warp-lang 1.0.1__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +115 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3425 -3354
  8. warp/codegen.py +2878 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +45 -45
  11. warp/context.py +5194 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +383 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -279
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
  34. warp/examples/benchmarks/benchmark_launches.py +295 -295
  35. warp/examples/browse.py +29 -28
  36. warp/examples/core/example_dem.py +234 -221
  37. warp/examples/core/example_fluid.py +293 -267
  38. warp/examples/core/example_graph_capture.py +144 -129
  39. warp/examples/core/example_marching_cubes.py +188 -176
  40. warp/examples/core/example_mesh.py +174 -154
  41. warp/examples/core/example_mesh_intersect.py +205 -193
  42. warp/examples/core/example_nvdb.py +176 -169
  43. warp/examples/core/example_raycast.py +105 -89
  44. warp/examples/core/example_raymarch.py +199 -178
  45. warp/examples/core/example_render_opengl.py +185 -141
  46. warp/examples/core/example_sph.py +405 -389
  47. warp/examples/core/example_torch.py +222 -181
  48. warp/examples/core/example_wave.py +263 -249
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +407 -391
  51. warp/examples/fem/example_convection_diffusion.py +182 -168
  52. warp/examples/fem/example_convection_diffusion_dg.py +219 -209
  53. warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
  54. warp/examples/fem/example_deformed_geometry.py +177 -159
  55. warp/examples/fem/example_diffusion.py +201 -173
  56. warp/examples/fem/example_diffusion_3d.py +177 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +221 -214
  58. warp/examples/fem/example_mixed_elasticity.py +244 -222
  59. warp/examples/fem/example_navier_stokes.py +259 -243
  60. warp/examples/fem/example_stokes.py +220 -192
  61. warp/examples/fem/example_stokes_transfer.py +265 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +260 -248
  65. warp/examples/optim/example_cloth_throw.py +222 -210
  66. warp/examples/optim/example_diffray.py +566 -535
  67. warp/examples/optim/example_drone.py +864 -835
  68. warp/examples/optim/example_inverse_kinematics.py +176 -169
  69. warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
  70. warp/examples/optim/example_spring_cage.py +239 -234
  71. warp/examples/optim/example_trajectory.py +223 -201
  72. warp/examples/optim/example_walker.py +306 -292
  73. warp/examples/sim/example_cartpole.py +139 -128
  74. warp/examples/sim/example_cloth.py +196 -184
  75. warp/examples/sim/example_granular.py +124 -113
  76. warp/examples/sim/example_granular_collision_sdf.py +197 -185
  77. warp/examples/sim/example_jacobian_ik.py +236 -213
  78. warp/examples/sim/example_particle_chain.py +118 -106
  79. warp/examples/sim/example_quadruped.py +193 -179
  80. warp/examples/sim/example_rigid_chain.py +197 -189
  81. warp/examples/sim/example_rigid_contact.py +189 -176
  82. warp/examples/sim/example_rigid_force.py +127 -126
  83. warp/examples/sim/example_rigid_gyroscopic.py +109 -97
  84. warp/examples/sim/example_rigid_soft_contact.py +134 -124
  85. warp/examples/sim/example_soft_body.py +190 -178
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +60 -27
  88. warp/fem/cache.py +401 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +15 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +744 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +441 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/partition.py +374 -376
  106. warp/fem/geometry/quadmesh_2d.py +532 -532
  107. warp/fem/geometry/tetmesh.py +840 -840
  108. warp/fem/geometry/trimesh_2d.py +577 -577
  109. warp/fem/integrate.py +1630 -1615
  110. warp/fem/operator.py +190 -191
  111. warp/fem/polynomial.py +214 -213
  112. warp/fem/quadrature/__init__.py +2 -2
  113. warp/fem/quadrature/pic_quadrature.py +243 -245
  114. warp/fem/quadrature/quadrature.py +295 -294
  115. warp/fem/space/__init__.py +294 -292
  116. warp/fem/space/basis_space.py +488 -489
  117. warp/fem/space/collocated_function_space.py +100 -105
  118. warp/fem/space/dof_mapper.py +236 -236
  119. warp/fem/space/function_space.py +148 -145
  120. warp/fem/space/grid_2d_function_space.py +267 -267
  121. warp/fem/space/grid_3d_function_space.py +305 -306
  122. warp/fem/space/hexmesh_function_space.py +350 -352
  123. warp/fem/space/partition.py +350 -350
  124. warp/fem/space/quadmesh_2d_function_space.py +368 -369
  125. warp/fem/space/restriction.py +158 -160
  126. warp/fem/space/shape/__init__.py +13 -15
  127. warp/fem/space/shape/cube_shape_function.py +738 -738
  128. warp/fem/space/shape/shape_function.py +102 -103
  129. warp/fem/space/shape/square_shape_function.py +611 -611
  130. warp/fem/space/shape/tet_shape_function.py +565 -567
  131. warp/fem/space/shape/triangle_shape_function.py +429 -429
  132. warp/fem/space/tetmesh_function_space.py +294 -292
  133. warp/fem/space/topology.py +297 -295
  134. warp/fem/space/trimesh_2d_function_space.py +223 -221
  135. warp/fem/types.py +77 -77
  136. warp/fem/utils.py +495 -495
  137. warp/jax.py +166 -141
  138. warp/jax_experimental.py +341 -339
  139. warp/native/array.h +1072 -1025
  140. warp/native/builtin.h +1560 -1560
  141. warp/native/bvh.cpp +398 -398
  142. warp/native/bvh.cu +525 -525
  143. warp/native/bvh.h +429 -429
  144. warp/native/clang/clang.cpp +495 -464
  145. warp/native/crt.cpp +31 -31
  146. warp/native/crt.h +334 -334
  147. warp/native/cuda_crt.h +1049 -1049
  148. warp/native/cuda_util.cpp +549 -540
  149. warp/native/cuda_util.h +288 -203
  150. warp/native/cutlass_gemm.cpp +34 -34
  151. warp/native/cutlass_gemm.cu +372 -372
  152. warp/native/error.cpp +66 -66
  153. warp/native/error.h +27 -27
  154. warp/native/fabric.h +228 -228
  155. warp/native/hashgrid.cpp +301 -278
  156. warp/native/hashgrid.cu +78 -77
  157. warp/native/hashgrid.h +227 -227
  158. warp/native/initializer_array.h +32 -32
  159. warp/native/intersect.h +1204 -1204
  160. warp/native/intersect_adj.h +365 -365
  161. warp/native/intersect_tri.h +322 -322
  162. warp/native/marching.cpp +2 -2
  163. warp/native/marching.cu +497 -497
  164. warp/native/marching.h +2 -2
  165. warp/native/mat.h +1498 -1498
  166. warp/native/matnn.h +333 -333
  167. warp/native/mesh.cpp +203 -203
  168. warp/native/mesh.cu +293 -293
  169. warp/native/mesh.h +1887 -1887
  170. warp/native/nanovdb/NanoVDB.h +4782 -4782
  171. warp/native/nanovdb/PNanoVDB.h +2553 -2553
  172. warp/native/nanovdb/PNanoVDBWrite.h +294 -294
  173. warp/native/noise.h +850 -850
  174. warp/native/quat.h +1084 -1084
  175. warp/native/rand.h +299 -299
  176. warp/native/range.h +108 -108
  177. warp/native/reduce.cpp +156 -156
  178. warp/native/reduce.cu +348 -348
  179. warp/native/runlength_encode.cpp +61 -61
  180. warp/native/runlength_encode.cu +46 -46
  181. warp/native/scan.cpp +30 -30
  182. warp/native/scan.cu +36 -36
  183. warp/native/scan.h +7 -7
  184. warp/native/solid_angle.h +442 -442
  185. warp/native/sort.cpp +94 -94
  186. warp/native/sort.cu +97 -97
  187. warp/native/sort.h +14 -14
  188. warp/native/sparse.cpp +337 -337
  189. warp/native/sparse.cu +544 -544
  190. warp/native/spatial.h +630 -630
  191. warp/native/svd.h +562 -562
  192. warp/native/temp_buffer.h +30 -30
  193. warp/native/vec.h +1132 -1132
  194. warp/native/volume.cpp +297 -297
  195. warp/native/volume.cu +32 -32
  196. warp/native/volume.h +538 -538
  197. warp/native/volume_builder.cu +425 -425
  198. warp/native/volume_builder.h +19 -19
  199. warp/native/warp.cpp +1057 -1052
  200. warp/native/warp.cu +2943 -2828
  201. warp/native/warp.h +313 -305
  202. warp/optim/__init__.py +9 -9
  203. warp/optim/adam.py +120 -120
  204. warp/optim/linear.py +1104 -939
  205. warp/optim/sgd.py +104 -92
  206. warp/render/__init__.py +10 -10
  207. warp/render/render_opengl.py +3217 -3204
  208. warp/render/render_usd.py +768 -749
  209. warp/render/utils.py +152 -150
  210. warp/sim/__init__.py +52 -59
  211. warp/sim/articulation.py +685 -685
  212. warp/sim/collide.py +1594 -1590
  213. warp/sim/import_mjcf.py +489 -481
  214. warp/sim/import_snu.py +220 -221
  215. warp/sim/import_urdf.py +536 -516
  216. warp/sim/import_usd.py +887 -881
  217. warp/sim/inertia.py +316 -317
  218. warp/sim/integrator.py +234 -233
  219. warp/sim/integrator_euler.py +1956 -1956
  220. warp/sim/integrator_featherstone.py +1910 -1991
  221. warp/sim/integrator_xpbd.py +3294 -3312
  222. warp/sim/model.py +4473 -4314
  223. warp/sim/particles.py +113 -112
  224. warp/sim/render.py +417 -403
  225. warp/sim/utils.py +413 -410
  226. warp/sparse.py +1227 -1227
  227. warp/stubs.py +2109 -2469
  228. warp/tape.py +1162 -225
  229. warp/tests/__init__.py +1 -1
  230. warp/tests/__main__.py +4 -4
  231. warp/tests/assets/torus.usda +105 -105
  232. warp/tests/aux_test_class_kernel.py +26 -26
  233. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  234. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  235. warp/tests/aux_test_dependent.py +22 -22
  236. warp/tests/aux_test_grad_customs.py +23 -23
  237. warp/tests/aux_test_reference.py +11 -11
  238. warp/tests/aux_test_reference_reference.py +10 -10
  239. warp/tests/aux_test_square.py +17 -17
  240. warp/tests/aux_test_unresolved_func.py +14 -14
  241. warp/tests/aux_test_unresolved_symbol.py +14 -14
  242. warp/tests/disabled_kinematics.py +239 -239
  243. warp/tests/run_coverage_serial.py +31 -31
  244. warp/tests/test_adam.py +157 -157
  245. warp/tests/test_arithmetic.py +1124 -1124
  246. warp/tests/test_array.py +2417 -2326
  247. warp/tests/test_array_reduce.py +150 -150
  248. warp/tests/test_async.py +668 -656
  249. warp/tests/test_atomic.py +141 -141
  250. warp/tests/test_bool.py +204 -149
  251. warp/tests/test_builtins_resolution.py +1292 -1292
  252. warp/tests/test_bvh.py +164 -171
  253. warp/tests/test_closest_point_edge_edge.py +228 -228
  254. warp/tests/test_codegen.py +566 -553
  255. warp/tests/test_compile_consts.py +97 -101
  256. warp/tests/test_conditional.py +246 -246
  257. warp/tests/test_copy.py +232 -215
  258. warp/tests/test_ctypes.py +632 -632
  259. warp/tests/test_dense.py +67 -67
  260. warp/tests/test_devices.py +91 -98
  261. warp/tests/test_dlpack.py +530 -529
  262. warp/tests/test_examples.py +400 -378
  263. warp/tests/test_fabricarray.py +955 -955
  264. warp/tests/test_fast_math.py +62 -54
  265. warp/tests/test_fem.py +1277 -1278
  266. warp/tests/test_fp16.py +130 -130
  267. warp/tests/test_func.py +338 -337
  268. warp/tests/test_generics.py +571 -571
  269. warp/tests/test_grad.py +746 -640
  270. warp/tests/test_grad_customs.py +333 -336
  271. warp/tests/test_hash_grid.py +210 -164
  272. warp/tests/test_import.py +39 -39
  273. warp/tests/test_indexedarray.py +1134 -1134
  274. warp/tests/test_intersect.py +67 -67
  275. warp/tests/test_jax.py +307 -307
  276. warp/tests/test_large.py +167 -164
  277. warp/tests/test_launch.py +354 -354
  278. warp/tests/test_lerp.py +261 -261
  279. warp/tests/test_linear_solvers.py +191 -171
  280. warp/tests/test_lvalue.py +421 -493
  281. warp/tests/test_marching_cubes.py +65 -65
  282. warp/tests/test_mat.py +1801 -1827
  283. warp/tests/test_mat_lite.py +115 -115
  284. warp/tests/test_mat_scalar_ops.py +2907 -2889
  285. warp/tests/test_math.py +126 -193
  286. warp/tests/test_matmul.py +500 -499
  287. warp/tests/test_matmul_lite.py +410 -410
  288. warp/tests/test_mempool.py +188 -190
  289. warp/tests/test_mesh.py +284 -324
  290. warp/tests/test_mesh_query_aabb.py +228 -241
  291. warp/tests/test_mesh_query_point.py +692 -702
  292. warp/tests/test_mesh_query_ray.py +292 -303
  293. warp/tests/test_mlp.py +276 -276
  294. warp/tests/test_model.py +110 -110
  295. warp/tests/test_modules_lite.py +39 -39
  296. warp/tests/test_multigpu.py +163 -163
  297. warp/tests/test_noise.py +248 -248
  298. warp/tests/test_operators.py +250 -250
  299. warp/tests/test_options.py +123 -125
  300. warp/tests/test_peer.py +133 -137
  301. warp/tests/test_pinned.py +78 -78
  302. warp/tests/test_print.py +54 -54
  303. warp/tests/test_quat.py +2086 -2086
  304. warp/tests/test_rand.py +288 -288
  305. warp/tests/test_reload.py +217 -217
  306. warp/tests/test_rounding.py +179 -179
  307. warp/tests/test_runlength_encode.py +190 -190
  308. warp/tests/test_sim_grad.py +243 -0
  309. warp/tests/test_sim_kinematics.py +91 -97
  310. warp/tests/test_smoothstep.py +168 -168
  311. warp/tests/test_snippet.py +305 -266
  312. warp/tests/test_sparse.py +468 -460
  313. warp/tests/test_spatial.py +2148 -2148
  314. warp/tests/test_streams.py +486 -473
  315. warp/tests/test_struct.py +710 -675
  316. warp/tests/test_tape.py +173 -148
  317. warp/tests/test_torch.py +743 -743
  318. warp/tests/test_transient_module.py +87 -87
  319. warp/tests/test_types.py +556 -659
  320. warp/tests/test_utils.py +490 -499
  321. warp/tests/test_vec.py +1264 -1268
  322. warp/tests/test_vec_lite.py +73 -73
  323. warp/tests/test_vec_scalar_ops.py +2099 -2099
  324. warp/tests/test_verify_fp.py +94 -94
  325. warp/tests/test_volume.py +737 -736
  326. warp/tests/test_volume_write.py +255 -265
  327. warp/tests/unittest_serial.py +37 -37
  328. warp/tests/unittest_suites.py +363 -359
  329. warp/tests/unittest_utils.py +603 -578
  330. warp/tests/unused_test_misc.py +71 -71
  331. warp/tests/walkthrough_debug.py +85 -85
  332. warp/thirdparty/appdirs.py +598 -598
  333. warp/thirdparty/dlpack.py +143 -143
  334. warp/thirdparty/unittest_parallel.py +566 -561
  335. warp/torch.py +321 -295
  336. warp/types.py +4504 -4450
  337. warp/utils.py +1008 -821
  338. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
  339. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
  340. warp_lang-1.1.0.dist-info/RECORD +352 -0
  341. warp/examples/assets/cube.usda +0 -42
  342. warp/examples/assets/sphere.usda +0 -56
  343. warp/examples/assets/torus.usda +0 -105
  344. warp_lang-1.0.1.dist-info/RECORD +0 -352
  345. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
  346. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -1,1615 +1,1630 @@
1
- from typing import List, Dict, Set, Optional, Any, Union
2
-
3
- import warp as wp
4
-
5
- import re
6
- import ast
7
-
8
- from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_assign
9
- from warp.types import type_length
10
- from warp.utils import array_cast
11
- from warp.codegen import get_annotations
12
-
13
- from warp.fem.domain import GeometryDomain
14
- from warp.fem.field import (
15
- TestField,
16
- TrialField,
17
- FieldLike,
18
- DiscreteField,
19
- FieldRestriction,
20
- make_restriction,
21
- )
22
- from warp.fem.quadrature import Quadrature, RegularQuadrature
23
- from warp.fem.operator import Operator, Integrand
24
- from warp.fem import cache
25
- from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE, make_free_sample
26
-
27
-
28
- def _resolve_path(func, node):
29
- """
30
- Resolves variable and path from ast node/attribute (adapted from warp.codegen)
31
- """
32
-
33
- modules = []
34
-
35
- while isinstance(node, ast.Attribute):
36
- modules.append(node.attr)
37
- node = node.value
38
-
39
- if isinstance(node, ast.Name):
40
- modules.append(node.id)
41
-
42
- # reverse list since ast presents it backward order
43
- path = [*reversed(modules)]
44
-
45
- if len(path) == 0:
46
- return None, path
47
-
48
- # try and evaluate object path
49
- try:
50
- # Look up the closure info and append it to adj.func.__globals__
51
- # in case you want to define a kernel inside a function and refer
52
- # to variables you've declared inside that function:
53
- capturedvars = dict(
54
- zip(
55
- func.__code__.co_freevars,
56
- [c.cell_contents for c in (func.__closure__ or [])],
57
- )
58
- )
59
-
60
- vars_dict = {**func.__globals__, **capturedvars}
61
- func = eval(".".join(path), vars_dict)
62
- return func, path
63
- except (NameError, AttributeError):
64
- pass
65
-
66
- return None, path
67
-
68
-
69
- def _path_to_ast_attribute(name: str) -> ast.Attribute:
70
- path = name.split(".")
71
- path.reverse()
72
-
73
- node = ast.Name(id=path.pop(), ctx=ast.Load())
74
- while len(path):
75
- node = ast.Attribute(
76
- value=node,
77
- attr=path.pop(),
78
- ctx=ast.Load(),
79
- )
80
- return node
81
-
82
-
83
- class IntegrandTransformer(ast.NodeTransformer):
84
- def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
85
- self._integrand = integrand
86
- self._field_args = field_args
87
-
88
- def visit_Call(self, call: ast.Call):
89
- call = self.generic_visit(call)
90
-
91
- callee = getattr(call.func, "id", None)
92
- if callee in self._field_args:
93
- # Shortcut for evaluating fields as f(x...)
94
- field = self._field_args[callee]
95
-
96
- arg_type = self._integrand.argspec.annotations[callee]
97
- operator = arg_type.call_operator
98
-
99
- call.func = ast.Attribute(
100
- value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
101
- attr="call_operator",
102
- ctx=ast.Load(),
103
- )
104
- call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
105
-
106
- self._replace_call_func(call, operator, field)
107
-
108
- return call
109
-
110
- func, _ = _resolve_path(self._integrand.func, call.func)
111
-
112
- if isinstance(func, Operator) and len(call.args) > 0:
113
- # Evaluating operators as op(field, x, ...)
114
- callee = getattr(call.args[0], "id", None)
115
- if callee in self._field_args:
116
- field = self._field_args[callee]
117
- self._replace_call_func(call, func, field)
118
-
119
- if isinstance(func, Integrand):
120
- key = self._translate_callee(func, call.args)
121
- call.func = ast.Attribute(
122
- value=call.func,
123
- attr=key,
124
- ctx=ast.Load(),
125
- )
126
-
127
- # print(ast.dump(call, indent=4))
128
-
129
- return call
130
-
131
- def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
132
- try:
133
- pointer = operator.resolver(field)
134
- setattr(operator, pointer.key, pointer)
135
- except AttributeError:
136
- raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}")
137
- call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
138
-
139
- def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
140
- # Get field types for call site arguments
141
- call_site_field_args = []
142
- for arg in args:
143
- name = getattr(arg, "id", None)
144
- if name in self._field_args:
145
- call_site_field_args.append(self._field_args[name])
146
-
147
- call_site_field_args.reverse()
148
-
149
- # Pass to callee in same order
150
- callee_field_args = {}
151
- for arg in callee.argspec.args:
152
- arg_type = callee.argspec.annotations[arg]
153
- if arg_type in (Field, Domain):
154
- callee_field_args[arg] = call_site_field_args.pop()
155
-
156
- return _translate_integrand(callee, callee_field_args).key
157
-
158
-
159
- def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
160
- # Specialize field argument types
161
- argspec = integrand.argspec
162
- annotations = {}
163
- for arg in argspec.args:
164
- arg_type = argspec.annotations[arg]
165
- if arg_type == Field:
166
- annotations[arg] = field_args[arg].ElementEvalArg
167
- elif arg_type == Domain:
168
- annotations[arg] = field_args[arg].ElementArg
169
- else:
170
- annotations[arg] = arg_type
171
-
172
- # Transform field evaluation calls
173
- transformer = IntegrandTransformer(integrand, field_args)
174
-
175
- suffix = "_".join([f.name for f in field_args.values()])
176
-
177
- func = cache.get_integrand_function(
178
- integrand=integrand,
179
- suffix=suffix,
180
- annotations=annotations,
181
- code_transformers=[transformer],
182
- )
183
-
184
- key = func.key
185
- setattr(integrand, key, integrand.module.functions[key])
186
-
187
- return getattr(integrand, key)
188
-
189
-
190
- def _get_integrand_field_arguments(
191
- integrand: Integrand,
192
- fields: Dict[str, FieldLike],
193
- domain: GeometryDomain = None,
194
- ):
195
- # parse argument types
196
- field_args = {}
197
- value_args = {}
198
-
199
- domain_name = None
200
- sample_name = None
201
-
202
- argspec = integrand.argspec
203
- for arg in argspec.args:
204
- arg_type = argspec.annotations[arg]
205
- if arg_type == Field:
206
- if arg not in fields:
207
- raise ValueError(f"Missing field for argument '{arg}'")
208
- field_args[arg] = fields[arg]
209
- elif arg_type == Domain:
210
- domain_name = arg
211
- field_args[arg] = domain
212
- elif arg_type == Sample:
213
- sample_name = arg
214
- else:
215
- value_args[arg] = arg_type
216
-
217
- return field_args, value_args, domain_name, sample_name
218
-
219
-
220
- def _get_test_and_trial_fields(
221
- fields: Dict[str, FieldLike],
222
- ):
223
- test = None
224
- trial = None
225
- test_name = None
226
- trial_name = None
227
-
228
- for name, field in fields.items():
229
- if isinstance(field, TestField):
230
- if test is not None:
231
- raise ValueError("Duplicate test field argument")
232
- test = field
233
- test_name = name
234
- elif isinstance(field, TrialField):
235
- if trial is not None:
236
- raise ValueError("Duplicate test field argument")
237
- trial = field
238
- trial_name = name
239
-
240
- if trial is not None:
241
- if test is None:
242
- raise ValueError("A trial field cannot be provided without a test field")
243
-
244
- if test.domain != trial.domain:
245
- raise ValueError("Incompatible test and trial domains")
246
-
247
- return test, test_name, trial, trial_name
248
-
249
-
250
- def _gen_field_struct(field_args: Dict[str, FieldLike]):
251
- class Fields:
252
- pass
253
-
254
- annotations = get_annotations(Fields)
255
-
256
- for name, arg in field_args.items():
257
- if isinstance(arg, GeometryDomain):
258
- continue
259
- setattr(Fields, name, arg.EvalArg())
260
- annotations[name] = arg.EvalArg
261
-
262
- try:
263
- Fields.__annotations__ = annotations
264
- except AttributeError:
265
- setattr(Fields.__dict__, "__annotations__", annotations)
266
-
267
- suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
268
-
269
- return cache.get_struct(Fields, suffix=suffix)
270
-
271
-
272
- def _gen_value_struct(value_args: Dict[str, type]):
273
- class Values:
274
- pass
275
-
276
- annotations = get_annotations(Values)
277
-
278
- for name, arg_type in value_args.items():
279
- setattr(Values, name, None)
280
- annotations[name] = arg_type
281
-
282
- def arg_type_name(arg_type):
283
- if isinstance(arg_type, wp.codegen.Struct):
284
- return arg_type_name(arg_type.cls)
285
- return getattr(arg_type, "__name__", str(arg_type))
286
-
287
- def arg_type_name(arg_type):
288
- if isinstance(arg_type, wp.codegen.Struct):
289
- return arg_type_name(arg_type.cls)
290
- return getattr(arg_type, "__name__", str(arg_type))
291
-
292
- try:
293
- Values.__annotations__ = annotations
294
- except AttributeError:
295
- setattr(Values.__dict__, "__annotations__", annotations)
296
-
297
- suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
298
-
299
- return cache.get_struct(Values, suffix=suffix)
300
-
301
-
302
- def _get_trial_arg():
303
- pass
304
-
305
-
306
- def _get_test_arg():
307
- pass
308
-
309
-
310
- class _FieldWrappers:
311
- pass
312
-
313
-
314
- def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
315
- integrand_func._field_wrappers = _FieldWrappers()
316
- for name, field in fields.items():
317
- setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
318
-
319
-
320
- class PassFieldArgsToIntegrand(ast.NodeTransformer):
321
- def __init__(
322
- self,
323
- arg_names: List[str],
324
- field_args: Set[str],
325
- value_args: Set[str],
326
- sample_name: str,
327
- domain_name: str,
328
- test_name: str = None,
329
- trial_name: str = None,
330
- func_name: str = "integrand_func",
331
- fields_var_name: str = "fields",
332
- values_var_name: str = "values",
333
- domain_var_name: str = "domain_arg",
334
- sample_var_name: str = "sample",
335
- field_wrappers_attr: str = "_field_wrappers",
336
- ):
337
- self._arg_names = arg_names
338
- self._field_args = field_args
339
- self._value_args = value_args
340
- self._domain_name = domain_name
341
- self._sample_name = sample_name
342
- self._func_name = func_name
343
- self._test_name = test_name
344
- self._trial_name = trial_name
345
- self._fields_var_name = fields_var_name
346
- self._values_var_name = values_var_name
347
- self._domain_var_name = domain_var_name
348
- self._sample_var_name = sample_var_name
349
- self._field_wrappers_attr = field_wrappers_attr
350
-
351
- def visit_Call(self, call: ast.Call):
352
- call = self.generic_visit(call)
353
-
354
- callee = getattr(call.func, "id", None)
355
-
356
- if callee == self._func_name:
357
- # Replace function arguments with ours generated structs
358
- call.args.clear()
359
- for arg in self._arg_names:
360
- if arg == self._domain_name:
361
- call.args.append(
362
- ast.Name(id=self._domain_var_name, ctx=ast.Load()),
363
- )
364
- elif arg == self._sample_name:
365
- call.args.append(
366
- ast.Name(id=self._sample_var_name, ctx=ast.Load()),
367
- )
368
- elif arg in self._field_args:
369
- call.args.append(
370
- ast.Call(
371
- func=ast.Attribute(
372
- value=ast.Attribute(
373
- value=ast.Name(id=self._func_name, ctx=ast.Load()),
374
- attr=self._field_wrappers_attr,
375
- ctx=ast.Load(),
376
- ),
377
- attr=arg,
378
- ctx=ast.Load(),
379
- ),
380
- args=[
381
- ast.Name(id=self._domain_var_name, ctx=ast.Load()),
382
- ast.Attribute(
383
- value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
384
- attr=arg,
385
- ctx=ast.Load(),
386
- ),
387
- ],
388
- keywords=[],
389
- )
390
- )
391
- elif arg in self._value_args:
392
- call.args.append(
393
- ast.Attribute(
394
- value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
395
- attr=arg,
396
- ctx=ast.Load(),
397
- )
398
- )
399
- else:
400
- raise RuntimeError(f"Unhandled argument {arg}")
401
- # print(ast.dump(call, indent=4))
402
- elif callee == _get_test_arg.__name__:
403
- # print(ast.dump(call, indent=4))
404
- call = ast.Attribute(
405
- value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
406
- attr=self._test_name,
407
- ctx=ast.Load(),
408
- )
409
- elif callee == _get_trial_arg.__name__:
410
- # print(ast.dump(call, indent=4))
411
- call = ast.Attribute(
412
- value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
413
- attr=self._trial_name,
414
- ctx=ast.Load(),
415
- )
416
-
417
- return call
418
-
419
-
420
- def get_integrate_constant_kernel(
421
- integrand_func: wp.Function,
422
- domain: GeometryDomain,
423
- quadrature: Quadrature,
424
- FieldStruct: wp.codegen.Struct,
425
- ValueStruct: wp.codegen.Struct,
426
- accumulate_dtype,
427
- ):
428
- def integrate_kernel_fn(
429
- qp_arg: quadrature.Arg,
430
- domain_arg: domain.ElementArg,
431
- domain_index_arg: domain.ElementIndexArg,
432
- fields: FieldStruct,
433
- values: ValueStruct,
434
- result: wp.array(dtype=accumulate_dtype),
435
- ):
436
- element_index = domain.element_index(domain_index_arg, wp.tid())
437
- elem_sum = accumulate_dtype(0.0)
438
-
439
- test_dof_index = NULL_DOF_INDEX
440
- trial_dof_index = NULL_DOF_INDEX
441
-
442
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
443
- for k in range(qp_point_count):
444
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
445
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
446
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
447
-
448
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
449
- vol = domain.element_measure(domain_arg, sample)
450
-
451
- val = integrand_func(sample, fields, values)
452
-
453
- elem_sum += accumulate_dtype(qp_weight * vol * val)
454
-
455
- wp.atomic_add(result, 0, elem_sum)
456
-
457
- return integrate_kernel_fn
458
-
459
-
460
- def get_integrate_linear_kernel(
461
- integrand_func: wp.Function,
462
- domain: GeometryDomain,
463
- quadrature: Quadrature,
464
- FieldStruct: wp.codegen.Struct,
465
- ValueStruct: wp.codegen.Struct,
466
- test: TestField,
467
- output_dtype,
468
- accumulate_dtype,
469
- ):
470
- def integrate_kernel_fn(
471
- qp_arg: quadrature.Arg,
472
- domain_arg: domain.ElementArg,
473
- domain_index_arg: domain.ElementIndexArg,
474
- test_arg: test.space_restriction.NodeArg,
475
- fields: FieldStruct,
476
- values: ValueStruct,
477
- result: wp.array2d(dtype=output_dtype),
478
- ):
479
- local_node_index, test_dof = wp.tid()
480
- node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
481
- element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
482
-
483
- trial_dof_index = NULL_DOF_INDEX
484
-
485
- val_sum = accumulate_dtype(0.0)
486
-
487
- for n in range(element_count):
488
- node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
489
- element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
490
-
491
- test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
492
-
493
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
494
- for k in range(qp_point_count):
495
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
496
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
497
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
498
-
499
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
500
-
501
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
502
- val = integrand_func(sample, fields, values)
503
-
504
- val_sum += accumulate_dtype(qp_weight * vol * val)
505
-
506
- result[node_index, test_dof] = output_dtype(val_sum)
507
-
508
- return integrate_kernel_fn
509
-
510
-
511
- def get_integrate_linear_nodal_kernel(
512
- integrand_func: wp.Function,
513
- domain: GeometryDomain,
514
- FieldStruct: wp.codegen.Struct,
515
- ValueStruct: wp.codegen.Struct,
516
- test: TestField,
517
- output_dtype,
518
- accumulate_dtype,
519
- ):
520
- def integrate_kernel_fn(
521
- domain_arg: domain.ElementArg,
522
- domain_index_arg: domain.ElementIndexArg,
523
- test_restriction_arg: test.space_restriction.NodeArg,
524
- fields: FieldStruct,
525
- values: ValueStruct,
526
- result: wp.array2d(dtype=output_dtype),
527
- ):
528
- local_node_index, dof = wp.tid()
529
-
530
- node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
531
- element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
532
-
533
- trial_dof_index = NULL_DOF_INDEX
534
-
535
- val_sum = accumulate_dtype(0.0)
536
-
537
- for n in range(element_count):
538
- node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
539
- element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
540
-
541
- coords = test.space.node_coords_in_element(
542
- domain_arg,
543
- _get_test_arg(),
544
- element_index,
545
- node_element_index.node_index_in_element,
546
- )
547
-
548
- if coords[0] != OUTSIDE:
549
- node_weight = test.space.node_quadrature_weight(
550
- domain_arg,
551
- _get_test_arg(),
552
- element_index,
553
- node_element_index.node_index_in_element,
554
- )
555
-
556
- test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
557
-
558
- sample = Sample(
559
- element_index,
560
- coords,
561
- node_index,
562
- node_weight,
563
- test_dof_index,
564
- trial_dof_index,
565
- )
566
- vol = domain.element_measure(domain_arg, sample)
567
- val = integrand_func(sample, fields, values)
568
-
569
- val_sum += accumulate_dtype(node_weight * vol * val)
570
-
571
- result[node_index, dof] = output_dtype(val_sum)
572
-
573
- return integrate_kernel_fn
574
-
575
-
576
- def get_integrate_bilinear_kernel(
577
- integrand_func: wp.Function,
578
- domain: GeometryDomain,
579
- quadrature: Quadrature,
580
- FieldStruct: wp.codegen.Struct,
581
- ValueStruct: wp.codegen.Struct,
582
- test: TestField,
583
- trial: TrialField,
584
- output_dtype,
585
- accumulate_dtype,
586
- ):
587
- NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
588
-
589
- def integrate_kernel_fn(
590
- qp_arg: quadrature.Arg,
591
- domain_arg: domain.ElementArg,
592
- domain_index_arg: domain.ElementIndexArg,
593
- test_arg: test.space_restriction.NodeArg,
594
- trial_partition_arg: trial.space_partition.PartitionArg,
595
- trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
596
- fields: FieldStruct,
597
- values: ValueStruct,
598
- row_offsets: wp.array(dtype=int),
599
- triplet_rows: wp.array(dtype=int),
600
- triplet_cols: wp.array(dtype=int),
601
- triplet_values: wp.array3d(dtype=output_dtype),
602
- ):
603
- test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
604
-
605
- element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
606
- test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
607
-
608
- trial_dof_index = DofIndex(trial_node, trial_dof)
609
-
610
- for element in range(element_count):
611
- test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
612
- element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
613
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
614
-
615
- test_dof_index = DofIndex(
616
- test_element_index.node_index_in_element,
617
- test_dof,
618
- )
619
-
620
- val_sum = accumulate_dtype(0.0)
621
-
622
- for k in range(qp_point_count):
623
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
624
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
625
-
626
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
627
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
628
-
629
- sample = Sample(
630
- element_index,
631
- coords,
632
- qp_index,
633
- qp_weight,
634
- test_dof_index,
635
- trial_dof_index,
636
- )
637
- val = integrand_func(sample, fields, values)
638
- val_sum += accumulate_dtype(qp_weight * vol * val)
639
-
640
- block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
641
- triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
642
-
643
- # Set row and column indices
644
- if test_dof == 0 and trial_dof == 0:
645
- trial_node_index = trial.space_partition.partition_node_index(
646
- trial_partition_arg,
647
- trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
648
- )
649
- triplet_rows[block_offset] = test_node_index
650
- triplet_cols[block_offset] = trial_node_index
651
-
652
- return integrate_kernel_fn
653
-
654
-
655
- def get_integrate_bilinear_nodal_kernel(
656
- integrand_func: wp.Function,
657
- domain: GeometryDomain,
658
- FieldStruct: wp.codegen.Struct,
659
- ValueStruct: wp.codegen.Struct,
660
- test: TestField,
661
- output_dtype,
662
- accumulate_dtype,
663
- ):
664
- def integrate_kernel_fn(
665
- domain_arg: domain.ElementArg,
666
- domain_index_arg: domain.ElementIndexArg,
667
- test_restriction_arg: test.space_restriction.NodeArg,
668
- fields: FieldStruct,
669
- values: ValueStruct,
670
- triplet_rows: wp.array(dtype=int),
671
- triplet_cols: wp.array(dtype=int),
672
- triplet_values: wp.array3d(dtype=output_dtype),
673
- ):
674
- local_node_index, test_dof, trial_dof = wp.tid()
675
-
676
- element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
677
- node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
678
-
679
- val_sum = accumulate_dtype(0.0)
680
-
681
- for n in range(element_count):
682
- node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
683
- element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
684
-
685
- coords = test.space.node_coords_in_element(
686
- domain_arg,
687
- _get_test_arg(),
688
- element_index,
689
- node_element_index.node_index_in_element,
690
- )
691
-
692
- if coords[0] != OUTSIDE:
693
- node_weight = test.space.node_quadrature_weight(
694
- domain_arg,
695
- _get_test_arg(),
696
- element_index,
697
- node_element_index.node_index_in_element,
698
- )
699
-
700
- test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
701
- trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
702
-
703
- sample = Sample(
704
- element_index,
705
- coords,
706
- node_index,
707
- node_weight,
708
- test_dof_index,
709
- trial_dof_index,
710
- )
711
- vol = domain.element_measure(domain_arg, sample)
712
- val = integrand_func(sample, fields, values)
713
-
714
- val_sum += accumulate_dtype(node_weight * vol * val)
715
-
716
- triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
717
- triplet_rows[local_node_index] = node_index
718
- triplet_cols[local_node_index] = node_index
719
-
720
- return integrate_kernel_fn
721
-
722
-
723
- def _generate_integrate_kernel(
724
- integrand: Integrand,
725
- domain: GeometryDomain,
726
- nodal: bool,
727
- quadrature: Quadrature,
728
- test: Optional[TestField],
729
- test_name: str,
730
- trial: Optional[TrialField],
731
- trial_name: str,
732
- fields: Dict[str, FieldLike],
733
- output_dtype: type,
734
- accumulate_dtype: type,
735
- kernel_options: Dict[str, Any] = {},
736
- ) -> wp.Kernel:
737
- output_dtype = wp.types.type_scalar_type(output_dtype)
738
-
739
- # Extract field arguments from integrand
740
- field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
741
- integrand, fields=fields, domain=domain
742
- )
743
-
744
- FieldStruct = _gen_field_struct(field_args)
745
- ValueStruct = _gen_value_struct(value_args)
746
-
747
- # Check if kernel exist in cache
748
- kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
749
- if nodal:
750
- kernel_suffix += "_nodal"
751
- else:
752
- kernel_suffix += quadrature.name
753
-
754
- if test:
755
- kernel_suffix += f"_test_{test.space_partition.name}_{test.space.name}"
756
- if trial:
757
- kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
758
-
759
- kernel = cache.get_integrand_kernel(
760
- integrand=integrand,
761
- suffix=kernel_suffix,
762
- )
763
- if kernel is not None:
764
- return kernel, FieldStruct, ValueStruct
765
-
766
- # Not found in cache, transform integrand and generate kernel
767
-
768
- integrand_func = _translate_integrand(
769
- integrand,
770
- field_args,
771
- )
772
-
773
- _register_integrand_field_wrappers(integrand_func, fields)
774
-
775
- if test is None and trial is None:
776
- integrate_kernel_fn = get_integrate_constant_kernel(
777
- integrand_func,
778
- domain,
779
- quadrature,
780
- FieldStruct,
781
- ValueStruct,
782
- accumulate_dtype=accumulate_dtype,
783
- )
784
- elif trial is None:
785
- if nodal:
786
- integrate_kernel_fn = get_integrate_linear_nodal_kernel(
787
- integrand_func,
788
- domain,
789
- FieldStruct,
790
- ValueStruct,
791
- test=test,
792
- output_dtype=output_dtype,
793
- accumulate_dtype=accumulate_dtype,
794
- )
795
- else:
796
- integrate_kernel_fn = get_integrate_linear_kernel(
797
- integrand_func,
798
- domain,
799
- quadrature,
800
- FieldStruct,
801
- ValueStruct,
802
- test=test,
803
- output_dtype=output_dtype,
804
- accumulate_dtype=accumulate_dtype,
805
- )
806
- else:
807
- if nodal:
808
- integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
809
- integrand_func,
810
- domain,
811
- FieldStruct,
812
- ValueStruct,
813
- test=test,
814
- output_dtype=output_dtype,
815
- accumulate_dtype=accumulate_dtype,
816
- )
817
- else:
818
- integrate_kernel_fn = get_integrate_bilinear_kernel(
819
- integrand_func,
820
- domain,
821
- quadrature,
822
- FieldStruct,
823
- ValueStruct,
824
- test=test,
825
- trial=trial,
826
- output_dtype=output_dtype,
827
- accumulate_dtype=accumulate_dtype,
828
- )
829
-
830
- kernel = cache.get_integrand_kernel(
831
- integrand=integrand,
832
- kernel_fn=integrate_kernel_fn,
833
- suffix=kernel_suffix,
834
- kernel_options=kernel_options,
835
- code_transformers=[
836
- PassFieldArgsToIntegrand(
837
- arg_names=integrand.argspec.args,
838
- field_args=field_args.keys(),
839
- value_args=value_args.keys(),
840
- sample_name=sample_name,
841
- domain_name=domain_name,
842
- test_name=test_name,
843
- trial_name=trial_name,
844
- )
845
- ],
846
- )
847
-
848
- return kernel, FieldStruct, ValueStruct
849
-
850
-
851
- def _launch_integrate_kernel(
852
- kernel: wp.Kernel,
853
- FieldStruct: wp.codegen.Struct,
854
- ValueStruct: wp.codegen.Struct,
855
- domain: GeometryDomain,
856
- nodal: bool,
857
- quadrature: Quadrature,
858
- test: Optional[TestField],
859
- trial: Optional[TrialField],
860
- fields: Dict[str, FieldLike],
861
- values: Dict[str, Any],
862
- accumulate_dtype: type,
863
- temporary_store: Optional[cache.TemporaryStore],
864
- output_dtype: type,
865
- output: Optional[Union[wp.array, BsrMatrix]],
866
- device,
867
- ):
868
- # Set-up launch arguments
869
- domain_elt_arg = domain.element_arg_value(device=device)
870
- domain_elt_index_arg = domain.element_index_arg_value(device=device)
871
-
872
- if quadrature is not None:
873
- qp_arg = quadrature.arg_value(device=device)
874
-
875
- field_arg_values = FieldStruct()
876
- for k, v in fields.items():
877
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
878
-
879
- value_struct_values = ValueStruct()
880
- for k, v in values.items():
881
- setattr(value_struct_values, k, v)
882
-
883
- # Constant form
884
- if test is None and trial is None:
885
- if output is not None and output.dtype == accumulate_dtype:
886
- if output.size < 1:
887
- raise RuntimeError("Output array must be of size at least 1")
888
- accumulate_array = output
889
- else:
890
- accumulate_temporary = cache.borrow_temporary(
891
- shape=(1),
892
- device=device,
893
- dtype=accumulate_dtype,
894
- temporary_store=temporary_store,
895
- requires_grad=output is not None and output.requires_grad,
896
- )
897
- accumulate_array = accumulate_temporary.array
898
-
899
- accumulate_array.zero_()
900
- wp.launch(
901
- kernel=kernel,
902
- dim=domain.element_count(),
903
- inputs=[
904
- qp_arg,
905
- domain_elt_arg,
906
- domain_elt_index_arg,
907
- field_arg_values,
908
- value_struct_values,
909
- accumulate_array,
910
- ],
911
- device=device,
912
- )
913
-
914
- if output == accumulate_array:
915
- return output
916
- elif output is None:
917
- return accumulate_array.numpy()[0]
918
- else:
919
- array_cast(in_array=accumulate_array, out_array=output)
920
- return output
921
-
922
- test_arg = test.space_restriction.node_arg(device=device)
923
-
924
- # Linear form
925
- if trial is None:
926
- # If an output array is provided with the correct type, accumulate directly into it
927
- # Otherwise, grab a temporary array
928
- if output is None:
929
- if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
930
- output_shape = (test.space_partition.node_count(),)
931
- elif type_length(output_dtype) == 1:
932
- output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
933
- else:
934
- raise RuntimeError(
935
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
936
- )
937
-
938
- output_temporary = cache.borrow_temporary(
939
- temporary_store=temporary_store,
940
- shape=output_shape,
941
- dtype=output_dtype,
942
- device=device,
943
- )
944
-
945
- output = output_temporary.array
946
-
947
- else:
948
- output_temporary = None
949
-
950
- if output.shape[0] < test.space_partition.node_count():
951
- raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
952
-
953
- output_dtype = output.dtype
954
- if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
955
- if type_length(output_dtype) != 1:
956
- raise RuntimeError(
957
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
958
- )
959
- if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
960
- raise RuntimeError(
961
- f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
962
- )
963
-
964
- # Launch the integration on the kernel on a 2d scalar view of the actual array
965
- output.zero_()
966
-
967
- def as_2d_array(array):
968
- return wp.array(
969
- data=None,
970
- ptr=array.ptr,
971
- capacity=array.capacity,
972
- device=array.device,
973
- shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
974
- dtype=wp.types.type_scalar_type(output_dtype),
975
- grad=None if array.grad is None else as_2d_array(array.grad),
976
- )
977
-
978
- output_view = output if output.ndim == 2 else as_2d_array(output)
979
-
980
- if nodal:
981
- wp.launch(
982
- kernel=kernel,
983
- dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
984
- inputs=[
985
- domain_elt_arg,
986
- domain_elt_index_arg,
987
- test_arg,
988
- field_arg_values,
989
- value_struct_values,
990
- output_view,
991
- ],
992
- device=device,
993
- )
994
- else:
995
- wp.launch(
996
- kernel=kernel,
997
- dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
998
- inputs=[
999
- qp_arg,
1000
- domain_elt_arg,
1001
- domain_elt_index_arg,
1002
- test_arg,
1003
- field_arg_values,
1004
- value_struct_values,
1005
- output_view,
1006
- ],
1007
- device=device,
1008
- )
1009
-
1010
- if output_temporary is not None:
1011
- return output_temporary.detach()
1012
-
1013
- return output
1014
-
1015
- # Bilinear form
1016
-
1017
- if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
1018
- block_type = output_dtype
1019
- else:
1020
- block_type = cache.cached_mat_type(
1021
- shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
1022
- )
1023
-
1024
- if nodal:
1025
- nnz = test.space_restriction.node_count()
1026
- else:
1027
- nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
1028
-
1029
- triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1030
- triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1031
- triplet_values_temp = cache.borrow_temporary(
1032
- temporary_store,
1033
- shape=(
1034
- nnz,
1035
- test.space.VALUE_DOF_COUNT,
1036
- trial.space.VALUE_DOF_COUNT,
1037
- ),
1038
- dtype=output_dtype,
1039
- device=device,
1040
- )
1041
- triplet_cols = triplet_cols_temp.array
1042
- triplet_rows = triplet_rows_temp.array
1043
- triplet_values = triplet_values_temp.array
1044
-
1045
- triplet_values.zero_()
1046
-
1047
- if nodal:
1048
- wp.launch(
1049
- kernel=kernel,
1050
- dim=triplet_values.shape,
1051
- inputs=[
1052
- domain_elt_arg,
1053
- domain_elt_index_arg,
1054
- test_arg,
1055
- field_arg_values,
1056
- value_struct_values,
1057
- triplet_rows,
1058
- triplet_cols,
1059
- triplet_values,
1060
- ],
1061
- device=device,
1062
- )
1063
-
1064
- else:
1065
- offsets = test.space_restriction.partition_element_offsets()
1066
-
1067
- trial_partition_arg = trial.space_partition.partition_arg_value(device)
1068
- trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1069
- wp.launch(
1070
- kernel=kernel,
1071
- dim=(
1072
- test.space_restriction.node_count(),
1073
- trial.space.topology.NODES_PER_ELEMENT,
1074
- test.space.VALUE_DOF_COUNT,
1075
- trial.space.VALUE_DOF_COUNT,
1076
- ),
1077
- inputs=[
1078
- qp_arg,
1079
- domain_elt_arg,
1080
- domain_elt_index_arg,
1081
- test_arg,
1082
- trial_partition_arg,
1083
- trial_topology_arg,
1084
- field_arg_values,
1085
- value_struct_values,
1086
- offsets,
1087
- triplet_rows,
1088
- triplet_cols,
1089
- triplet_values,
1090
- ],
1091
- device=device,
1092
- )
1093
-
1094
- if output is not None:
1095
- if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1096
- raise RuntimeError(
1097
- f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1098
- )
1099
-
1100
- else:
1101
- output = bsr_zeros(
1102
- rows_of_blocks=test.space_partition.node_count(),
1103
- cols_of_blocks=trial.space_partition.node_count(),
1104
- block_type=block_type,
1105
- device=device,
1106
- )
1107
-
1108
- bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
1109
-
1110
- # Do not wait for garbage collection
1111
- triplet_values_temp.release()
1112
- triplet_rows_temp.release()
1113
- triplet_cols_temp.release()
1114
-
1115
- return output
1116
-
1117
-
1118
- def integrate(
1119
- integrand: Integrand,
1120
- domain: Optional[GeometryDomain] = None,
1121
- quadrature: Optional[Quadrature] = None,
1122
- nodal: bool = False,
1123
- fields: Dict[str, FieldLike] = {},
1124
- values: Dict[str, Any] = {},
1125
- accumulate_dtype: type = wp.float64,
1126
- output_dtype: Optional[type] = None,
1127
- output: Optional[Union[BsrMatrix, wp.array]] = None,
1128
- device=None,
1129
- temporary_store: Optional[cache.TemporaryStore] = None,
1130
- kernel_options: Dict[str, Any] = {},
1131
- ):
1132
- """
1133
- Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
1134
-
1135
- Args:
1136
- integrand: Form to be integrated, must have :func:`integrand` decorator
1137
- domain: Integration domain. If None, deduced from fields
1138
- quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1139
- nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1140
- fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1141
- values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1142
- temporary_store: shared pool from which to allocate temporary arrays
1143
- accumulate_dtype: Scalar type to be used for accumulating integration samples
1144
- output: Sparse matrix or warp array into which to store the result of the integration
1145
- output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
1146
- device: Device on which to perform the integration
1147
- kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1148
- """
1149
- if not isinstance(integrand, Integrand):
1150
- raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1151
-
1152
- test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1153
-
1154
- if domain is None:
1155
- if quadrature is not None:
1156
- domain = quadrature.domain
1157
- elif test is not None:
1158
- domain = test.domain
1159
-
1160
- if domain is None:
1161
- raise ValueError("Must provide at least one of domain, quadrature, or test field")
1162
- if test is not None and domain != test.domain:
1163
- raise NotImplementedError("Mixing integration and test domain is not supported yet")
1164
-
1165
- if nodal:
1166
- if quadrature is not None:
1167
- raise ValueError("Cannot specify quadrature for nodal integration")
1168
-
1169
- if test is None:
1170
- raise ValueError("Nodal integration requires specifying a test function")
1171
-
1172
- if trial is not None and test.space_partition != trial.space_partition:
1173
- raise ValueError(
1174
- "Bilinear nodal integration requires test and trial to be defined on the same function space"
1175
- )
1176
- else:
1177
- if quadrature is None:
1178
- order = sum(field.degree for field in fields.values())
1179
- quadrature = RegularQuadrature(domain=domain, order=order)
1180
- elif domain != quadrature.domain:
1181
- raise ValueError("Incompatible integration and quadrature domain")
1182
-
1183
- # Canonicalize types
1184
- accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1185
- if output is not None:
1186
- if isinstance(output, BsrMatrix):
1187
- output_dtype = output.scalar_type
1188
- else:
1189
- output_dtype = output.dtype
1190
- elif output_dtype is None:
1191
- output_dtype = accumulate_dtype
1192
- else:
1193
- output_dtype = wp.types.type_to_warp(output_dtype)
1194
-
1195
- kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1196
- integrand=integrand,
1197
- domain=domain,
1198
- nodal=nodal,
1199
- quadrature=quadrature,
1200
- test=test,
1201
- test_name=test_name,
1202
- trial=trial,
1203
- trial_name=trial_name,
1204
- fields=fields,
1205
- accumulate_dtype=accumulate_dtype,
1206
- output_dtype=output_dtype,
1207
- kernel_options=kernel_options,
1208
- )
1209
-
1210
- return _launch_integrate_kernel(
1211
- kernel=kernel,
1212
- FieldStruct=FieldStruct,
1213
- ValueStruct=ValueStruct,
1214
- domain=domain,
1215
- nodal=nodal,
1216
- quadrature=quadrature,
1217
- test=test,
1218
- trial=trial,
1219
- fields=fields,
1220
- values=values,
1221
- accumulate_dtype=accumulate_dtype,
1222
- temporary_store=temporary_store,
1223
- output_dtype=output_dtype,
1224
- output=output,
1225
- device=device,
1226
- )
1227
-
1228
-
1229
- def get_interpolate_to_field_function(
1230
- integrand_func: wp.Function,
1231
- domain: GeometryDomain,
1232
- FieldStruct: wp.codegen.Struct,
1233
- ValueStruct: wp.codegen.Struct,
1234
- dest: FieldRestriction,
1235
- ):
1236
- value_type = dest.space.dtype
1237
-
1238
- def interpolate_to_field_fn(
1239
- local_node_index: int,
1240
- domain_arg: domain.ElementArg,
1241
- domain_index_arg: domain.ElementIndexArg,
1242
- dest_node_arg: dest.space_restriction.NodeArg,
1243
- dest_eval_arg: dest.field.EvalArg,
1244
- fields: FieldStruct,
1245
- values: ValueStruct,
1246
- ):
1247
- node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1248
- element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
1249
-
1250
- test_dof_index = NULL_DOF_INDEX
1251
- trial_dof_index = NULL_DOF_INDEX
1252
- node_weight = 1.0
1253
-
1254
- # Volume-weighted average across elements
1255
- # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1256
-
1257
- val_sum = value_type(0.0)
1258
- vol_sum = float(0.0)
1259
-
1260
- for n in range(element_count):
1261
- node_element_index = dest.space_restriction.node_element_index(dest_node_arg, local_node_index, n)
1262
- element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1263
-
1264
- coords = dest.space.node_coords_in_element(
1265
- domain_arg,
1266
- dest_eval_arg.space_arg,
1267
- element_index,
1268
- node_element_index.node_index_in_element,
1269
- )
1270
-
1271
- if coords[0] != OUTSIDE:
1272
- sample = Sample(
1273
- element_index,
1274
- coords,
1275
- node_index,
1276
- node_weight,
1277
- test_dof_index,
1278
- trial_dof_index,
1279
- )
1280
- vol = domain.element_measure(domain_arg, sample)
1281
- val = integrand_func(sample, fields, values)
1282
-
1283
- vol_sum += vol
1284
- val_sum += vol * val
1285
-
1286
- return val_sum, vol_sum
1287
-
1288
- return interpolate_to_field_fn
1289
-
1290
-
1291
- def get_interpolate_to_field_kernel(
1292
- interpolate_to_field_fn: wp.Function,
1293
- domain: GeometryDomain,
1294
- FieldStruct: wp.codegen.Struct,
1295
- ValueStruct: wp.codegen.Struct,
1296
- dest: FieldRestriction,
1297
- ):
1298
- def interpolate_to_field_kernel_fn(
1299
- domain_arg: domain.ElementArg,
1300
- domain_index_arg: domain.ElementIndexArg,
1301
- dest_node_arg: dest.space_restriction.NodeArg,
1302
- dest_eval_arg: dest.field.EvalArg,
1303
- fields: FieldStruct,
1304
- values: ValueStruct,
1305
- ):
1306
- local_node_index = wp.tid()
1307
-
1308
- val_sum, vol_sum = interpolate_to_field_fn(
1309
- local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1310
- )
1311
-
1312
- if vol_sum > 0.0:
1313
- node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1314
- dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
1315
-
1316
- return interpolate_to_field_kernel_fn
1317
-
1318
-
1319
- def get_interpolate_to_array_kernel(
1320
- integrand_func: wp.Function,
1321
- domain: GeometryDomain,
1322
- quadrature: Quadrature,
1323
- FieldStruct: wp.codegen.Struct,
1324
- ValueStruct: wp.codegen.Struct,
1325
- value_type: type,
1326
- ):
1327
- def interpolate_to_array_kernel_fn(
1328
- qp_arg: quadrature.Arg,
1329
- domain_arg: quadrature.domain.ElementArg,
1330
- domain_index_arg: quadrature.domain.ElementIndexArg,
1331
- fields: FieldStruct,
1332
- values: ValueStruct,
1333
- result: wp.array(dtype=value_type),
1334
- ):
1335
- element_index = domain.element_index(domain_index_arg, wp.tid())
1336
-
1337
- test_dof_index = NULL_DOF_INDEX
1338
- trial_dof_index = NULL_DOF_INDEX
1339
-
1340
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1341
- for k in range(qp_point_count):
1342
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1343
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1344
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1345
-
1346
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1347
-
1348
- result[qp_index] = integrand_func(sample, fields, values)
1349
-
1350
- return interpolate_to_array_kernel_fn
1351
-
1352
-
1353
- def get_interpolate_nonvalued_kernel(
1354
- integrand_func: wp.Function,
1355
- domain: GeometryDomain,
1356
- quadrature: Quadrature,
1357
- FieldStruct: wp.codegen.Struct,
1358
- ValueStruct: wp.codegen.Struct,
1359
- ):
1360
- def interpolate_nonvalued_kernel_fn(
1361
- qp_arg: quadrature.Arg,
1362
- domain_arg: quadrature.domain.ElementArg,
1363
- domain_index_arg: quadrature.domain.ElementIndexArg,
1364
- fields: FieldStruct,
1365
- values: ValueStruct,
1366
- ):
1367
- element_index = domain.element_index(domain_index_arg, wp.tid())
1368
-
1369
- test_dof_index = NULL_DOF_INDEX
1370
- trial_dof_index = NULL_DOF_INDEX
1371
-
1372
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1373
- for k in range(qp_point_count):
1374
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1375
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1376
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1377
-
1378
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1379
- integrand_func(sample, fields, values)
1380
-
1381
- return interpolate_nonvalued_kernel_fn
1382
-
1383
-
1384
- def _generate_interpolate_kernel(
1385
- integrand: Integrand,
1386
- domain: GeometryDomain,
1387
- dest: Optional[Union[FieldLike, wp.array]],
1388
- quadrature: Optional[Quadrature],
1389
- fields: Dict[str, FieldLike],
1390
- kernel_options: Dict[str, Any] = {},
1391
- ) -> wp.Kernel:
1392
- # Extract field arguments from integrand
1393
- field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
1394
- integrand, fields=fields, domain=domain
1395
- )
1396
-
1397
- # Generate field struct
1398
- integrand_func = _translate_integrand(
1399
- integrand,
1400
- field_args,
1401
- )
1402
-
1403
- _register_integrand_field_wrappers(integrand_func, fields)
1404
-
1405
- FieldStruct = _gen_field_struct(field_args)
1406
- ValueStruct = _gen_value_struct(value_args)
1407
-
1408
- # Check if kernel exist in cache
1409
- if isinstance(dest, FieldRestriction):
1410
- kernel_suffix = (
1411
- f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1412
- )
1413
- elif wp.types.is_array(dest):
1414
- kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
1415
- else:
1416
- kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
1417
-
1418
- kernel = cache.get_integrand_kernel(
1419
- integrand=integrand,
1420
- suffix=kernel_suffix,
1421
- )
1422
- if kernel is not None:
1423
- return kernel, FieldStruct, ValueStruct
1424
-
1425
- # Generate interpolation kernel
1426
- if isinstance(dest, FieldRestriction):
1427
- # need to split into kernel + function for diffferentiability
1428
- interpolate_fn = get_interpolate_to_field_function(
1429
- integrand_func,
1430
- domain,
1431
- dest=dest,
1432
- FieldStruct=FieldStruct,
1433
- ValueStruct=ValueStruct,
1434
- )
1435
-
1436
- interpolate_fn = cache.get_integrand_function(
1437
- integrand=integrand,
1438
- func=interpolate_fn,
1439
- suffix=kernel_suffix,
1440
- code_transformers=[
1441
- PassFieldArgsToIntegrand(
1442
- arg_names=integrand.argspec.args,
1443
- field_args=field_args.keys(),
1444
- value_args=value_args.keys(),
1445
- sample_name=sample_name,
1446
- domain_name=domain_name,
1447
- )
1448
- ],
1449
- )
1450
-
1451
- interpolate_kernel_fn = get_interpolate_to_field_kernel(
1452
- interpolate_fn,
1453
- domain,
1454
- dest=dest,
1455
- FieldStruct=FieldStruct,
1456
- ValueStruct=ValueStruct,
1457
- )
1458
- elif wp.types.is_array(dest):
1459
- interpolate_kernel_fn = get_interpolate_to_array_kernel(
1460
- integrand_func,
1461
- domain=domain,
1462
- quadrature=quadrature,
1463
- value_type=dest.dtype,
1464
- FieldStruct=FieldStruct,
1465
- ValueStruct=ValueStruct,
1466
- )
1467
- else:
1468
- interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
1469
- integrand_func,
1470
- domain=domain,
1471
- quadrature=quadrature,
1472
- FieldStruct=FieldStruct,
1473
- ValueStruct=ValueStruct,
1474
- )
1475
-
1476
- kernel = cache.get_integrand_kernel(
1477
- integrand=integrand,
1478
- kernel_fn=interpolate_kernel_fn,
1479
- suffix=kernel_suffix,
1480
- kernel_options=kernel_options,
1481
- code_transformers=[
1482
- PassFieldArgsToIntegrand(
1483
- arg_names=integrand.argspec.args,
1484
- field_args=field_args.keys(),
1485
- value_args=value_args.keys(),
1486
- sample_name=sample_name,
1487
- domain_name=domain_name,
1488
- )
1489
- ],
1490
- )
1491
-
1492
- return kernel, FieldStruct, ValueStruct
1493
-
1494
-
1495
- def _launch_interpolate_kernel(
1496
- kernel: wp.kernel,
1497
- FieldStruct: wp.codegen.Struct,
1498
- ValueStruct: wp.codegen.Struct,
1499
- domain: GeometryDomain,
1500
- dest: Optional[Union[FieldRestriction, wp.array]],
1501
- quadrature: Optional[Quadrature],
1502
- fields: Dict[str, FieldLike],
1503
- values: Dict[str, Any],
1504
- device,
1505
- ) -> wp.Kernel:
1506
- # Set-up launch arguments
1507
- elt_arg = domain.element_arg_value(device=device)
1508
- elt_index_arg = domain.element_index_arg_value(device=device)
1509
-
1510
- field_arg_values = FieldStruct()
1511
- for k, v in fields.items():
1512
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
1513
-
1514
- value_struct_values = ValueStruct()
1515
- for k, v in values.items():
1516
- setattr(value_struct_values, k, v)
1517
-
1518
- if isinstance(dest, FieldRestriction):
1519
- dest_node_arg = dest.space_restriction.node_arg(device=device)
1520
- dest_eval_arg = dest.field.eval_arg_value(device=device)
1521
-
1522
- wp.launch(
1523
- kernel=kernel,
1524
- dim=dest.space_restriction.node_count(),
1525
- inputs=[
1526
- elt_arg,
1527
- elt_index_arg,
1528
- dest_node_arg,
1529
- dest_eval_arg,
1530
- field_arg_values,
1531
- value_struct_values,
1532
- ],
1533
- device=device,
1534
- )
1535
- elif wp.types.is_array(dest):
1536
- qp_arg = quadrature.arg_value(device)
1537
- wp.launch(
1538
- kernel=kernel,
1539
- dim=domain.element_count(),
1540
- inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
1541
- device=device,
1542
- )
1543
- else:
1544
- qp_arg = quadrature.arg_value(device)
1545
- wp.launch(
1546
- kernel=kernel,
1547
- dim=domain.element_count(),
1548
- inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
1549
- device=device,
1550
- )
1551
-
1552
-
1553
- def interpolate(
1554
- integrand: Integrand,
1555
- dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
1556
- quadrature: Optional[Quadrature] = None,
1557
- fields: Dict[str, FieldLike] = {},
1558
- values: Dict[str, Any] = {},
1559
- device=None,
1560
- kernel_options: Dict[str, Any] = {},
1561
- ):
1562
- """
1563
- Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
1564
-
1565
- Args:
1566
- integrand: Function to be interpolated, must have :func:`integrand` decorator
1567
- dest: Where to store the interpolation result. Can be either
1568
-
1569
- - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
1570
- - a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
1571
- - ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is responsible for dealing with the interpolation result.
1572
- quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
1573
- fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
1574
- values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1575
- device: Device on which to perform the interpolation
1576
- kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1577
- """
1578
- if not isinstance(integrand, Integrand):
1579
- raise ValueError("integrand must be tagged with @integrand decorator")
1580
-
1581
- test, _, trial, __ = _get_test_and_trial_fields(fields)
1582
- if test is not None or trial is not None:
1583
- raise ValueError("Test or Trial fields should not be used for interpolation")
1584
-
1585
- if isinstance(dest, DiscreteField):
1586
- dest = make_restriction(dest)
1587
-
1588
- if isinstance(dest, FieldRestriction):
1589
- domain = dest.domain
1590
- else:
1591
- if quadrature is None:
1592
- raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
1593
-
1594
- domain = quadrature.domain
1595
-
1596
- kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
1597
- integrand=integrand,
1598
- domain=domain,
1599
- dest=dest,
1600
- quadrature=quadrature,
1601
- fields=fields,
1602
- kernel_options=kernel_options,
1603
- )
1604
-
1605
- return _launch_interpolate_kernel(
1606
- kernel=kernel,
1607
- FieldStruct=FieldStruct,
1608
- ValueStruct=ValueStruct,
1609
- domain=domain,
1610
- dest=dest,
1611
- quadrature=quadrature,
1612
- fields=fields,
1613
- values=values,
1614
- device=device,
1615
- )
1
+ import ast
2
+ from typing import Any, Dict, List, Optional, Set, Union
3
+
4
+ import warp as wp
5
+ from warp.codegen import get_annotations
6
+ from warp.fem import cache
7
+ from warp.fem.domain import GeometryDomain
8
+ from warp.fem.field import (
9
+ DiscreteField,
10
+ FieldLike,
11
+ FieldRestriction,
12
+ TestField,
13
+ TrialField,
14
+ make_restriction,
15
+ )
16
+ from warp.fem.operator import Integrand, Operator
17
+ from warp.fem.quadrature import Quadrature, RegularQuadrature
18
+ from warp.fem.types import NULL_DOF_INDEX, OUTSIDE, DofIndex, Domain, Field, Sample, make_free_sample
19
+ from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
20
+ from warp.types import type_length
21
+ from warp.utils import array_cast
22
+
23
+
24
+ def _resolve_path(func, node):
25
+ """
26
+ Resolves variable and path from ast node/attribute (adapted from warp.codegen)
27
+ """
28
+
29
+ modules = []
30
+
31
+ while isinstance(node, ast.Attribute):
32
+ modules.append(node.attr)
33
+ node = node.value
34
+
35
+ if isinstance(node, ast.Name):
36
+ modules.append(node.id)
37
+
38
+ # reverse list since ast presents it backward order
39
+ path = [*reversed(modules)]
40
+
41
+ if len(path) == 0:
42
+ return None, path
43
+
44
+ # try and evaluate object path
45
+ try:
46
+ # Look up the closure info and append it to adj.func.__globals__
47
+ # in case you want to define a kernel inside a function and refer
48
+ # to variables you've declared inside that function:
49
+ capturedvars = dict(zip(func.__code__.co_freevars, [c.cell_contents for c in (func.__closure__ or [])]))
50
+
51
+ vars_dict = {**func.__globals__, **capturedvars}
52
+ func = eval(".".join(path), vars_dict)
53
+ return func, path
54
+ except (NameError, AttributeError):
55
+ pass
56
+
57
+ return None, path
58
+
59
+
60
+ def _path_to_ast_attribute(name: str) -> ast.Attribute:
61
+ path = name.split(".")
62
+ path.reverse()
63
+
64
+ node = ast.Name(id=path.pop(), ctx=ast.Load())
65
+ while len(path):
66
+ node = ast.Attribute(
67
+ value=node,
68
+ attr=path.pop(),
69
+ ctx=ast.Load(),
70
+ )
71
+ return node
72
+
73
+
74
+ class IntegrandTransformer(ast.NodeTransformer):
75
+ def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
76
+ self._integrand = integrand
77
+ self._field_args = field_args
78
+
79
+ def visit_Call(self, call: ast.Call):
80
+ call = self.generic_visit(call)
81
+
82
+ callee = getattr(call.func, "id", None)
83
+ if callee in self._field_args:
84
+ # Shortcut for evaluating fields as f(x...)
85
+ field = self._field_args[callee]
86
+
87
+ arg_type = self._integrand.argspec.annotations[callee]
88
+ operator = arg_type.call_operator
89
+
90
+ call.func = ast.Attribute(
91
+ value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
92
+ attr="call_operator",
93
+ ctx=ast.Load(),
94
+ )
95
+ call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
96
+
97
+ self._replace_call_func(call, operator, field)
98
+
99
+ return call
100
+
101
+ func, _ = _resolve_path(self._integrand.func, call.func)
102
+
103
+ if isinstance(func, Operator) and len(call.args) > 0:
104
+ # Evaluating operators as op(field, x, ...)
105
+ callee = getattr(call.args[0], "id", None)
106
+ if callee in self._field_args:
107
+ field = self._field_args[callee]
108
+ self._replace_call_func(call, func, field)
109
+
110
+ if isinstance(func, Integrand):
111
+ key = self._translate_callee(func, call.args)
112
+ call.func = ast.Attribute(
113
+ value=call.func,
114
+ attr=key,
115
+ ctx=ast.Load(),
116
+ )
117
+
118
+ # print(ast.dump(call, indent=4))
119
+
120
+ return call
121
+
122
+ def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
123
+ try:
124
+ pointer = operator.resolver(field)
125
+ setattr(operator, pointer.key, pointer)
126
+ except AttributeError as e:
127
+ raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}") from e
128
+ call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
129
+
130
+ def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
131
+ # Get field types for call site arguments
132
+ call_site_field_args = []
133
+ for arg in args:
134
+ name = getattr(arg, "id", None)
135
+ if name in self._field_args:
136
+ call_site_field_args.append(self._field_args[name])
137
+
138
+ call_site_field_args.reverse()
139
+
140
+ # Pass to callee in same order
141
+ callee_field_args = {}
142
+ for arg in callee.argspec.args:
143
+ arg_type = callee.argspec.annotations[arg]
144
+ if arg_type in (Field, Domain):
145
+ callee_field_args[arg] = call_site_field_args.pop()
146
+
147
+ return _translate_integrand(callee, callee_field_args).key
148
+
149
+
150
+ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
151
+ # Specialize field argument types
152
+ argspec = integrand.argspec
153
+ annotations = {}
154
+ for arg in argspec.args:
155
+ arg_type = argspec.annotations[arg]
156
+ if arg_type == Field:
157
+ annotations[arg] = field_args[arg].ElementEvalArg
158
+ elif arg_type == Domain:
159
+ annotations[arg] = field_args[arg].ElementArg
160
+ else:
161
+ annotations[arg] = arg_type
162
+
163
+ # Transform field evaluation calls
164
+ transformer = IntegrandTransformer(integrand, field_args)
165
+
166
+ suffix = "_".join([f.name for f in field_args.values()])
167
+
168
+ func = cache.get_integrand_function(
169
+ integrand=integrand,
170
+ suffix=suffix,
171
+ annotations=annotations,
172
+ code_transformers=[transformer],
173
+ )
174
+
175
+ key = func.key
176
+ setattr(integrand, key, integrand.module.functions[key])
177
+
178
+ return getattr(integrand, key)
179
+
180
+
181
+ def _get_integrand_field_arguments(
182
+ integrand: Integrand,
183
+ fields: Dict[str, FieldLike],
184
+ domain: GeometryDomain = None,
185
+ ):
186
+ # parse argument types
187
+ field_args = {}
188
+ value_args = {}
189
+
190
+ domain_name = None
191
+ sample_name = None
192
+
193
+ argspec = integrand.argspec
194
+ for arg in argspec.args:
195
+ arg_type = argspec.annotations[arg]
196
+ if arg_type == Field:
197
+ if arg not in fields:
198
+ raise ValueError(f"Missing field for argument '{arg}'")
199
+ field_args[arg] = fields[arg]
200
+ elif arg_type == Domain:
201
+ domain_name = arg
202
+ field_args[arg] = domain
203
+ elif arg_type == Sample:
204
+ sample_name = arg
205
+ else:
206
+ value_args[arg] = arg_type
207
+
208
+ return field_args, value_args, domain_name, sample_name
209
+
210
+
211
+ def _get_test_and_trial_fields(
212
+ fields: Dict[str, FieldLike],
213
+ ):
214
+ test = None
215
+ trial = None
216
+ test_name = None
217
+ trial_name = None
218
+
219
+ for name, field in fields.items():
220
+ if isinstance(field, TestField):
221
+ if test is not None:
222
+ raise ValueError("Duplicate test field argument")
223
+ test = field
224
+ test_name = name
225
+ elif isinstance(field, TrialField):
226
+ if trial is not None:
227
+ raise ValueError("Duplicate test field argument")
228
+ trial = field
229
+ trial_name = name
230
+
231
+ if trial is not None:
232
+ if test is None:
233
+ raise ValueError("A trial field cannot be provided without a test field")
234
+
235
+ if test.domain != trial.domain:
236
+ raise ValueError("Incompatible test and trial domains")
237
+
238
+ return test, test_name, trial, trial_name
239
+
240
+
241
+ def _gen_field_struct(field_args: Dict[str, FieldLike]):
242
+ class Fields:
243
+ pass
244
+
245
+ annotations = get_annotations(Fields)
246
+
247
+ for name, arg in field_args.items():
248
+ if isinstance(arg, GeometryDomain):
249
+ continue
250
+ setattr(Fields, name, arg.EvalArg())
251
+ annotations[name] = arg.EvalArg
252
+
253
+ try:
254
+ Fields.__annotations__ = annotations
255
+ except AttributeError:
256
+ Fields.__dict__.__annotations__ = annotations
257
+
258
+ suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
259
+
260
+ return cache.get_struct(Fields, suffix=suffix)
261
+
262
+
263
+ def _gen_value_struct(value_args: Dict[str, type]):
264
+ class Values:
265
+ pass
266
+
267
+ annotations = get_annotations(Values)
268
+
269
+ for name, arg_type in value_args.items():
270
+ setattr(Values, name, None)
271
+ annotations[name] = arg_type
272
+
273
+ def arg_type_name(arg_type):
274
+ if isinstance(arg_type, wp.codegen.Struct):
275
+ return arg_type_name(arg_type.cls)
276
+ return getattr(arg_type, "__name__", str(arg_type))
277
+
278
+ def arg_type_name(arg_type):
279
+ if isinstance(arg_type, wp.codegen.Struct):
280
+ return arg_type_name(arg_type.cls)
281
+ return getattr(arg_type, "__name__", str(arg_type))
282
+
283
+ try:
284
+ Values.__annotations__ = annotations
285
+ except AttributeError:
286
+ Values.__dict__.__annotations__ = annotations
287
+
288
+ suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
289
+
290
+ return cache.get_struct(Values, suffix=suffix)
291
+
292
+
293
+ def _get_trial_arg():
294
+ pass
295
+
296
+
297
+ def _get_test_arg():
298
+ pass
299
+
300
+
301
+ class _FieldWrappers:
302
+ pass
303
+
304
+
305
+ def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
306
+ integrand_func._field_wrappers = _FieldWrappers()
307
+ for name, field in fields.items():
308
+ setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
309
+
310
+
311
+ class PassFieldArgsToIntegrand(ast.NodeTransformer):
312
+ def __init__(
313
+ self,
314
+ arg_names: List[str],
315
+ field_args: Set[str],
316
+ value_args: Set[str],
317
+ sample_name: str,
318
+ domain_name: str,
319
+ test_name: str = None,
320
+ trial_name: str = None,
321
+ func_name: str = "integrand_func",
322
+ fields_var_name: str = "fields",
323
+ values_var_name: str = "values",
324
+ domain_var_name: str = "domain_arg",
325
+ sample_var_name: str = "sample",
326
+ field_wrappers_attr: str = "_field_wrappers",
327
+ ):
328
+ self._arg_names = arg_names
329
+ self._field_args = field_args
330
+ self._value_args = value_args
331
+ self._domain_name = domain_name
332
+ self._sample_name = sample_name
333
+ self._func_name = func_name
334
+ self._test_name = test_name
335
+ self._trial_name = trial_name
336
+ self._fields_var_name = fields_var_name
337
+ self._values_var_name = values_var_name
338
+ self._domain_var_name = domain_var_name
339
+ self._sample_var_name = sample_var_name
340
+ self._field_wrappers_attr = field_wrappers_attr
341
+
342
+ def visit_Call(self, call: ast.Call):
343
+ call = self.generic_visit(call)
344
+
345
+ callee = getattr(call.func, "id", None)
346
+
347
+ if callee == self._func_name:
348
+ # Replace function arguments with ours generated structs
349
+ call.args.clear()
350
+ for arg in self._arg_names:
351
+ if arg == self._domain_name:
352
+ call.args.append(
353
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
354
+ )
355
+ elif arg == self._sample_name:
356
+ call.args.append(
357
+ ast.Name(id=self._sample_var_name, ctx=ast.Load()),
358
+ )
359
+ elif arg in self._field_args:
360
+ call.args.append(
361
+ ast.Call(
362
+ func=ast.Attribute(
363
+ value=ast.Attribute(
364
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
365
+ attr=self._field_wrappers_attr,
366
+ ctx=ast.Load(),
367
+ ),
368
+ attr=arg,
369
+ ctx=ast.Load(),
370
+ ),
371
+ args=[
372
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
373
+ ast.Attribute(
374
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
375
+ attr=arg,
376
+ ctx=ast.Load(),
377
+ ),
378
+ ],
379
+ keywords=[],
380
+ )
381
+ )
382
+ elif arg in self._value_args:
383
+ call.args.append(
384
+ ast.Attribute(
385
+ value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
386
+ attr=arg,
387
+ ctx=ast.Load(),
388
+ )
389
+ )
390
+ else:
391
+ raise RuntimeError(f"Unhandled argument {arg}")
392
+ # print(ast.dump(call, indent=4))
393
+ elif callee == _get_test_arg.__name__:
394
+ # print(ast.dump(call, indent=4))
395
+ call = ast.Attribute(
396
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
397
+ attr=self._test_name,
398
+ ctx=ast.Load(),
399
+ )
400
+ elif callee == _get_trial_arg.__name__:
401
+ # print(ast.dump(call, indent=4))
402
+ call = ast.Attribute(
403
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
404
+ attr=self._trial_name,
405
+ ctx=ast.Load(),
406
+ )
407
+
408
+ return call
409
+
410
+
411
+ def get_integrate_constant_kernel(
412
+ integrand_func: wp.Function,
413
+ domain: GeometryDomain,
414
+ quadrature: Quadrature,
415
+ FieldStruct: wp.codegen.Struct,
416
+ ValueStruct: wp.codegen.Struct,
417
+ accumulate_dtype,
418
+ ):
419
+ def integrate_kernel_fn(
420
+ qp_arg: quadrature.Arg,
421
+ domain_arg: domain.ElementArg,
422
+ domain_index_arg: domain.ElementIndexArg,
423
+ fields: FieldStruct,
424
+ values: ValueStruct,
425
+ result: wp.array(dtype=accumulate_dtype),
426
+ ):
427
+ element_index = domain.element_index(domain_index_arg, wp.tid())
428
+ elem_sum = accumulate_dtype(0.0)
429
+
430
+ test_dof_index = NULL_DOF_INDEX
431
+ trial_dof_index = NULL_DOF_INDEX
432
+
433
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
434
+ for k in range(qp_point_count):
435
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
436
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
437
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
438
+
439
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
440
+ vol = domain.element_measure(domain_arg, sample)
441
+
442
+ val = integrand_func(sample, fields, values)
443
+
444
+ elem_sum += accumulate_dtype(qp_weight * vol * val)
445
+
446
+ wp.atomic_add(result, 0, elem_sum)
447
+
448
+ return integrate_kernel_fn
449
+
450
+
451
+ def get_integrate_linear_kernel(
452
+ integrand_func: wp.Function,
453
+ domain: GeometryDomain,
454
+ quadrature: Quadrature,
455
+ FieldStruct: wp.codegen.Struct,
456
+ ValueStruct: wp.codegen.Struct,
457
+ test: TestField,
458
+ output_dtype,
459
+ accumulate_dtype,
460
+ ):
461
+ def integrate_kernel_fn(
462
+ qp_arg: quadrature.Arg,
463
+ domain_arg: domain.ElementArg,
464
+ domain_index_arg: domain.ElementIndexArg,
465
+ test_arg: test.space_restriction.NodeArg,
466
+ fields: FieldStruct,
467
+ values: ValueStruct,
468
+ result: wp.array2d(dtype=output_dtype),
469
+ ):
470
+ local_node_index, test_dof = wp.tid()
471
+ node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
472
+ element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
473
+
474
+ trial_dof_index = NULL_DOF_INDEX
475
+
476
+ val_sum = accumulate_dtype(0.0)
477
+
478
+ for n in range(element_count):
479
+ node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
480
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
481
+
482
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
483
+
484
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
485
+ for k in range(qp_point_count):
486
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
487
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
488
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
489
+
490
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
491
+
492
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
493
+ val = integrand_func(sample, fields, values)
494
+
495
+ val_sum += accumulate_dtype(qp_weight * vol * val)
496
+
497
+ result[node_index, test_dof] = output_dtype(val_sum)
498
+
499
+ return integrate_kernel_fn
500
+
501
+
502
+ def get_integrate_linear_nodal_kernel(
503
+ integrand_func: wp.Function,
504
+ domain: GeometryDomain,
505
+ FieldStruct: wp.codegen.Struct,
506
+ ValueStruct: wp.codegen.Struct,
507
+ test: TestField,
508
+ output_dtype,
509
+ accumulate_dtype,
510
+ ):
511
+ def integrate_kernel_fn(
512
+ domain_arg: domain.ElementArg,
513
+ domain_index_arg: domain.ElementIndexArg,
514
+ test_restriction_arg: test.space_restriction.NodeArg,
515
+ fields: FieldStruct,
516
+ values: ValueStruct,
517
+ result: wp.array2d(dtype=output_dtype),
518
+ ):
519
+ local_node_index, dof = wp.tid()
520
+
521
+ node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
522
+ element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
523
+
524
+ trial_dof_index = NULL_DOF_INDEX
525
+
526
+ val_sum = accumulate_dtype(0.0)
527
+
528
+ for n in range(element_count):
529
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
530
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
531
+
532
+ coords = test.space.node_coords_in_element(
533
+ domain_arg,
534
+ _get_test_arg(),
535
+ element_index,
536
+ node_element_index.node_index_in_element,
537
+ )
538
+
539
+ if coords[0] != OUTSIDE:
540
+ node_weight = test.space.node_quadrature_weight(
541
+ domain_arg,
542
+ _get_test_arg(),
543
+ element_index,
544
+ node_element_index.node_index_in_element,
545
+ )
546
+
547
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
548
+
549
+ sample = Sample(
550
+ element_index,
551
+ coords,
552
+ node_index,
553
+ node_weight,
554
+ test_dof_index,
555
+ trial_dof_index,
556
+ )
557
+ vol = domain.element_measure(domain_arg, sample)
558
+ val = integrand_func(sample, fields, values)
559
+
560
+ val_sum += accumulate_dtype(node_weight * vol * val)
561
+
562
+ result[node_index, dof] = output_dtype(val_sum)
563
+
564
+ return integrate_kernel_fn
565
+
566
+
567
+ def get_integrate_bilinear_kernel(
568
+ integrand_func: wp.Function,
569
+ domain: GeometryDomain,
570
+ quadrature: Quadrature,
571
+ FieldStruct: wp.codegen.Struct,
572
+ ValueStruct: wp.codegen.Struct,
573
+ test: TestField,
574
+ trial: TrialField,
575
+ output_dtype,
576
+ accumulate_dtype,
577
+ ):
578
+ NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
579
+
580
+ def integrate_kernel_fn(
581
+ qp_arg: quadrature.Arg,
582
+ domain_arg: domain.ElementArg,
583
+ domain_index_arg: domain.ElementIndexArg,
584
+ test_arg: test.space_restriction.NodeArg,
585
+ trial_partition_arg: trial.space_partition.PartitionArg,
586
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
587
+ fields: FieldStruct,
588
+ values: ValueStruct,
589
+ row_offsets: wp.array(dtype=int),
590
+ triplet_rows: wp.array(dtype=int),
591
+ triplet_cols: wp.array(dtype=int),
592
+ triplet_values: wp.array3d(dtype=output_dtype),
593
+ ):
594
+ test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
595
+
596
+ element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
597
+ test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
598
+
599
+ trial_dof_index = DofIndex(trial_node, trial_dof)
600
+
601
+ for element in range(element_count):
602
+ test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
603
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
604
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
605
+
606
+ test_dof_index = DofIndex(
607
+ test_element_index.node_index_in_element,
608
+ test_dof,
609
+ )
610
+
611
+ val_sum = accumulate_dtype(0.0)
612
+
613
+ for k in range(qp_point_count):
614
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
615
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
616
+
617
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
618
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
619
+
620
+ sample = Sample(
621
+ element_index,
622
+ coords,
623
+ qp_index,
624
+ qp_weight,
625
+ test_dof_index,
626
+ trial_dof_index,
627
+ )
628
+ val = integrand_func(sample, fields, values)
629
+ val_sum += accumulate_dtype(qp_weight * vol * val)
630
+
631
+ block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
632
+ triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
633
+
634
+ # Set row and column indices
635
+ if test_dof == 0 and trial_dof == 0:
636
+ trial_node_index = trial.space_partition.partition_node_index(
637
+ trial_partition_arg,
638
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
639
+ )
640
+ triplet_rows[block_offset] = test_node_index
641
+ triplet_cols[block_offset] = trial_node_index
642
+
643
+ return integrate_kernel_fn
644
+
645
+
646
+ def get_integrate_bilinear_nodal_kernel(
647
+ integrand_func: wp.Function,
648
+ domain: GeometryDomain,
649
+ FieldStruct: wp.codegen.Struct,
650
+ ValueStruct: wp.codegen.Struct,
651
+ test: TestField,
652
+ output_dtype,
653
+ accumulate_dtype,
654
+ ):
655
+ def integrate_kernel_fn(
656
+ domain_arg: domain.ElementArg,
657
+ domain_index_arg: domain.ElementIndexArg,
658
+ test_restriction_arg: test.space_restriction.NodeArg,
659
+ fields: FieldStruct,
660
+ values: ValueStruct,
661
+ triplet_rows: wp.array(dtype=int),
662
+ triplet_cols: wp.array(dtype=int),
663
+ triplet_values: wp.array3d(dtype=output_dtype),
664
+ ):
665
+ local_node_index, test_dof, trial_dof = wp.tid()
666
+
667
+ element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
668
+ node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
669
+
670
+ val_sum = accumulate_dtype(0.0)
671
+
672
+ for n in range(element_count):
673
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
674
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
675
+
676
+ coords = test.space.node_coords_in_element(
677
+ domain_arg,
678
+ _get_test_arg(),
679
+ element_index,
680
+ node_element_index.node_index_in_element,
681
+ )
682
+
683
+ if coords[0] != OUTSIDE:
684
+ node_weight = test.space.node_quadrature_weight(
685
+ domain_arg,
686
+ _get_test_arg(),
687
+ element_index,
688
+ node_element_index.node_index_in_element,
689
+ )
690
+
691
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
692
+ trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
693
+
694
+ sample = Sample(
695
+ element_index,
696
+ coords,
697
+ node_index,
698
+ node_weight,
699
+ test_dof_index,
700
+ trial_dof_index,
701
+ )
702
+ vol = domain.element_measure(domain_arg, sample)
703
+ val = integrand_func(sample, fields, values)
704
+
705
+ val_sum += accumulate_dtype(node_weight * vol * val)
706
+
707
+ triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
708
+ triplet_rows[local_node_index] = node_index
709
+ triplet_cols[local_node_index] = node_index
710
+
711
+ return integrate_kernel_fn
712
+
713
+
714
+ def _generate_integrate_kernel(
715
+ integrand: Integrand,
716
+ domain: GeometryDomain,
717
+ nodal: bool,
718
+ quadrature: Quadrature,
719
+ test: Optional[TestField],
720
+ test_name: str,
721
+ trial: Optional[TrialField],
722
+ trial_name: str,
723
+ fields: Dict[str, FieldLike],
724
+ output_dtype: type,
725
+ accumulate_dtype: type,
726
+ kernel_options: Optional[Dict[str, Any]] = None,
727
+ ) -> wp.Kernel:
728
+ if kernel_options is None:
729
+ kernel_options = {}
730
+
731
+ output_dtype = wp.types.type_scalar_type(output_dtype)
732
+
733
+ # Extract field arguments from integrand
734
+ field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
735
+ integrand, fields=fields, domain=domain
736
+ )
737
+
738
+ FieldStruct = _gen_field_struct(field_args)
739
+ ValueStruct = _gen_value_struct(value_args)
740
+
741
+ # Check if kernel exist in cache
742
+ kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
743
+ if nodal:
744
+ kernel_suffix += "_nodal"
745
+ else:
746
+ kernel_suffix += quadrature.name
747
+
748
+ if test:
749
+ kernel_suffix += f"_test_{test.space_partition.name}_{test.space.name}"
750
+ if trial:
751
+ kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
752
+
753
+ kernel = cache.get_integrand_kernel(
754
+ integrand=integrand,
755
+ suffix=kernel_suffix,
756
+ )
757
+ if kernel is not None:
758
+ return kernel, FieldStruct, ValueStruct
759
+
760
+ # Not found in cache, transform integrand and generate kernel
761
+
762
+ integrand_func = _translate_integrand(
763
+ integrand,
764
+ field_args,
765
+ )
766
+
767
+ _register_integrand_field_wrappers(integrand_func, fields)
768
+
769
+ if test is None and trial is None:
770
+ integrate_kernel_fn = get_integrate_constant_kernel(
771
+ integrand_func,
772
+ domain,
773
+ quadrature,
774
+ FieldStruct,
775
+ ValueStruct,
776
+ accumulate_dtype=accumulate_dtype,
777
+ )
778
+ elif trial is None:
779
+ if nodal:
780
+ integrate_kernel_fn = get_integrate_linear_nodal_kernel(
781
+ integrand_func,
782
+ domain,
783
+ FieldStruct,
784
+ ValueStruct,
785
+ test=test,
786
+ output_dtype=output_dtype,
787
+ accumulate_dtype=accumulate_dtype,
788
+ )
789
+ else:
790
+ integrate_kernel_fn = get_integrate_linear_kernel(
791
+ integrand_func,
792
+ domain,
793
+ quadrature,
794
+ FieldStruct,
795
+ ValueStruct,
796
+ test=test,
797
+ output_dtype=output_dtype,
798
+ accumulate_dtype=accumulate_dtype,
799
+ )
800
+ else:
801
+ if nodal:
802
+ integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
803
+ integrand_func,
804
+ domain,
805
+ FieldStruct,
806
+ ValueStruct,
807
+ test=test,
808
+ output_dtype=output_dtype,
809
+ accumulate_dtype=accumulate_dtype,
810
+ )
811
+ else:
812
+ integrate_kernel_fn = get_integrate_bilinear_kernel(
813
+ integrand_func,
814
+ domain,
815
+ quadrature,
816
+ FieldStruct,
817
+ ValueStruct,
818
+ test=test,
819
+ trial=trial,
820
+ output_dtype=output_dtype,
821
+ accumulate_dtype=accumulate_dtype,
822
+ )
823
+
824
+ kernel = cache.get_integrand_kernel(
825
+ integrand=integrand,
826
+ kernel_fn=integrate_kernel_fn,
827
+ suffix=kernel_suffix,
828
+ kernel_options=kernel_options,
829
+ code_transformers=[
830
+ PassFieldArgsToIntegrand(
831
+ arg_names=integrand.argspec.args,
832
+ field_args=field_args.keys(),
833
+ value_args=value_args.keys(),
834
+ sample_name=sample_name,
835
+ domain_name=domain_name,
836
+ test_name=test_name,
837
+ trial_name=trial_name,
838
+ )
839
+ ],
840
+ )
841
+
842
+ return kernel, FieldStruct, ValueStruct
843
+
844
+
845
+ def _launch_integrate_kernel(
846
+ kernel: wp.Kernel,
847
+ FieldStruct: wp.codegen.Struct,
848
+ ValueStruct: wp.codegen.Struct,
849
+ domain: GeometryDomain,
850
+ nodal: bool,
851
+ quadrature: Quadrature,
852
+ test: Optional[TestField],
853
+ trial: Optional[TrialField],
854
+ fields: Dict[str, FieldLike],
855
+ values: Dict[str, Any],
856
+ accumulate_dtype: type,
857
+ temporary_store: Optional[cache.TemporaryStore],
858
+ output_dtype: type,
859
+ output: Optional[Union[wp.array, BsrMatrix]],
860
+ device,
861
+ ):
862
+ # Set-up launch arguments
863
+ domain_elt_arg = domain.element_arg_value(device=device)
864
+ domain_elt_index_arg = domain.element_index_arg_value(device=device)
865
+
866
+ if quadrature is not None:
867
+ qp_arg = quadrature.arg_value(device=device)
868
+
869
+ field_arg_values = FieldStruct()
870
+ for k, v in fields.items():
871
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
872
+
873
+ value_struct_values = ValueStruct()
874
+ for k, v in values.items():
875
+ setattr(value_struct_values, k, v)
876
+
877
+ # Constant form
878
+ if test is None and trial is None:
879
+ if output is not None and output.dtype == accumulate_dtype:
880
+ if output.size < 1:
881
+ raise RuntimeError("Output array must be of size at least 1")
882
+ accumulate_array = output
883
+ else:
884
+ accumulate_temporary = cache.borrow_temporary(
885
+ shape=(1),
886
+ device=device,
887
+ dtype=accumulate_dtype,
888
+ temporary_store=temporary_store,
889
+ requires_grad=output is not None and output.requires_grad,
890
+ )
891
+ accumulate_array = accumulate_temporary.array
892
+
893
+ accumulate_array.zero_()
894
+ wp.launch(
895
+ kernel=kernel,
896
+ dim=domain.element_count(),
897
+ inputs=[
898
+ qp_arg,
899
+ domain_elt_arg,
900
+ domain_elt_index_arg,
901
+ field_arg_values,
902
+ value_struct_values,
903
+ accumulate_array,
904
+ ],
905
+ device=device,
906
+ )
907
+
908
+ if output == accumulate_array:
909
+ return output
910
+ elif output is None:
911
+ return accumulate_array.numpy()[0]
912
+ else:
913
+ array_cast(in_array=accumulate_array, out_array=output)
914
+ return output
915
+
916
+ test_arg = test.space_restriction.node_arg(device=device)
917
+
918
+ # Linear form
919
+ if trial is None:
920
+ # If an output array is provided with the correct type, accumulate directly into it
921
+ # Otherwise, grab a temporary array
922
+ if output is None:
923
+ if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
924
+ output_shape = (test.space_partition.node_count(),)
925
+ elif type_length(output_dtype) == 1:
926
+ output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
927
+ else:
928
+ raise RuntimeError(
929
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
930
+ )
931
+
932
+ output_temporary = cache.borrow_temporary(
933
+ temporary_store=temporary_store,
934
+ shape=output_shape,
935
+ dtype=output_dtype,
936
+ device=device,
937
+ )
938
+
939
+ output = output_temporary.array
940
+
941
+ else:
942
+ output_temporary = None
943
+
944
+ if output.shape[0] < test.space_partition.node_count():
945
+ raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
946
+
947
+ output_dtype = output.dtype
948
+ if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
949
+ if type_length(output_dtype) != 1:
950
+ raise RuntimeError(
951
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
952
+ )
953
+ if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
954
+ raise RuntimeError(
955
+ f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
956
+ )
957
+
958
+ # Launch the integration on the kernel on a 2d scalar view of the actual array
959
+ output.zero_()
960
+
961
+ def as_2d_array(array):
962
+ return wp.array(
963
+ data=None,
964
+ ptr=array.ptr,
965
+ capacity=array.capacity,
966
+ device=array.device,
967
+ shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
968
+ dtype=wp.types.type_scalar_type(output_dtype),
969
+ grad=None if array.grad is None else as_2d_array(array.grad),
970
+ )
971
+
972
+ output_view = output if output.ndim == 2 else as_2d_array(output)
973
+
974
+ if nodal:
975
+ wp.launch(
976
+ kernel=kernel,
977
+ dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
978
+ inputs=[
979
+ domain_elt_arg,
980
+ domain_elt_index_arg,
981
+ test_arg,
982
+ field_arg_values,
983
+ value_struct_values,
984
+ output_view,
985
+ ],
986
+ device=device,
987
+ )
988
+ else:
989
+ wp.launch(
990
+ kernel=kernel,
991
+ dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
992
+ inputs=[
993
+ qp_arg,
994
+ domain_elt_arg,
995
+ domain_elt_index_arg,
996
+ test_arg,
997
+ field_arg_values,
998
+ value_struct_values,
999
+ output_view,
1000
+ ],
1001
+ device=device,
1002
+ )
1003
+
1004
+ if output_temporary is not None:
1005
+ return output_temporary.detach()
1006
+
1007
+ return output
1008
+
1009
+ # Bilinear form
1010
+
1011
+ if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
1012
+ block_type = output_dtype
1013
+ else:
1014
+ block_type = cache.cached_mat_type(
1015
+ shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
1016
+ )
1017
+
1018
+ if nodal:
1019
+ nnz = test.space_restriction.node_count()
1020
+ else:
1021
+ nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
1022
+
1023
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1024
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1025
+ triplet_values_temp = cache.borrow_temporary(
1026
+ temporary_store,
1027
+ shape=(
1028
+ nnz,
1029
+ test.space.VALUE_DOF_COUNT,
1030
+ trial.space.VALUE_DOF_COUNT,
1031
+ ),
1032
+ dtype=output_dtype,
1033
+ device=device,
1034
+ )
1035
+ triplet_cols = triplet_cols_temp.array
1036
+ triplet_rows = triplet_rows_temp.array
1037
+ triplet_values = triplet_values_temp.array
1038
+
1039
+ triplet_values.zero_()
1040
+
1041
+ if nodal:
1042
+ wp.launch(
1043
+ kernel=kernel,
1044
+ dim=triplet_values.shape,
1045
+ inputs=[
1046
+ domain_elt_arg,
1047
+ domain_elt_index_arg,
1048
+ test_arg,
1049
+ field_arg_values,
1050
+ value_struct_values,
1051
+ triplet_rows,
1052
+ triplet_cols,
1053
+ triplet_values,
1054
+ ],
1055
+ device=device,
1056
+ )
1057
+
1058
+ else:
1059
+ offsets = test.space_restriction.partition_element_offsets()
1060
+
1061
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1062
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1063
+ wp.launch(
1064
+ kernel=kernel,
1065
+ dim=(
1066
+ test.space_restriction.node_count(),
1067
+ trial.space.topology.NODES_PER_ELEMENT,
1068
+ test.space.VALUE_DOF_COUNT,
1069
+ trial.space.VALUE_DOF_COUNT,
1070
+ ),
1071
+ inputs=[
1072
+ qp_arg,
1073
+ domain_elt_arg,
1074
+ domain_elt_index_arg,
1075
+ test_arg,
1076
+ trial_partition_arg,
1077
+ trial_topology_arg,
1078
+ field_arg_values,
1079
+ value_struct_values,
1080
+ offsets,
1081
+ triplet_rows,
1082
+ triplet_cols,
1083
+ triplet_values,
1084
+ ],
1085
+ device=device,
1086
+ )
1087
+
1088
+ if output is not None:
1089
+ if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1090
+ raise RuntimeError(
1091
+ f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1092
+ )
1093
+
1094
+ else:
1095
+ output = bsr_zeros(
1096
+ rows_of_blocks=test.space_partition.node_count(),
1097
+ cols_of_blocks=trial.space_partition.node_count(),
1098
+ block_type=block_type,
1099
+ device=device,
1100
+ )
1101
+
1102
+ bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
1103
+
1104
+ # Do not wait for garbage collection
1105
+ triplet_values_temp.release()
1106
+ triplet_rows_temp.release()
1107
+ triplet_cols_temp.release()
1108
+
1109
+ return output
1110
+
1111
+
1112
+ def integrate(
1113
+ integrand: Integrand,
1114
+ domain: Optional[GeometryDomain] = None,
1115
+ quadrature: Optional[Quadrature] = None,
1116
+ nodal: bool = False,
1117
+ fields: Optional[Dict[str, FieldLike]] = None,
1118
+ values: Optional[Dict[str, Any]] = None,
1119
+ accumulate_dtype: type = wp.float64,
1120
+ output_dtype: Optional[type] = None,
1121
+ output: Optional[Union[BsrMatrix, wp.array]] = None,
1122
+ device=None,
1123
+ temporary_store: Optional[cache.TemporaryStore] = None,
1124
+ kernel_options: Optional[Dict[str, Any]] = None,
1125
+ ):
1126
+ """
1127
+ Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
1128
+
1129
+ Args:
1130
+ integrand: Form to be integrated, must have :func:`integrand` decorator
1131
+ domain: Integration domain. If None, deduced from fields
1132
+ quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1133
+ nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1134
+ fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1135
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1136
+ temporary_store: shared pool from which to allocate temporary arrays
1137
+ accumulate_dtype: Scalar type to be used for accumulating integration samples
1138
+ output: Sparse matrix or warp array into which to store the result of the integration
1139
+ output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
1140
+ device: Device on which to perform the integration
1141
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1142
+ """
1143
+ if fields is None:
1144
+ fields = {}
1145
+
1146
+ if values is None:
1147
+ values = {}
1148
+
1149
+ if kernel_options is None:
1150
+ kernel_options = {}
1151
+
1152
+ if not isinstance(integrand, Integrand):
1153
+ raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1154
+
1155
+ test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1156
+
1157
+ if domain is None:
1158
+ if quadrature is not None:
1159
+ domain = quadrature.domain
1160
+ elif test is not None:
1161
+ domain = test.domain
1162
+
1163
+ if domain is None:
1164
+ raise ValueError("Must provide at least one of domain, quadrature, or test field")
1165
+ if test is not None and domain != test.domain:
1166
+ raise NotImplementedError("Mixing integration and test domain is not supported yet")
1167
+
1168
+ if nodal:
1169
+ if quadrature is not None:
1170
+ raise ValueError("Cannot specify quadrature for nodal integration")
1171
+
1172
+ if test is None:
1173
+ raise ValueError("Nodal integration requires specifying a test function")
1174
+
1175
+ if trial is not None and test.space_partition != trial.space_partition:
1176
+ raise ValueError(
1177
+ "Bilinear nodal integration requires test and trial to be defined on the same function space"
1178
+ )
1179
+ else:
1180
+ if quadrature is None:
1181
+ order = sum(field.degree for field in fields.values())
1182
+ quadrature = RegularQuadrature(domain=domain, order=order)
1183
+ elif domain != quadrature.domain:
1184
+ raise ValueError("Incompatible integration and quadrature domain")
1185
+
1186
+ # Canonicalize types
1187
+ accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1188
+ if output is not None:
1189
+ if isinstance(output, BsrMatrix):
1190
+ output_dtype = output.scalar_type
1191
+ else:
1192
+ output_dtype = output.dtype
1193
+ elif output_dtype is None:
1194
+ output_dtype = accumulate_dtype
1195
+ else:
1196
+ output_dtype = wp.types.type_to_warp(output_dtype)
1197
+
1198
+ kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1199
+ integrand=integrand,
1200
+ domain=domain,
1201
+ nodal=nodal,
1202
+ quadrature=quadrature,
1203
+ test=test,
1204
+ test_name=test_name,
1205
+ trial=trial,
1206
+ trial_name=trial_name,
1207
+ fields=fields,
1208
+ accumulate_dtype=accumulate_dtype,
1209
+ output_dtype=output_dtype,
1210
+ kernel_options=kernel_options,
1211
+ )
1212
+
1213
+ return _launch_integrate_kernel(
1214
+ kernel=kernel,
1215
+ FieldStruct=FieldStruct,
1216
+ ValueStruct=ValueStruct,
1217
+ domain=domain,
1218
+ nodal=nodal,
1219
+ quadrature=quadrature,
1220
+ test=test,
1221
+ trial=trial,
1222
+ fields=fields,
1223
+ values=values,
1224
+ accumulate_dtype=accumulate_dtype,
1225
+ temporary_store=temporary_store,
1226
+ output_dtype=output_dtype,
1227
+ output=output,
1228
+ device=device,
1229
+ )
1230
+
1231
+
1232
+ def get_interpolate_to_field_function(
1233
+ integrand_func: wp.Function,
1234
+ domain: GeometryDomain,
1235
+ FieldStruct: wp.codegen.Struct,
1236
+ ValueStruct: wp.codegen.Struct,
1237
+ dest: FieldRestriction,
1238
+ ):
1239
+ value_type = dest.space.dtype
1240
+
1241
+ def interpolate_to_field_fn(
1242
+ local_node_index: int,
1243
+ domain_arg: domain.ElementArg,
1244
+ domain_index_arg: domain.ElementIndexArg,
1245
+ dest_node_arg: dest.space_restriction.NodeArg,
1246
+ dest_eval_arg: dest.field.EvalArg,
1247
+ fields: FieldStruct,
1248
+ values: ValueStruct,
1249
+ ):
1250
+ node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1251
+ element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
1252
+
1253
+ test_dof_index = NULL_DOF_INDEX
1254
+ trial_dof_index = NULL_DOF_INDEX
1255
+ node_weight = 1.0
1256
+
1257
+ # Volume-weighted average across elements
1258
+ # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1259
+
1260
+ val_sum = value_type(0.0)
1261
+ vol_sum = float(0.0)
1262
+
1263
+ for n in range(element_count):
1264
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, local_node_index, n)
1265
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1266
+
1267
+ coords = dest.space.node_coords_in_element(
1268
+ domain_arg,
1269
+ dest_eval_arg.space_arg,
1270
+ element_index,
1271
+ node_element_index.node_index_in_element,
1272
+ )
1273
+
1274
+ if coords[0] != OUTSIDE:
1275
+ sample = Sample(
1276
+ element_index,
1277
+ coords,
1278
+ node_index,
1279
+ node_weight,
1280
+ test_dof_index,
1281
+ trial_dof_index,
1282
+ )
1283
+ vol = domain.element_measure(domain_arg, sample)
1284
+ val = integrand_func(sample, fields, values)
1285
+
1286
+ vol_sum += vol
1287
+ val_sum += vol * val
1288
+
1289
+ return val_sum, vol_sum
1290
+
1291
+ return interpolate_to_field_fn
1292
+
1293
+
1294
+ def get_interpolate_to_field_kernel(
1295
+ interpolate_to_field_fn: wp.Function,
1296
+ domain: GeometryDomain,
1297
+ FieldStruct: wp.codegen.Struct,
1298
+ ValueStruct: wp.codegen.Struct,
1299
+ dest: FieldRestriction,
1300
+ ):
1301
+ def interpolate_to_field_kernel_fn(
1302
+ domain_arg: domain.ElementArg,
1303
+ domain_index_arg: domain.ElementIndexArg,
1304
+ dest_node_arg: dest.space_restriction.NodeArg,
1305
+ dest_eval_arg: dest.field.EvalArg,
1306
+ fields: FieldStruct,
1307
+ values: ValueStruct,
1308
+ ):
1309
+ local_node_index = wp.tid()
1310
+
1311
+ val_sum, vol_sum = interpolate_to_field_fn(
1312
+ local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1313
+ )
1314
+
1315
+ if vol_sum > 0.0:
1316
+ node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1317
+ dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
1318
+
1319
+ return interpolate_to_field_kernel_fn
1320
+
1321
+
1322
+ def get_interpolate_to_array_kernel(
1323
+ integrand_func: wp.Function,
1324
+ domain: GeometryDomain,
1325
+ quadrature: Quadrature,
1326
+ FieldStruct: wp.codegen.Struct,
1327
+ ValueStruct: wp.codegen.Struct,
1328
+ value_type: type,
1329
+ ):
1330
+ def interpolate_to_array_kernel_fn(
1331
+ qp_arg: quadrature.Arg,
1332
+ domain_arg: quadrature.domain.ElementArg,
1333
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1334
+ fields: FieldStruct,
1335
+ values: ValueStruct,
1336
+ result: wp.array(dtype=value_type),
1337
+ ):
1338
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1339
+
1340
+ test_dof_index = NULL_DOF_INDEX
1341
+ trial_dof_index = NULL_DOF_INDEX
1342
+
1343
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1344
+ for k in range(qp_point_count):
1345
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1346
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1347
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1348
+
1349
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1350
+
1351
+ result[qp_index] = integrand_func(sample, fields, values)
1352
+
1353
+ return interpolate_to_array_kernel_fn
1354
+
1355
+
1356
+ def get_interpolate_nonvalued_kernel(
1357
+ integrand_func: wp.Function,
1358
+ domain: GeometryDomain,
1359
+ quadrature: Quadrature,
1360
+ FieldStruct: wp.codegen.Struct,
1361
+ ValueStruct: wp.codegen.Struct,
1362
+ ):
1363
+ def interpolate_nonvalued_kernel_fn(
1364
+ qp_arg: quadrature.Arg,
1365
+ domain_arg: quadrature.domain.ElementArg,
1366
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1367
+ fields: FieldStruct,
1368
+ values: ValueStruct,
1369
+ ):
1370
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1371
+
1372
+ test_dof_index = NULL_DOF_INDEX
1373
+ trial_dof_index = NULL_DOF_INDEX
1374
+
1375
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1376
+ for k in range(qp_point_count):
1377
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1378
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1379
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1380
+
1381
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1382
+ integrand_func(sample, fields, values)
1383
+
1384
+ return interpolate_nonvalued_kernel_fn
1385
+
1386
+
1387
+ def _generate_interpolate_kernel(
1388
+ integrand: Integrand,
1389
+ domain: GeometryDomain,
1390
+ dest: Optional[Union[FieldLike, wp.array]],
1391
+ quadrature: Optional[Quadrature],
1392
+ fields: Dict[str, FieldLike],
1393
+ kernel_options: Optional[Dict[str, Any]] = None,
1394
+ ) -> wp.Kernel:
1395
+ if kernel_options is None:
1396
+ kernel_options = {}
1397
+
1398
+ # Extract field arguments from integrand
1399
+ field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
1400
+ integrand, fields=fields, domain=domain
1401
+ )
1402
+
1403
+ # Generate field struct
1404
+ integrand_func = _translate_integrand(
1405
+ integrand,
1406
+ field_args,
1407
+ )
1408
+
1409
+ _register_integrand_field_wrappers(integrand_func, fields)
1410
+
1411
+ FieldStruct = _gen_field_struct(field_args)
1412
+ ValueStruct = _gen_value_struct(value_args)
1413
+
1414
+ # Check if kernel exist in cache
1415
+ if isinstance(dest, FieldRestriction):
1416
+ kernel_suffix = (
1417
+ f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1418
+ )
1419
+ elif wp.types.is_array(dest):
1420
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
1421
+ else:
1422
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
1423
+
1424
+ kernel = cache.get_integrand_kernel(
1425
+ integrand=integrand,
1426
+ suffix=kernel_suffix,
1427
+ )
1428
+ if kernel is not None:
1429
+ return kernel, FieldStruct, ValueStruct
1430
+
1431
+ # Generate interpolation kernel
1432
+ if isinstance(dest, FieldRestriction):
1433
+ # need to split into kernel + function for diffferentiability
1434
+ interpolate_fn = get_interpolate_to_field_function(
1435
+ integrand_func,
1436
+ domain,
1437
+ dest=dest,
1438
+ FieldStruct=FieldStruct,
1439
+ ValueStruct=ValueStruct,
1440
+ )
1441
+
1442
+ interpolate_fn = cache.get_integrand_function(
1443
+ integrand=integrand,
1444
+ func=interpolate_fn,
1445
+ suffix=kernel_suffix,
1446
+ code_transformers=[
1447
+ PassFieldArgsToIntegrand(
1448
+ arg_names=integrand.argspec.args,
1449
+ field_args=field_args.keys(),
1450
+ value_args=value_args.keys(),
1451
+ sample_name=sample_name,
1452
+ domain_name=domain_name,
1453
+ )
1454
+ ],
1455
+ )
1456
+
1457
+ interpolate_kernel_fn = get_interpolate_to_field_kernel(
1458
+ interpolate_fn,
1459
+ domain,
1460
+ dest=dest,
1461
+ FieldStruct=FieldStruct,
1462
+ ValueStruct=ValueStruct,
1463
+ )
1464
+ elif wp.types.is_array(dest):
1465
+ interpolate_kernel_fn = get_interpolate_to_array_kernel(
1466
+ integrand_func,
1467
+ domain=domain,
1468
+ quadrature=quadrature,
1469
+ value_type=dest.dtype,
1470
+ FieldStruct=FieldStruct,
1471
+ ValueStruct=ValueStruct,
1472
+ )
1473
+ else:
1474
+ interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
1475
+ integrand_func,
1476
+ domain=domain,
1477
+ quadrature=quadrature,
1478
+ FieldStruct=FieldStruct,
1479
+ ValueStruct=ValueStruct,
1480
+ )
1481
+
1482
+ kernel = cache.get_integrand_kernel(
1483
+ integrand=integrand,
1484
+ kernel_fn=interpolate_kernel_fn,
1485
+ suffix=kernel_suffix,
1486
+ kernel_options=kernel_options,
1487
+ code_transformers=[
1488
+ PassFieldArgsToIntegrand(
1489
+ arg_names=integrand.argspec.args,
1490
+ field_args=field_args.keys(),
1491
+ value_args=value_args.keys(),
1492
+ sample_name=sample_name,
1493
+ domain_name=domain_name,
1494
+ )
1495
+ ],
1496
+ )
1497
+
1498
+ return kernel, FieldStruct, ValueStruct
1499
+
1500
+
1501
+ def _launch_interpolate_kernel(
1502
+ kernel: wp.kernel,
1503
+ FieldStruct: wp.codegen.Struct,
1504
+ ValueStruct: wp.codegen.Struct,
1505
+ domain: GeometryDomain,
1506
+ dest: Optional[Union[FieldRestriction, wp.array]],
1507
+ quadrature: Optional[Quadrature],
1508
+ fields: Dict[str, FieldLike],
1509
+ values: Dict[str, Any],
1510
+ device,
1511
+ ) -> wp.Kernel:
1512
+ # Set-up launch arguments
1513
+ elt_arg = domain.element_arg_value(device=device)
1514
+ elt_index_arg = domain.element_index_arg_value(device=device)
1515
+
1516
+ field_arg_values = FieldStruct()
1517
+ for k, v in fields.items():
1518
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
1519
+
1520
+ value_struct_values = ValueStruct()
1521
+ for k, v in values.items():
1522
+ setattr(value_struct_values, k, v)
1523
+
1524
+ if isinstance(dest, FieldRestriction):
1525
+ dest_node_arg = dest.space_restriction.node_arg(device=device)
1526
+ dest_eval_arg = dest.field.eval_arg_value(device=device)
1527
+
1528
+ wp.launch(
1529
+ kernel=kernel,
1530
+ dim=dest.space_restriction.node_count(),
1531
+ inputs=[
1532
+ elt_arg,
1533
+ elt_index_arg,
1534
+ dest_node_arg,
1535
+ dest_eval_arg,
1536
+ field_arg_values,
1537
+ value_struct_values,
1538
+ ],
1539
+ device=device,
1540
+ )
1541
+ elif wp.types.is_array(dest):
1542
+ qp_arg = quadrature.arg_value(device)
1543
+ wp.launch(
1544
+ kernel=kernel,
1545
+ dim=domain.element_count(),
1546
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
1547
+ device=device,
1548
+ )
1549
+ else:
1550
+ qp_arg = quadrature.arg_value(device)
1551
+ wp.launch(
1552
+ kernel=kernel,
1553
+ dim=domain.element_count(),
1554
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
1555
+ device=device,
1556
+ )
1557
+
1558
+
1559
+ def interpolate(
1560
+ integrand: Integrand,
1561
+ dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
1562
+ quadrature: Optional[Quadrature] = None,
1563
+ fields: Optional[Dict[str, FieldLike]] = None,
1564
+ values: Optional[Dict[str, Any]] = None,
1565
+ device=None,
1566
+ kernel_options: Optional[Dict[str, Any]] = None,
1567
+ ):
1568
+ """
1569
+ Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
1570
+
1571
+ Args:
1572
+ integrand: Function to be interpolated, must have :func:`integrand` decorator
1573
+ dest: Where to store the interpolation result. Can be either
1574
+
1575
+ - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
1576
+ - a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
1577
+ - ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is responsible for dealing with the interpolation result.
1578
+ quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
1579
+ fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
1580
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1581
+ device: Device on which to perform the interpolation
1582
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1583
+ """
1584
+ if fields is None:
1585
+ fields = {}
1586
+
1587
+ if values is None:
1588
+ values = {}
1589
+
1590
+ if kernel_options is None:
1591
+ kernel_options = {}
1592
+
1593
+ if not isinstance(integrand, Integrand):
1594
+ raise ValueError("integrand must be tagged with @integrand decorator")
1595
+
1596
+ test, _, trial, __ = _get_test_and_trial_fields(fields)
1597
+ if test is not None or trial is not None:
1598
+ raise ValueError("Test or Trial fields should not be used for interpolation")
1599
+
1600
+ if isinstance(dest, DiscreteField):
1601
+ dest = make_restriction(dest)
1602
+
1603
+ if isinstance(dest, FieldRestriction):
1604
+ domain = dest.domain
1605
+ else:
1606
+ if quadrature is None:
1607
+ raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
1608
+
1609
+ domain = quadrature.domain
1610
+
1611
+ kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
1612
+ integrand=integrand,
1613
+ domain=domain,
1614
+ dest=dest,
1615
+ quadrature=quadrature,
1616
+ fields=fields,
1617
+ kernel_options=kernel_options,
1618
+ )
1619
+
1620
+ return _launch_interpolate_kernel(
1621
+ kernel=kernel,
1622
+ FieldStruct=FieldStruct,
1623
+ ValueStruct=ValueStruct,
1624
+ domain=domain,
1625
+ dest=dest,
1626
+ quadrature=quadrature,
1627
+ fields=fields,
1628
+ values=values,
1629
+ device=device,
1630
+ )