warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
ADDED
|
@@ -0,0 +1,1616 @@
|
|
|
1
|
+
from typing import List, Dict, Set, Optional, Any, Union
|
|
2
|
+
|
|
3
|
+
import warp as wp
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
import ast
|
|
7
|
+
|
|
8
|
+
from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_assign
|
|
9
|
+
from warp.types import type_length
|
|
10
|
+
from warp.utils import array_cast
|
|
11
|
+
from warp.codegen import get_annotations
|
|
12
|
+
|
|
13
|
+
from warp.fem.domain import GeometryDomain
|
|
14
|
+
from warp.fem.field import (
|
|
15
|
+
TestField,
|
|
16
|
+
TrialField,
|
|
17
|
+
FieldLike,
|
|
18
|
+
DiscreteField,
|
|
19
|
+
FieldRestriction,
|
|
20
|
+
make_restriction,
|
|
21
|
+
)
|
|
22
|
+
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
23
|
+
from warp.fem.operator import Operator, Integrand
|
|
24
|
+
from warp.fem import cache
|
|
25
|
+
from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE, make_free_sample
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _resolve_path(func, node):
|
|
29
|
+
"""
|
|
30
|
+
Resolves variable and path from ast node/attribute (adapted from warp.codegen)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
modules = []
|
|
34
|
+
|
|
35
|
+
while isinstance(node, ast.Attribute):
|
|
36
|
+
modules.append(node.attr)
|
|
37
|
+
node = node.value
|
|
38
|
+
|
|
39
|
+
if isinstance(node, ast.Name):
|
|
40
|
+
modules.append(node.id)
|
|
41
|
+
|
|
42
|
+
# reverse list since ast presents it backward order
|
|
43
|
+
path = [*reversed(modules)]
|
|
44
|
+
|
|
45
|
+
if len(path) == 0:
|
|
46
|
+
return None, path
|
|
47
|
+
|
|
48
|
+
# try and evaluate object path
|
|
49
|
+
try:
|
|
50
|
+
# Look up the closure info and append it to adj.func.__globals__
|
|
51
|
+
# in case you want to define a kernel inside a function and refer
|
|
52
|
+
# to variables you've declared inside that function:
|
|
53
|
+
capturedvars = dict(
|
|
54
|
+
zip(
|
|
55
|
+
func.__code__.co_freevars,
|
|
56
|
+
[c.cell_contents for c in (func.__closure__ or [])],
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
vars_dict = {**func.__globals__, **capturedvars}
|
|
61
|
+
func = eval(".".join(path), vars_dict)
|
|
62
|
+
return func, path
|
|
63
|
+
except (NameError, AttributeError):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
return None, path
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _path_to_ast_attribute(name: str) -> ast.Attribute:
|
|
70
|
+
path = name.split(".")
|
|
71
|
+
path.reverse()
|
|
72
|
+
|
|
73
|
+
node = ast.Name(id=path.pop(), ctx=ast.Load())
|
|
74
|
+
while len(path):
|
|
75
|
+
node = ast.Attribute(
|
|
76
|
+
value=node,
|
|
77
|
+
attr=path.pop(),
|
|
78
|
+
ctx=ast.Load(),
|
|
79
|
+
)
|
|
80
|
+
return node
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class IntegrandTransformer(ast.NodeTransformer):
|
|
84
|
+
def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
85
|
+
self._integrand = integrand
|
|
86
|
+
self._field_args = field_args
|
|
87
|
+
|
|
88
|
+
def visit_Call(self, call: ast.Call):
|
|
89
|
+
call = self.generic_visit(call)
|
|
90
|
+
|
|
91
|
+
callee = getattr(call.func, "id", None)
|
|
92
|
+
if callee in self._field_args:
|
|
93
|
+
# Shortcut for evaluating fields as f(x...)
|
|
94
|
+
field = self._field_args[callee]
|
|
95
|
+
|
|
96
|
+
arg_type = self._integrand.argspec.annotations[callee]
|
|
97
|
+
operator = arg_type.call_operator
|
|
98
|
+
|
|
99
|
+
call.func = ast.Attribute(
|
|
100
|
+
value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
|
|
101
|
+
attr="call_operator",
|
|
102
|
+
ctx=ast.Load(),
|
|
103
|
+
)
|
|
104
|
+
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
|
|
105
|
+
|
|
106
|
+
self._replace_call_func(call, operator, field)
|
|
107
|
+
|
|
108
|
+
return call
|
|
109
|
+
|
|
110
|
+
func, _ = _resolve_path(self._integrand.func, call.func)
|
|
111
|
+
|
|
112
|
+
if isinstance(func, Operator) and len(call.args) > 0:
|
|
113
|
+
# Evaluating operators as op(field, x, ...)
|
|
114
|
+
callee = getattr(call.args[0], "id", None)
|
|
115
|
+
if callee in self._field_args:
|
|
116
|
+
field = self._field_args[callee]
|
|
117
|
+
self._replace_call_func(call, func, field)
|
|
118
|
+
|
|
119
|
+
if isinstance(func, Integrand):
|
|
120
|
+
key = self._translate_callee(func, call.args)
|
|
121
|
+
call.func = ast.Attribute(
|
|
122
|
+
value=call.func,
|
|
123
|
+
attr=key,
|
|
124
|
+
ctx=ast.Load(),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# print(ast.dump(call, indent=4))
|
|
128
|
+
|
|
129
|
+
return call
|
|
130
|
+
|
|
131
|
+
def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
|
|
132
|
+
try:
|
|
133
|
+
pointer = operator.resolver(field)
|
|
134
|
+
setattr(operator, pointer.key, pointer)
|
|
135
|
+
except AttributeError:
|
|
136
|
+
raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}")
|
|
137
|
+
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
|
|
138
|
+
|
|
139
|
+
def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
|
|
140
|
+
# Get field types for call site arguments
|
|
141
|
+
call_site_field_args = []
|
|
142
|
+
for arg in args:
|
|
143
|
+
name = getattr(arg, "id", None)
|
|
144
|
+
if name in self._field_args:
|
|
145
|
+
call_site_field_args.append(self._field_args[name])
|
|
146
|
+
|
|
147
|
+
call_site_field_args.reverse()
|
|
148
|
+
|
|
149
|
+
# Pass to callee in same order
|
|
150
|
+
callee_field_args = {}
|
|
151
|
+
for arg in callee.argspec.args:
|
|
152
|
+
arg_type = callee.argspec.annotations[arg]
|
|
153
|
+
if arg_type in (Field, Domain):
|
|
154
|
+
callee_field_args[arg] = call_site_field_args.pop()
|
|
155
|
+
|
|
156
|
+
return _translate_integrand(callee, callee_field_args).key
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
|
|
160
|
+
# Specialize field argument types
|
|
161
|
+
argspec = integrand.argspec
|
|
162
|
+
annotations = {}
|
|
163
|
+
for arg in argspec.args:
|
|
164
|
+
arg_type = argspec.annotations[arg]
|
|
165
|
+
if arg_type == Field:
|
|
166
|
+
annotations[arg] = field_args[arg].ElementEvalArg
|
|
167
|
+
elif arg_type == Domain:
|
|
168
|
+
annotations[arg] = field_args[arg].ElementArg
|
|
169
|
+
else:
|
|
170
|
+
annotations[arg] = arg_type
|
|
171
|
+
|
|
172
|
+
# Transform field evaluation calls
|
|
173
|
+
transformer = IntegrandTransformer(integrand, field_args)
|
|
174
|
+
|
|
175
|
+
suffix = "_".join([f.name for f in field_args.values()])
|
|
176
|
+
|
|
177
|
+
func = cache.get_integrand_function(
|
|
178
|
+
integrand=integrand,
|
|
179
|
+
suffix=suffix,
|
|
180
|
+
annotations=annotations,
|
|
181
|
+
code_transformers=[transformer],
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
key = func.key
|
|
185
|
+
setattr(integrand, key, integrand.module.functions[key])
|
|
186
|
+
|
|
187
|
+
return getattr(integrand, key)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _get_integrand_field_arguments(
|
|
191
|
+
integrand: Integrand,
|
|
192
|
+
fields: Dict[str, FieldLike],
|
|
193
|
+
domain: GeometryDomain = None,
|
|
194
|
+
):
|
|
195
|
+
# parse argument types
|
|
196
|
+
field_args = {}
|
|
197
|
+
value_args = {}
|
|
198
|
+
|
|
199
|
+
domain_name = None
|
|
200
|
+
sample_name = None
|
|
201
|
+
|
|
202
|
+
argspec = integrand.argspec
|
|
203
|
+
for arg in argspec.args:
|
|
204
|
+
arg_type = argspec.annotations[arg]
|
|
205
|
+
if arg_type == Field:
|
|
206
|
+
if arg not in fields:
|
|
207
|
+
raise ValueError(f"Missing field for argument '{arg}'")
|
|
208
|
+
field_args[arg] = fields[arg]
|
|
209
|
+
elif arg_type == Domain:
|
|
210
|
+
domain_name = arg
|
|
211
|
+
field_args[arg] = domain
|
|
212
|
+
elif arg_type == Sample:
|
|
213
|
+
sample_name = arg
|
|
214
|
+
else:
|
|
215
|
+
value_args[arg] = arg_type
|
|
216
|
+
|
|
217
|
+
return field_args, value_args, domain_name, sample_name
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _get_test_and_trial_fields(
|
|
221
|
+
fields: Dict[str, FieldLike],
|
|
222
|
+
):
|
|
223
|
+
test = None
|
|
224
|
+
trial = None
|
|
225
|
+
test_name = None
|
|
226
|
+
trial_name = None
|
|
227
|
+
|
|
228
|
+
for name, field in fields.items():
|
|
229
|
+
if isinstance(field, TestField):
|
|
230
|
+
if test is not None:
|
|
231
|
+
raise ValueError("Duplicate test field argument")
|
|
232
|
+
test = field
|
|
233
|
+
test_name = name
|
|
234
|
+
elif isinstance(field, TrialField):
|
|
235
|
+
if trial is not None:
|
|
236
|
+
raise ValueError("Duplicate test field argument")
|
|
237
|
+
trial = field
|
|
238
|
+
trial_name = name
|
|
239
|
+
|
|
240
|
+
if trial is not None:
|
|
241
|
+
if test is None:
|
|
242
|
+
raise ValueError("A trial field cannot be provided without a test field")
|
|
243
|
+
|
|
244
|
+
if test.domain != trial.domain:
|
|
245
|
+
raise ValueError("Incompatible test and trial domains")
|
|
246
|
+
|
|
247
|
+
return test, test_name, trial, trial_name
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
251
|
+
class Fields:
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
annotations = get_annotations(Fields)
|
|
255
|
+
|
|
256
|
+
for name, arg in field_args.items():
|
|
257
|
+
if isinstance(arg, GeometryDomain):
|
|
258
|
+
continue
|
|
259
|
+
setattr(Fields, name, arg.EvalArg())
|
|
260
|
+
annotations[name] = arg.EvalArg
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
Fields.__annotations__ = annotations
|
|
264
|
+
except AttributeError:
|
|
265
|
+
setattr(Fields.__dict__, "__annotations__", annotations)
|
|
266
|
+
|
|
267
|
+
suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
|
|
268
|
+
|
|
269
|
+
return cache.get_struct(Fields, suffix=suffix)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _gen_value_struct(value_args: Dict[str, type]):
|
|
273
|
+
class Values:
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
annotations = get_annotations(Values)
|
|
277
|
+
|
|
278
|
+
for name, arg_type in value_args.items():
|
|
279
|
+
setattr(Values, name, None)
|
|
280
|
+
annotations[name] = arg_type
|
|
281
|
+
|
|
282
|
+
def arg_type_name(arg_type):
|
|
283
|
+
if isinstance(arg_type, wp.codegen.Struct):
|
|
284
|
+
return arg_type_name(arg_type.cls)
|
|
285
|
+
return getattr(arg_type, "__name__", str(arg_type))
|
|
286
|
+
|
|
287
|
+
def arg_type_name(arg_type):
|
|
288
|
+
if isinstance(arg_type, wp.codegen.Struct):
|
|
289
|
+
return arg_type_name(arg_type.cls)
|
|
290
|
+
return getattr(arg_type, "__name__", str(arg_type))
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
Values.__annotations__ = annotations
|
|
294
|
+
except AttributeError:
|
|
295
|
+
setattr(Values.__dict__, "__annotations__", annotations)
|
|
296
|
+
|
|
297
|
+
suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
|
|
298
|
+
|
|
299
|
+
return cache.get_struct(Values, suffix=suffix)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _get_trial_arg():
|
|
303
|
+
pass
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _get_test_arg():
|
|
307
|
+
pass
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class _FieldWrappers:
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
|
|
315
|
+
integrand_func._field_wrappers = _FieldWrappers()
|
|
316
|
+
for name, field in fields.items():
|
|
317
|
+
setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
arg_names: List[str],
|
|
324
|
+
field_args: Set[str],
|
|
325
|
+
value_args: Set[str],
|
|
326
|
+
sample_name: str,
|
|
327
|
+
domain_name: str,
|
|
328
|
+
test_name: str = None,
|
|
329
|
+
trial_name: str = None,
|
|
330
|
+
func_name: str = "integrand_func",
|
|
331
|
+
fields_var_name: str = "fields",
|
|
332
|
+
values_var_name: str = "values",
|
|
333
|
+
domain_var_name: str = "domain_arg",
|
|
334
|
+
sample_var_name: str = "sample",
|
|
335
|
+
field_wrappers_attr: str = "_field_wrappers",
|
|
336
|
+
):
|
|
337
|
+
self._arg_names = arg_names
|
|
338
|
+
self._field_args = field_args
|
|
339
|
+
self._value_args = value_args
|
|
340
|
+
self._domain_name = domain_name
|
|
341
|
+
self._sample_name = sample_name
|
|
342
|
+
self._func_name = func_name
|
|
343
|
+
self._test_name = test_name
|
|
344
|
+
self._trial_name = trial_name
|
|
345
|
+
self._fields_var_name = fields_var_name
|
|
346
|
+
self._values_var_name = values_var_name
|
|
347
|
+
self._domain_var_name = domain_var_name
|
|
348
|
+
self._sample_var_name = sample_var_name
|
|
349
|
+
self._field_wrappers_attr = field_wrappers_attr
|
|
350
|
+
|
|
351
|
+
def visit_Call(self, call: ast.Call):
|
|
352
|
+
call = self.generic_visit(call)
|
|
353
|
+
|
|
354
|
+
callee = getattr(call.func, "id", None)
|
|
355
|
+
|
|
356
|
+
if callee == self._func_name:
|
|
357
|
+
# Replace function arguments with ours generated structs
|
|
358
|
+
call.args.clear()
|
|
359
|
+
for arg in self._arg_names:
|
|
360
|
+
if arg == self._domain_name:
|
|
361
|
+
call.args.append(
|
|
362
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
363
|
+
)
|
|
364
|
+
elif arg == self._sample_name:
|
|
365
|
+
call.args.append(
|
|
366
|
+
ast.Name(id=self._sample_var_name, ctx=ast.Load()),
|
|
367
|
+
)
|
|
368
|
+
elif arg in self._field_args:
|
|
369
|
+
call.args.append(
|
|
370
|
+
ast.Call(
|
|
371
|
+
func=ast.Attribute(
|
|
372
|
+
value=ast.Attribute(
|
|
373
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
|
374
|
+
attr=self._field_wrappers_attr,
|
|
375
|
+
ctx=ast.Load(),
|
|
376
|
+
),
|
|
377
|
+
attr=arg,
|
|
378
|
+
ctx=ast.Load(),
|
|
379
|
+
),
|
|
380
|
+
args=[
|
|
381
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
382
|
+
ast.Attribute(
|
|
383
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
384
|
+
attr=arg,
|
|
385
|
+
ctx=ast.Load(),
|
|
386
|
+
),
|
|
387
|
+
],
|
|
388
|
+
keywords=[],
|
|
389
|
+
)
|
|
390
|
+
)
|
|
391
|
+
elif arg in self._value_args:
|
|
392
|
+
call.args.append(
|
|
393
|
+
ast.Attribute(
|
|
394
|
+
value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
|
|
395
|
+
attr=arg,
|
|
396
|
+
ctx=ast.Load(),
|
|
397
|
+
)
|
|
398
|
+
)
|
|
399
|
+
else:
|
|
400
|
+
raise RuntimeError(f"Unhandled argument {arg}")
|
|
401
|
+
# print(ast.dump(call, indent=4))
|
|
402
|
+
elif callee == _get_test_arg.__name__:
|
|
403
|
+
# print(ast.dump(call, indent=4))
|
|
404
|
+
call = ast.Attribute(
|
|
405
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
406
|
+
attr=self._test_name,
|
|
407
|
+
ctx=ast.Load(),
|
|
408
|
+
)
|
|
409
|
+
elif callee == _get_trial_arg.__name__:
|
|
410
|
+
# print(ast.dump(call, indent=4))
|
|
411
|
+
call = ast.Attribute(
|
|
412
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
413
|
+
attr=self._trial_name,
|
|
414
|
+
ctx=ast.Load(),
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return call
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def get_integrate_constant_kernel(
|
|
421
|
+
integrand_func: wp.Function,
|
|
422
|
+
domain: GeometryDomain,
|
|
423
|
+
quadrature: Quadrature,
|
|
424
|
+
FieldStruct: wp.codegen.Struct,
|
|
425
|
+
ValueStruct: wp.codegen.Struct,
|
|
426
|
+
accumulate_dtype,
|
|
427
|
+
):
|
|
428
|
+
def integrate_kernel_fn(
|
|
429
|
+
qp_arg: quadrature.Arg,
|
|
430
|
+
domain_arg: domain.ElementArg,
|
|
431
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
432
|
+
fields: FieldStruct,
|
|
433
|
+
values: ValueStruct,
|
|
434
|
+
result: wp.array(dtype=accumulate_dtype),
|
|
435
|
+
):
|
|
436
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
437
|
+
elem_sum = accumulate_dtype(0.0)
|
|
438
|
+
|
|
439
|
+
test_dof_index = NULL_DOF_INDEX
|
|
440
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
441
|
+
|
|
442
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
443
|
+
for k in range(qp_point_count):
|
|
444
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
445
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
446
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
447
|
+
|
|
448
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
449
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
450
|
+
|
|
451
|
+
val = integrand_func(sample, fields, values)
|
|
452
|
+
|
|
453
|
+
elem_sum += accumulate_dtype(qp_weight * vol * val)
|
|
454
|
+
|
|
455
|
+
wp.atomic_add(result, 0, elem_sum)
|
|
456
|
+
|
|
457
|
+
return integrate_kernel_fn
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def get_integrate_linear_kernel(
|
|
461
|
+
integrand_func: wp.Function,
|
|
462
|
+
domain: GeometryDomain,
|
|
463
|
+
quadrature: Quadrature,
|
|
464
|
+
FieldStruct: wp.codegen.Struct,
|
|
465
|
+
ValueStruct: wp.codegen.Struct,
|
|
466
|
+
test: TestField,
|
|
467
|
+
output_dtype,
|
|
468
|
+
accumulate_dtype,
|
|
469
|
+
):
|
|
470
|
+
def integrate_kernel_fn(
|
|
471
|
+
qp_arg: quadrature.Arg,
|
|
472
|
+
domain_arg: domain.ElementArg,
|
|
473
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
474
|
+
test_arg: test.space_restriction.NodeArg,
|
|
475
|
+
fields: FieldStruct,
|
|
476
|
+
values: ValueStruct,
|
|
477
|
+
result: wp.array2d(dtype=output_dtype),
|
|
478
|
+
):
|
|
479
|
+
local_node_index, test_dof = wp.tid()
|
|
480
|
+
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
481
|
+
element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
|
|
482
|
+
|
|
483
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
484
|
+
|
|
485
|
+
val_sum = accumulate_dtype(0.0)
|
|
486
|
+
|
|
487
|
+
for n in range(element_count):
|
|
488
|
+
node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
|
|
489
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
490
|
+
|
|
491
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
492
|
+
|
|
493
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
494
|
+
for k in range(qp_point_count):
|
|
495
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
496
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
497
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
498
|
+
|
|
499
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
500
|
+
|
|
501
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
502
|
+
val = integrand_func(sample, fields, values)
|
|
503
|
+
|
|
504
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
505
|
+
|
|
506
|
+
result[node_index, test_dof] = output_dtype(val_sum)
|
|
507
|
+
|
|
508
|
+
return integrate_kernel_fn
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def get_integrate_linear_nodal_kernel(
|
|
512
|
+
integrand_func: wp.Function,
|
|
513
|
+
domain: GeometryDomain,
|
|
514
|
+
FieldStruct: wp.codegen.Struct,
|
|
515
|
+
ValueStruct: wp.codegen.Struct,
|
|
516
|
+
test: TestField,
|
|
517
|
+
output_dtype,
|
|
518
|
+
accumulate_dtype,
|
|
519
|
+
):
|
|
520
|
+
def integrate_kernel_fn(
|
|
521
|
+
domain_arg: domain.ElementArg,
|
|
522
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
523
|
+
test_restriction_arg: test.space_restriction.NodeArg,
|
|
524
|
+
fields: FieldStruct,
|
|
525
|
+
values: ValueStruct,
|
|
526
|
+
result: wp.array2d(dtype=output_dtype),
|
|
527
|
+
):
|
|
528
|
+
local_node_index, dof = wp.tid()
|
|
529
|
+
|
|
530
|
+
node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
531
|
+
element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
|
|
532
|
+
|
|
533
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
534
|
+
|
|
535
|
+
val_sum = accumulate_dtype(0.0)
|
|
536
|
+
|
|
537
|
+
for n in range(element_count):
|
|
538
|
+
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
|
|
539
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
540
|
+
|
|
541
|
+
coords = test.space.node_coords_in_element(
|
|
542
|
+
domain_arg,
|
|
543
|
+
_get_test_arg(),
|
|
544
|
+
element_index,
|
|
545
|
+
node_element_index.node_index_in_element,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
if coords[0] != OUTSIDE:
|
|
549
|
+
node_weight = test.space.node_quadrature_weight(
|
|
550
|
+
domain_arg,
|
|
551
|
+
_get_test_arg(),
|
|
552
|
+
element_index,
|
|
553
|
+
node_element_index.node_index_in_element,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
|
|
557
|
+
|
|
558
|
+
sample = Sample(
|
|
559
|
+
element_index,
|
|
560
|
+
coords,
|
|
561
|
+
node_index,
|
|
562
|
+
node_weight,
|
|
563
|
+
test_dof_index,
|
|
564
|
+
trial_dof_index,
|
|
565
|
+
)
|
|
566
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
567
|
+
val = integrand_func(sample, fields, values)
|
|
568
|
+
|
|
569
|
+
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
570
|
+
|
|
571
|
+
result[node_index, dof] = output_dtype(val_sum)
|
|
572
|
+
|
|
573
|
+
return integrate_kernel_fn
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def get_integrate_bilinear_kernel(
|
|
577
|
+
integrand_func: wp.Function,
|
|
578
|
+
domain: GeometryDomain,
|
|
579
|
+
quadrature: Quadrature,
|
|
580
|
+
FieldStruct: wp.codegen.Struct,
|
|
581
|
+
ValueStruct: wp.codegen.Struct,
|
|
582
|
+
test: TestField,
|
|
583
|
+
trial: TrialField,
|
|
584
|
+
output_dtype,
|
|
585
|
+
accumulate_dtype,
|
|
586
|
+
):
|
|
587
|
+
NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
|
|
588
|
+
|
|
589
|
+
def integrate_kernel_fn(
|
|
590
|
+
qp_arg: quadrature.Arg,
|
|
591
|
+
domain_arg: domain.ElementArg,
|
|
592
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
593
|
+
test_arg: test.space_restriction.NodeArg,
|
|
594
|
+
trial_partition_arg: trial.space_partition.PartitionArg,
|
|
595
|
+
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
596
|
+
fields: FieldStruct,
|
|
597
|
+
values: ValueStruct,
|
|
598
|
+
row_offsets: wp.array(dtype=int),
|
|
599
|
+
triplet_rows: wp.array(dtype=int),
|
|
600
|
+
triplet_cols: wp.array(dtype=int),
|
|
601
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
602
|
+
):
|
|
603
|
+
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
604
|
+
|
|
605
|
+
element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
|
|
606
|
+
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
607
|
+
|
|
608
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
609
|
+
|
|
610
|
+
for element in range(element_count):
|
|
611
|
+
test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
|
|
612
|
+
element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
|
|
613
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
614
|
+
|
|
615
|
+
test_dof_index = DofIndex(
|
|
616
|
+
test_element_index.node_index_in_element,
|
|
617
|
+
test_dof,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
val_sum = accumulate_dtype(0.0)
|
|
621
|
+
|
|
622
|
+
for k in range(qp_point_count):
|
|
623
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
624
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
625
|
+
|
|
626
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
627
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
|
|
628
|
+
|
|
629
|
+
sample = Sample(
|
|
630
|
+
element_index,
|
|
631
|
+
coords,
|
|
632
|
+
qp_index,
|
|
633
|
+
qp_weight,
|
|
634
|
+
test_dof_index,
|
|
635
|
+
trial_dof_index,
|
|
636
|
+
)
|
|
637
|
+
val = integrand_func(sample, fields, values)
|
|
638
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
639
|
+
|
|
640
|
+
block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
|
|
641
|
+
triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
|
|
642
|
+
|
|
643
|
+
# Set row and column indices
|
|
644
|
+
if test_dof == 0 and trial_dof == 0:
|
|
645
|
+
trial_node_index = trial.space_partition.partition_node_index(
|
|
646
|
+
trial_partition_arg,
|
|
647
|
+
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
|
|
648
|
+
)
|
|
649
|
+
triplet_rows[block_offset] = test_node_index
|
|
650
|
+
triplet_cols[block_offset] = trial_node_index
|
|
651
|
+
|
|
652
|
+
return integrate_kernel_fn
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def get_integrate_bilinear_nodal_kernel(
|
|
656
|
+
integrand_func: wp.Function,
|
|
657
|
+
domain: GeometryDomain,
|
|
658
|
+
FieldStruct: wp.codegen.Struct,
|
|
659
|
+
ValueStruct: wp.codegen.Struct,
|
|
660
|
+
test: TestField,
|
|
661
|
+
output_dtype,
|
|
662
|
+
accumulate_dtype,
|
|
663
|
+
):
|
|
664
|
+
def integrate_kernel_fn(
|
|
665
|
+
domain_arg: domain.ElementArg,
|
|
666
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
667
|
+
test_restriction_arg: test.space_restriction.NodeArg,
|
|
668
|
+
fields: FieldStruct,
|
|
669
|
+
values: ValueStruct,
|
|
670
|
+
triplet_rows: wp.array(dtype=int),
|
|
671
|
+
triplet_cols: wp.array(dtype=int),
|
|
672
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
673
|
+
):
|
|
674
|
+
local_node_index, test_dof, trial_dof = wp.tid()
|
|
675
|
+
|
|
676
|
+
element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
|
|
677
|
+
node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
678
|
+
|
|
679
|
+
val_sum = accumulate_dtype(0.0)
|
|
680
|
+
|
|
681
|
+
for n in range(element_count):
|
|
682
|
+
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
|
|
683
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
684
|
+
|
|
685
|
+
coords = test.space.node_coords_in_element(
|
|
686
|
+
domain_arg,
|
|
687
|
+
_get_test_arg(),
|
|
688
|
+
element_index,
|
|
689
|
+
node_element_index.node_index_in_element,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
if coords[0] != OUTSIDE:
|
|
693
|
+
node_weight = test.space.node_quadrature_weight(
|
|
694
|
+
domain_arg,
|
|
695
|
+
_get_test_arg(),
|
|
696
|
+
element_index,
|
|
697
|
+
node_element_index.node_index_in_element,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
701
|
+
trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
|
|
702
|
+
|
|
703
|
+
sample = Sample(
|
|
704
|
+
element_index,
|
|
705
|
+
coords,
|
|
706
|
+
node_index,
|
|
707
|
+
node_weight,
|
|
708
|
+
test_dof_index,
|
|
709
|
+
trial_dof_index,
|
|
710
|
+
)
|
|
711
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
712
|
+
val = integrand_func(sample, fields, values)
|
|
713
|
+
|
|
714
|
+
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
715
|
+
|
|
716
|
+
triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
|
|
717
|
+
triplet_rows[local_node_index] = node_index
|
|
718
|
+
triplet_cols[local_node_index] = node_index
|
|
719
|
+
|
|
720
|
+
return integrate_kernel_fn
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def _generate_integrate_kernel(
|
|
724
|
+
integrand: Integrand,
|
|
725
|
+
domain: GeometryDomain,
|
|
726
|
+
nodal: bool,
|
|
727
|
+
quadrature: Quadrature,
|
|
728
|
+
test: Optional[TestField],
|
|
729
|
+
test_name: str,
|
|
730
|
+
trial: Optional[TrialField],
|
|
731
|
+
trial_name: str,
|
|
732
|
+
fields: Dict[str, FieldLike],
|
|
733
|
+
output_dtype: type,
|
|
734
|
+
accumulate_dtype: type,
|
|
735
|
+
kernel_options: Dict[str, Any] = {},
|
|
736
|
+
) -> wp.Kernel:
|
|
737
|
+
output_dtype = wp.types.type_scalar_type(output_dtype)
|
|
738
|
+
|
|
739
|
+
# Extract field arguments from integrand
|
|
740
|
+
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
741
|
+
integrand, fields=fields, domain=domain
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
FieldStruct = _gen_field_struct(field_args)
|
|
745
|
+
ValueStruct = _gen_value_struct(value_args)
|
|
746
|
+
|
|
747
|
+
# Check if kernel exist in cache
|
|
748
|
+
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
|
|
749
|
+
if nodal:
|
|
750
|
+
kernel_suffix += "_nodal"
|
|
751
|
+
else:
|
|
752
|
+
kernel_suffix += quadrature.name
|
|
753
|
+
|
|
754
|
+
if test:
|
|
755
|
+
kernel_suffix += f"_test_{test.space_partition.name}_{test.space.name}"
|
|
756
|
+
if trial:
|
|
757
|
+
kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
|
|
758
|
+
|
|
759
|
+
kernel = cache.get_integrand_kernel(
|
|
760
|
+
integrand=integrand,
|
|
761
|
+
suffix=kernel_suffix,
|
|
762
|
+
)
|
|
763
|
+
if kernel is not None:
|
|
764
|
+
return kernel, FieldStruct, ValueStruct
|
|
765
|
+
|
|
766
|
+
# Not found in cache, transform integrand and generate kernel
|
|
767
|
+
|
|
768
|
+
integrand_func = _translate_integrand(
|
|
769
|
+
integrand,
|
|
770
|
+
field_args,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
_register_integrand_field_wrappers(integrand_func, fields)
|
|
774
|
+
|
|
775
|
+
if test is None and trial is None:
|
|
776
|
+
integrate_kernel_fn = get_integrate_constant_kernel(
|
|
777
|
+
integrand_func,
|
|
778
|
+
domain,
|
|
779
|
+
quadrature,
|
|
780
|
+
FieldStruct,
|
|
781
|
+
ValueStruct,
|
|
782
|
+
accumulate_dtype=accumulate_dtype,
|
|
783
|
+
)
|
|
784
|
+
elif trial is None:
|
|
785
|
+
if nodal:
|
|
786
|
+
integrate_kernel_fn = get_integrate_linear_nodal_kernel(
|
|
787
|
+
integrand_func,
|
|
788
|
+
domain,
|
|
789
|
+
FieldStruct,
|
|
790
|
+
ValueStruct,
|
|
791
|
+
test=test,
|
|
792
|
+
output_dtype=output_dtype,
|
|
793
|
+
accumulate_dtype=accumulate_dtype,
|
|
794
|
+
)
|
|
795
|
+
else:
|
|
796
|
+
integrate_kernel_fn = get_integrate_linear_kernel(
|
|
797
|
+
integrand_func,
|
|
798
|
+
domain,
|
|
799
|
+
quadrature,
|
|
800
|
+
FieldStruct,
|
|
801
|
+
ValueStruct,
|
|
802
|
+
test=test,
|
|
803
|
+
output_dtype=output_dtype,
|
|
804
|
+
accumulate_dtype=accumulate_dtype,
|
|
805
|
+
)
|
|
806
|
+
else:
|
|
807
|
+
if nodal:
|
|
808
|
+
integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
|
|
809
|
+
integrand_func,
|
|
810
|
+
domain,
|
|
811
|
+
FieldStruct,
|
|
812
|
+
ValueStruct,
|
|
813
|
+
test=test,
|
|
814
|
+
output_dtype=output_dtype,
|
|
815
|
+
accumulate_dtype=accumulate_dtype,
|
|
816
|
+
)
|
|
817
|
+
else:
|
|
818
|
+
integrate_kernel_fn = get_integrate_bilinear_kernel(
|
|
819
|
+
integrand_func,
|
|
820
|
+
domain,
|
|
821
|
+
quadrature,
|
|
822
|
+
FieldStruct,
|
|
823
|
+
ValueStruct,
|
|
824
|
+
test=test,
|
|
825
|
+
trial=trial,
|
|
826
|
+
output_dtype=output_dtype,
|
|
827
|
+
accumulate_dtype=accumulate_dtype,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
kernel = cache.get_integrand_kernel(
|
|
831
|
+
integrand=integrand,
|
|
832
|
+
kernel_fn=integrate_kernel_fn,
|
|
833
|
+
suffix=kernel_suffix,
|
|
834
|
+
kernel_options=kernel_options,
|
|
835
|
+
code_transformers=[
|
|
836
|
+
PassFieldArgsToIntegrand(
|
|
837
|
+
arg_names=integrand.argspec.args,
|
|
838
|
+
field_args=field_args.keys(),
|
|
839
|
+
value_args=value_args.keys(),
|
|
840
|
+
sample_name=sample_name,
|
|
841
|
+
domain_name=domain_name,
|
|
842
|
+
test_name=test_name,
|
|
843
|
+
trial_name=trial_name,
|
|
844
|
+
)
|
|
845
|
+
],
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
return kernel, FieldStruct, ValueStruct
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _launch_integrate_kernel(
|
|
852
|
+
kernel: wp.Kernel,
|
|
853
|
+
FieldStruct: wp.codegen.Struct,
|
|
854
|
+
ValueStruct: wp.codegen.Struct,
|
|
855
|
+
domain: GeometryDomain,
|
|
856
|
+
nodal: bool,
|
|
857
|
+
quadrature: Quadrature,
|
|
858
|
+
test: Optional[TestField],
|
|
859
|
+
trial: Optional[TrialField],
|
|
860
|
+
fields: Dict[str, FieldLike],
|
|
861
|
+
values: Dict[str, Any],
|
|
862
|
+
accumulate_dtype: type,
|
|
863
|
+
temporary_store: Optional[cache.TemporaryStore],
|
|
864
|
+
output_dtype: type,
|
|
865
|
+
output: Optional[Union[wp.array, BsrMatrix]],
|
|
866
|
+
device,
|
|
867
|
+
):
|
|
868
|
+
# Set-up launch arguments
|
|
869
|
+
domain_elt_arg = domain.element_arg_value(device=device)
|
|
870
|
+
domain_elt_index_arg = domain.element_index_arg_value(device=device)
|
|
871
|
+
|
|
872
|
+
if quadrature is not None:
|
|
873
|
+
qp_arg = quadrature.arg_value(device=device)
|
|
874
|
+
|
|
875
|
+
field_arg_values = FieldStruct()
|
|
876
|
+
for k, v in fields.items():
|
|
877
|
+
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
878
|
+
|
|
879
|
+
value_struct_values = ValueStruct()
|
|
880
|
+
for k, v in values.items():
|
|
881
|
+
setattr(value_struct_values, k, v)
|
|
882
|
+
|
|
883
|
+
# Constant form
|
|
884
|
+
if test is None and trial is None:
|
|
885
|
+
if output is not None and output.dtype == accumulate_dtype:
|
|
886
|
+
if output.size < 1:
|
|
887
|
+
raise RuntimeError("Output array must be of size at least 1")
|
|
888
|
+
accumulate_array = output
|
|
889
|
+
else:
|
|
890
|
+
accumulate_temporary = cache.borrow_temporary(
|
|
891
|
+
shape=(1),
|
|
892
|
+
device=device,
|
|
893
|
+
dtype=accumulate_dtype,
|
|
894
|
+
temporary_store=temporary_store,
|
|
895
|
+
requires_grad=output is not None and output.requires_grad,
|
|
896
|
+
)
|
|
897
|
+
accumulate_array = accumulate_temporary.array
|
|
898
|
+
|
|
899
|
+
accumulate_array.zero_()
|
|
900
|
+
wp.launch(
|
|
901
|
+
kernel=kernel,
|
|
902
|
+
dim=domain.element_count(),
|
|
903
|
+
inputs=[
|
|
904
|
+
qp_arg,
|
|
905
|
+
domain_elt_arg,
|
|
906
|
+
domain_elt_index_arg,
|
|
907
|
+
field_arg_values,
|
|
908
|
+
value_struct_values,
|
|
909
|
+
accumulate_array,
|
|
910
|
+
],
|
|
911
|
+
device=device,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
if output == accumulate_array:
|
|
915
|
+
return output
|
|
916
|
+
elif output is None:
|
|
917
|
+
return accumulate_array.numpy()[0]
|
|
918
|
+
else:
|
|
919
|
+
array_cast(in_array=accumulate_array, out_array=output)
|
|
920
|
+
return output
|
|
921
|
+
|
|
922
|
+
test_arg = test.space_restriction.node_arg(device=device)
|
|
923
|
+
|
|
924
|
+
# Linear form
|
|
925
|
+
if trial is None:
|
|
926
|
+
# If an output array is provided with the correct type, accumulate directly into it
|
|
927
|
+
# Otherwise, grab a temporary array
|
|
928
|
+
if output is None:
|
|
929
|
+
if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
|
|
930
|
+
output_shape = (test.space_partition.node_count(),)
|
|
931
|
+
elif type_length(output_dtype) == 1:
|
|
932
|
+
output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
|
|
933
|
+
else:
|
|
934
|
+
raise RuntimeError(
|
|
935
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
output_temporary = cache.borrow_temporary(
|
|
939
|
+
temporary_store=temporary_store,
|
|
940
|
+
shape=output_shape,
|
|
941
|
+
dtype=output_dtype,
|
|
942
|
+
device=device,
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
output = output_temporary.array
|
|
946
|
+
|
|
947
|
+
else:
|
|
948
|
+
output_temporary = None
|
|
949
|
+
|
|
950
|
+
if output.shape[0] < test.space_partition.node_count():
|
|
951
|
+
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
952
|
+
|
|
953
|
+
output_dtype = output.dtype
|
|
954
|
+
if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
|
|
955
|
+
if type_length(output_dtype) != 1:
|
|
956
|
+
raise RuntimeError(
|
|
957
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
|
|
958
|
+
)
|
|
959
|
+
if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
|
|
960
|
+
raise RuntimeError(
|
|
961
|
+
f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
# Launch the integration on the kernel on a 2d scalar view of the actual array
|
|
965
|
+
output.zero_()
|
|
966
|
+
|
|
967
|
+
def as_2d_array(array):
|
|
968
|
+
return wp.array(
|
|
969
|
+
data=None,
|
|
970
|
+
ptr=array.ptr,
|
|
971
|
+
capacity=array.capacity,
|
|
972
|
+
owner=False,
|
|
973
|
+
device=array.device,
|
|
974
|
+
shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
|
|
975
|
+
dtype=wp.types.type_scalar_type(output_dtype),
|
|
976
|
+
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
output_view = output if output.ndim == 2 else as_2d_array(output)
|
|
980
|
+
|
|
981
|
+
if nodal:
|
|
982
|
+
wp.launch(
|
|
983
|
+
kernel=kernel,
|
|
984
|
+
dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
|
|
985
|
+
inputs=[
|
|
986
|
+
domain_elt_arg,
|
|
987
|
+
domain_elt_index_arg,
|
|
988
|
+
test_arg,
|
|
989
|
+
field_arg_values,
|
|
990
|
+
value_struct_values,
|
|
991
|
+
output_view,
|
|
992
|
+
],
|
|
993
|
+
device=device,
|
|
994
|
+
)
|
|
995
|
+
else:
|
|
996
|
+
wp.launch(
|
|
997
|
+
kernel=kernel,
|
|
998
|
+
dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
|
|
999
|
+
inputs=[
|
|
1000
|
+
qp_arg,
|
|
1001
|
+
domain_elt_arg,
|
|
1002
|
+
domain_elt_index_arg,
|
|
1003
|
+
test_arg,
|
|
1004
|
+
field_arg_values,
|
|
1005
|
+
value_struct_values,
|
|
1006
|
+
output_view,
|
|
1007
|
+
],
|
|
1008
|
+
device=device,
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if output_temporary is not None:
|
|
1012
|
+
return output_temporary.detach()
|
|
1013
|
+
|
|
1014
|
+
return output
|
|
1015
|
+
|
|
1016
|
+
# Bilinear form
|
|
1017
|
+
|
|
1018
|
+
if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
|
|
1019
|
+
block_type = output_dtype
|
|
1020
|
+
else:
|
|
1021
|
+
block_type = cache.cached_mat_type(
|
|
1022
|
+
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
if nodal:
|
|
1026
|
+
nnz = test.space_restriction.node_count()
|
|
1027
|
+
else:
|
|
1028
|
+
nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
|
|
1029
|
+
|
|
1030
|
+
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1031
|
+
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1032
|
+
triplet_values_temp = cache.borrow_temporary(
|
|
1033
|
+
temporary_store,
|
|
1034
|
+
shape=(
|
|
1035
|
+
nnz,
|
|
1036
|
+
test.space.VALUE_DOF_COUNT,
|
|
1037
|
+
trial.space.VALUE_DOF_COUNT,
|
|
1038
|
+
),
|
|
1039
|
+
dtype=output_dtype,
|
|
1040
|
+
device=device,
|
|
1041
|
+
)
|
|
1042
|
+
triplet_cols = triplet_cols_temp.array
|
|
1043
|
+
triplet_rows = triplet_rows_temp.array
|
|
1044
|
+
triplet_values = triplet_values_temp.array
|
|
1045
|
+
|
|
1046
|
+
triplet_values.zero_()
|
|
1047
|
+
|
|
1048
|
+
if nodal:
|
|
1049
|
+
wp.launch(
|
|
1050
|
+
kernel=kernel,
|
|
1051
|
+
dim=triplet_values.shape,
|
|
1052
|
+
inputs=[
|
|
1053
|
+
domain_elt_arg,
|
|
1054
|
+
domain_elt_index_arg,
|
|
1055
|
+
test_arg,
|
|
1056
|
+
field_arg_values,
|
|
1057
|
+
value_struct_values,
|
|
1058
|
+
triplet_rows,
|
|
1059
|
+
triplet_cols,
|
|
1060
|
+
triplet_values,
|
|
1061
|
+
],
|
|
1062
|
+
device=device,
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
else:
|
|
1066
|
+
offsets = test.space_restriction.partition_element_offsets()
|
|
1067
|
+
|
|
1068
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1069
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1070
|
+
wp.launch(
|
|
1071
|
+
kernel=kernel,
|
|
1072
|
+
dim=(
|
|
1073
|
+
test.space_restriction.node_count(),
|
|
1074
|
+
trial.space.topology.NODES_PER_ELEMENT,
|
|
1075
|
+
test.space.VALUE_DOF_COUNT,
|
|
1076
|
+
trial.space.VALUE_DOF_COUNT,
|
|
1077
|
+
),
|
|
1078
|
+
inputs=[
|
|
1079
|
+
qp_arg,
|
|
1080
|
+
domain_elt_arg,
|
|
1081
|
+
domain_elt_index_arg,
|
|
1082
|
+
test_arg,
|
|
1083
|
+
trial_partition_arg,
|
|
1084
|
+
trial_topology_arg,
|
|
1085
|
+
field_arg_values,
|
|
1086
|
+
value_struct_values,
|
|
1087
|
+
offsets,
|
|
1088
|
+
triplet_rows,
|
|
1089
|
+
triplet_cols,
|
|
1090
|
+
triplet_values,
|
|
1091
|
+
],
|
|
1092
|
+
device=device,
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
if output is not None:
|
|
1096
|
+
if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
|
|
1097
|
+
raise RuntimeError(
|
|
1098
|
+
f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
else:
|
|
1102
|
+
output = bsr_zeros(
|
|
1103
|
+
rows_of_blocks=test.space_partition.node_count(),
|
|
1104
|
+
cols_of_blocks=trial.space_partition.node_count(),
|
|
1105
|
+
block_type=block_type,
|
|
1106
|
+
device=device,
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
|
|
1110
|
+
|
|
1111
|
+
# Do not wait for garbage collection
|
|
1112
|
+
triplet_values_temp.release()
|
|
1113
|
+
triplet_rows_temp.release()
|
|
1114
|
+
triplet_cols_temp.release()
|
|
1115
|
+
|
|
1116
|
+
return output
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
def integrate(
|
|
1120
|
+
integrand: Integrand,
|
|
1121
|
+
domain: Optional[GeometryDomain] = None,
|
|
1122
|
+
quadrature: Optional[Quadrature] = None,
|
|
1123
|
+
nodal: bool = False,
|
|
1124
|
+
fields: Dict[str, FieldLike] = {},
|
|
1125
|
+
values: Dict[str, Any] = {},
|
|
1126
|
+
accumulate_dtype: type = wp.float64,
|
|
1127
|
+
output_dtype: Optional[type] = None,
|
|
1128
|
+
output: Optional[Union[BsrMatrix, wp.array]] = None,
|
|
1129
|
+
device=None,
|
|
1130
|
+
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
1131
|
+
kernel_options: Dict[str, Any] = {},
|
|
1132
|
+
):
|
|
1133
|
+
"""
|
|
1134
|
+
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
integrand: Form to be integrated, must have :func:`integrand` decorator
|
|
1138
|
+
domain: Integration domain. If None, deduced from fields
|
|
1139
|
+
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1140
|
+
nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1141
|
+
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1142
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1143
|
+
temporary_store: shared pool from which to allocate temporary arrays
|
|
1144
|
+
accumulate_dtype: Scalar type to be used for accumulating integration samples
|
|
1145
|
+
output: Sparse matrix or warp array into which to store the result of the integration
|
|
1146
|
+
output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
|
|
1147
|
+
device: Device on which to perform the integration
|
|
1148
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1149
|
+
"""
|
|
1150
|
+
if not isinstance(integrand, Integrand):
|
|
1151
|
+
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
|
|
1152
|
+
|
|
1153
|
+
test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
|
|
1154
|
+
|
|
1155
|
+
if domain is None:
|
|
1156
|
+
if quadrature is not None:
|
|
1157
|
+
domain = quadrature.domain
|
|
1158
|
+
elif test is not None:
|
|
1159
|
+
domain = test.domain
|
|
1160
|
+
|
|
1161
|
+
if domain is None:
|
|
1162
|
+
raise ValueError("Must provide at least one of domain, quadrature, or test field")
|
|
1163
|
+
if test is not None and domain != test.domain:
|
|
1164
|
+
raise NotImplementedError("Mixing integration and test domain is not supported yet")
|
|
1165
|
+
|
|
1166
|
+
if nodal:
|
|
1167
|
+
if quadrature is not None:
|
|
1168
|
+
raise ValueError("Cannot specify quadrature for nodal integration")
|
|
1169
|
+
|
|
1170
|
+
if test is None:
|
|
1171
|
+
raise ValueError("Nodal integration requires specifying a test function")
|
|
1172
|
+
|
|
1173
|
+
if trial is not None and test.space_partition != trial.space_partition:
|
|
1174
|
+
raise ValueError(
|
|
1175
|
+
"Bilinear nodal integration requires test and trial to be defined on the same function space"
|
|
1176
|
+
)
|
|
1177
|
+
else:
|
|
1178
|
+
if quadrature is None:
|
|
1179
|
+
order = sum(field.degree for field in fields.values())
|
|
1180
|
+
quadrature = RegularQuadrature(domain=domain, order=order)
|
|
1181
|
+
elif domain != quadrature.domain:
|
|
1182
|
+
raise ValueError("Incompatible integration and quadrature domain")
|
|
1183
|
+
|
|
1184
|
+
# Canonicalize types
|
|
1185
|
+
accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
|
|
1186
|
+
if output is not None:
|
|
1187
|
+
if isinstance(output, BsrMatrix):
|
|
1188
|
+
output_dtype = output.scalar_type
|
|
1189
|
+
else:
|
|
1190
|
+
output_dtype = output.dtype
|
|
1191
|
+
elif output_dtype is None:
|
|
1192
|
+
output_dtype = accumulate_dtype
|
|
1193
|
+
else:
|
|
1194
|
+
output_dtype = wp.types.type_to_warp(output_dtype)
|
|
1195
|
+
|
|
1196
|
+
kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
|
|
1197
|
+
integrand=integrand,
|
|
1198
|
+
domain=domain,
|
|
1199
|
+
nodal=nodal,
|
|
1200
|
+
quadrature=quadrature,
|
|
1201
|
+
test=test,
|
|
1202
|
+
test_name=test_name,
|
|
1203
|
+
trial=trial,
|
|
1204
|
+
trial_name=trial_name,
|
|
1205
|
+
fields=fields,
|
|
1206
|
+
accumulate_dtype=accumulate_dtype,
|
|
1207
|
+
output_dtype=output_dtype,
|
|
1208
|
+
kernel_options=kernel_options,
|
|
1209
|
+
)
|
|
1210
|
+
|
|
1211
|
+
return _launch_integrate_kernel(
|
|
1212
|
+
kernel=kernel,
|
|
1213
|
+
FieldStruct=FieldStruct,
|
|
1214
|
+
ValueStruct=ValueStruct,
|
|
1215
|
+
domain=domain,
|
|
1216
|
+
nodal=nodal,
|
|
1217
|
+
quadrature=quadrature,
|
|
1218
|
+
test=test,
|
|
1219
|
+
trial=trial,
|
|
1220
|
+
fields=fields,
|
|
1221
|
+
values=values,
|
|
1222
|
+
accumulate_dtype=accumulate_dtype,
|
|
1223
|
+
temporary_store=temporary_store,
|
|
1224
|
+
output_dtype=output_dtype,
|
|
1225
|
+
output=output,
|
|
1226
|
+
device=device,
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
|
|
1230
|
+
def get_interpolate_to_field_function(
|
|
1231
|
+
integrand_func: wp.Function,
|
|
1232
|
+
domain: GeometryDomain,
|
|
1233
|
+
FieldStruct: wp.codegen.Struct,
|
|
1234
|
+
ValueStruct: wp.codegen.Struct,
|
|
1235
|
+
dest: FieldRestriction,
|
|
1236
|
+
):
|
|
1237
|
+
value_type = dest.space.dtype
|
|
1238
|
+
|
|
1239
|
+
def interpolate_to_field_fn(
|
|
1240
|
+
local_node_index: int,
|
|
1241
|
+
domain_arg: domain.ElementArg,
|
|
1242
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1243
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1244
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1245
|
+
fields: FieldStruct,
|
|
1246
|
+
values: ValueStruct,
|
|
1247
|
+
):
|
|
1248
|
+
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1249
|
+
element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
|
|
1250
|
+
|
|
1251
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1252
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1253
|
+
node_weight = 1.0
|
|
1254
|
+
|
|
1255
|
+
# Volume-weighted average across elements
|
|
1256
|
+
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
|
|
1257
|
+
|
|
1258
|
+
val_sum = value_type(0.0)
|
|
1259
|
+
vol_sum = float(0.0)
|
|
1260
|
+
|
|
1261
|
+
for n in range(element_count):
|
|
1262
|
+
node_element_index = dest.space_restriction.node_element_index(dest_node_arg, local_node_index, n)
|
|
1263
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
1264
|
+
|
|
1265
|
+
coords = dest.space.node_coords_in_element(
|
|
1266
|
+
domain_arg,
|
|
1267
|
+
dest_eval_arg.space_arg,
|
|
1268
|
+
element_index,
|
|
1269
|
+
node_element_index.node_index_in_element,
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
if coords[0] != OUTSIDE:
|
|
1273
|
+
sample = Sample(
|
|
1274
|
+
element_index,
|
|
1275
|
+
coords,
|
|
1276
|
+
node_index,
|
|
1277
|
+
node_weight,
|
|
1278
|
+
test_dof_index,
|
|
1279
|
+
trial_dof_index,
|
|
1280
|
+
)
|
|
1281
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
1282
|
+
val = integrand_func(sample, fields, values)
|
|
1283
|
+
|
|
1284
|
+
vol_sum += vol
|
|
1285
|
+
val_sum += vol * val
|
|
1286
|
+
|
|
1287
|
+
return val_sum, vol_sum
|
|
1288
|
+
|
|
1289
|
+
return interpolate_to_field_fn
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
def get_interpolate_to_field_kernel(
|
|
1293
|
+
interpolate_to_field_fn: wp.Function,
|
|
1294
|
+
domain: GeometryDomain,
|
|
1295
|
+
FieldStruct: wp.codegen.Struct,
|
|
1296
|
+
ValueStruct: wp.codegen.Struct,
|
|
1297
|
+
dest: FieldRestriction,
|
|
1298
|
+
):
|
|
1299
|
+
def interpolate_to_field_kernel_fn(
|
|
1300
|
+
domain_arg: domain.ElementArg,
|
|
1301
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1302
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1303
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1304
|
+
fields: FieldStruct,
|
|
1305
|
+
values: ValueStruct,
|
|
1306
|
+
):
|
|
1307
|
+
local_node_index = wp.tid()
|
|
1308
|
+
|
|
1309
|
+
val_sum, vol_sum = interpolate_to_field_fn(
|
|
1310
|
+
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
if vol_sum > 0.0:
|
|
1314
|
+
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1315
|
+
dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
|
|
1316
|
+
|
|
1317
|
+
return interpolate_to_field_kernel_fn
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
def get_interpolate_to_array_kernel(
|
|
1321
|
+
integrand_func: wp.Function,
|
|
1322
|
+
domain: GeometryDomain,
|
|
1323
|
+
quadrature: Quadrature,
|
|
1324
|
+
FieldStruct: wp.codegen.Struct,
|
|
1325
|
+
ValueStruct: wp.codegen.Struct,
|
|
1326
|
+
value_type: type,
|
|
1327
|
+
):
|
|
1328
|
+
def interpolate_to_array_kernel_fn(
|
|
1329
|
+
qp_arg: quadrature.Arg,
|
|
1330
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1331
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1332
|
+
fields: FieldStruct,
|
|
1333
|
+
values: ValueStruct,
|
|
1334
|
+
result: wp.array(dtype=value_type),
|
|
1335
|
+
):
|
|
1336
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
1337
|
+
|
|
1338
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1339
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1340
|
+
|
|
1341
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1342
|
+
for k in range(qp_point_count):
|
|
1343
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1344
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1345
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1346
|
+
|
|
1347
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1348
|
+
|
|
1349
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1350
|
+
|
|
1351
|
+
return interpolate_to_array_kernel_fn
|
|
1352
|
+
|
|
1353
|
+
|
|
1354
|
+
def get_interpolate_nonvalued_kernel(
|
|
1355
|
+
integrand_func: wp.Function,
|
|
1356
|
+
domain: GeometryDomain,
|
|
1357
|
+
quadrature: Quadrature,
|
|
1358
|
+
FieldStruct: wp.codegen.Struct,
|
|
1359
|
+
ValueStruct: wp.codegen.Struct,
|
|
1360
|
+
):
|
|
1361
|
+
def interpolate_nonvalued_kernel_fn(
|
|
1362
|
+
qp_arg: quadrature.Arg,
|
|
1363
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1364
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1365
|
+
fields: FieldStruct,
|
|
1366
|
+
values: ValueStruct,
|
|
1367
|
+
):
|
|
1368
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
1369
|
+
|
|
1370
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1371
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1372
|
+
|
|
1373
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1374
|
+
for k in range(qp_point_count):
|
|
1375
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1376
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1377
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1378
|
+
|
|
1379
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1380
|
+
integrand_func(sample, fields, values)
|
|
1381
|
+
|
|
1382
|
+
return interpolate_nonvalued_kernel_fn
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _generate_interpolate_kernel(
|
|
1386
|
+
integrand: Integrand,
|
|
1387
|
+
domain: GeometryDomain,
|
|
1388
|
+
dest: Optional[Union[FieldLike, wp.array]],
|
|
1389
|
+
quadrature: Optional[Quadrature],
|
|
1390
|
+
fields: Dict[str, FieldLike],
|
|
1391
|
+
kernel_options: Dict[str, Any] = {},
|
|
1392
|
+
) -> wp.Kernel:
|
|
1393
|
+
# Extract field arguments from integrand
|
|
1394
|
+
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
1395
|
+
integrand, fields=fields, domain=domain
|
|
1396
|
+
)
|
|
1397
|
+
|
|
1398
|
+
# Generate field struct
|
|
1399
|
+
integrand_func = _translate_integrand(
|
|
1400
|
+
integrand,
|
|
1401
|
+
field_args,
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
_register_integrand_field_wrappers(integrand_func, fields)
|
|
1405
|
+
|
|
1406
|
+
FieldStruct = _gen_field_struct(field_args)
|
|
1407
|
+
ValueStruct = _gen_value_struct(value_args)
|
|
1408
|
+
|
|
1409
|
+
# Check if kernel exist in cache
|
|
1410
|
+
if isinstance(dest, FieldRestriction):
|
|
1411
|
+
kernel_suffix = (
|
|
1412
|
+
f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
|
|
1413
|
+
)
|
|
1414
|
+
elif wp.types.is_array(dest):
|
|
1415
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
|
|
1416
|
+
else:
|
|
1417
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
|
|
1418
|
+
|
|
1419
|
+
kernel = cache.get_integrand_kernel(
|
|
1420
|
+
integrand=integrand,
|
|
1421
|
+
suffix=kernel_suffix,
|
|
1422
|
+
)
|
|
1423
|
+
if kernel is not None:
|
|
1424
|
+
return kernel, FieldStruct, ValueStruct
|
|
1425
|
+
|
|
1426
|
+
# Generate interpolation kernel
|
|
1427
|
+
if isinstance(dest, FieldRestriction):
|
|
1428
|
+
# need to split into kernel + function for diffferentiability
|
|
1429
|
+
interpolate_fn = get_interpolate_to_field_function(
|
|
1430
|
+
integrand_func,
|
|
1431
|
+
domain,
|
|
1432
|
+
dest=dest,
|
|
1433
|
+
FieldStruct=FieldStruct,
|
|
1434
|
+
ValueStruct=ValueStruct,
|
|
1435
|
+
)
|
|
1436
|
+
|
|
1437
|
+
interpolate_fn = cache.get_integrand_function(
|
|
1438
|
+
integrand=integrand,
|
|
1439
|
+
func=interpolate_fn,
|
|
1440
|
+
suffix=kernel_suffix,
|
|
1441
|
+
code_transformers=[
|
|
1442
|
+
PassFieldArgsToIntegrand(
|
|
1443
|
+
arg_names=integrand.argspec.args,
|
|
1444
|
+
field_args=field_args.keys(),
|
|
1445
|
+
value_args=value_args.keys(),
|
|
1446
|
+
sample_name=sample_name,
|
|
1447
|
+
domain_name=domain_name,
|
|
1448
|
+
)
|
|
1449
|
+
],
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
interpolate_kernel_fn = get_interpolate_to_field_kernel(
|
|
1453
|
+
interpolate_fn,
|
|
1454
|
+
domain,
|
|
1455
|
+
dest=dest,
|
|
1456
|
+
FieldStruct=FieldStruct,
|
|
1457
|
+
ValueStruct=ValueStruct,
|
|
1458
|
+
)
|
|
1459
|
+
elif wp.types.is_array(dest):
|
|
1460
|
+
interpolate_kernel_fn = get_interpolate_to_array_kernel(
|
|
1461
|
+
integrand_func,
|
|
1462
|
+
domain=domain,
|
|
1463
|
+
quadrature=quadrature,
|
|
1464
|
+
value_type=dest.dtype,
|
|
1465
|
+
FieldStruct=FieldStruct,
|
|
1466
|
+
ValueStruct=ValueStruct,
|
|
1467
|
+
)
|
|
1468
|
+
else:
|
|
1469
|
+
interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
|
|
1470
|
+
integrand_func,
|
|
1471
|
+
domain=domain,
|
|
1472
|
+
quadrature=quadrature,
|
|
1473
|
+
FieldStruct=FieldStruct,
|
|
1474
|
+
ValueStruct=ValueStruct,
|
|
1475
|
+
)
|
|
1476
|
+
|
|
1477
|
+
kernel = cache.get_integrand_kernel(
|
|
1478
|
+
integrand=integrand,
|
|
1479
|
+
kernel_fn=interpolate_kernel_fn,
|
|
1480
|
+
suffix=kernel_suffix,
|
|
1481
|
+
kernel_options=kernel_options,
|
|
1482
|
+
code_transformers=[
|
|
1483
|
+
PassFieldArgsToIntegrand(
|
|
1484
|
+
arg_names=integrand.argspec.args,
|
|
1485
|
+
field_args=field_args.keys(),
|
|
1486
|
+
value_args=value_args.keys(),
|
|
1487
|
+
sample_name=sample_name,
|
|
1488
|
+
domain_name=domain_name,
|
|
1489
|
+
)
|
|
1490
|
+
],
|
|
1491
|
+
)
|
|
1492
|
+
|
|
1493
|
+
return kernel, FieldStruct, ValueStruct
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
def _launch_interpolate_kernel(
|
|
1497
|
+
kernel: wp.kernel,
|
|
1498
|
+
FieldStruct: wp.codegen.Struct,
|
|
1499
|
+
ValueStruct: wp.codegen.Struct,
|
|
1500
|
+
domain: GeometryDomain,
|
|
1501
|
+
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
1502
|
+
quadrature: Optional[Quadrature],
|
|
1503
|
+
fields: Dict[str, FieldLike],
|
|
1504
|
+
values: Dict[str, Any],
|
|
1505
|
+
device,
|
|
1506
|
+
) -> wp.Kernel:
|
|
1507
|
+
# Set-up launch arguments
|
|
1508
|
+
elt_arg = domain.element_arg_value(device=device)
|
|
1509
|
+
elt_index_arg = domain.element_index_arg_value(device=device)
|
|
1510
|
+
|
|
1511
|
+
field_arg_values = FieldStruct()
|
|
1512
|
+
for k, v in fields.items():
|
|
1513
|
+
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
1514
|
+
|
|
1515
|
+
value_struct_values = ValueStruct()
|
|
1516
|
+
for k, v in values.items():
|
|
1517
|
+
setattr(value_struct_values, k, v)
|
|
1518
|
+
|
|
1519
|
+
if isinstance(dest, FieldRestriction):
|
|
1520
|
+
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
1521
|
+
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
1522
|
+
|
|
1523
|
+
wp.launch(
|
|
1524
|
+
kernel=kernel,
|
|
1525
|
+
dim=dest.space_restriction.node_count(),
|
|
1526
|
+
inputs=[
|
|
1527
|
+
elt_arg,
|
|
1528
|
+
elt_index_arg,
|
|
1529
|
+
dest_node_arg,
|
|
1530
|
+
dest_eval_arg,
|
|
1531
|
+
field_arg_values,
|
|
1532
|
+
value_struct_values,
|
|
1533
|
+
],
|
|
1534
|
+
device=device,
|
|
1535
|
+
)
|
|
1536
|
+
elif wp.types.is_array(dest):
|
|
1537
|
+
qp_arg = quadrature.arg_value(device)
|
|
1538
|
+
wp.launch(
|
|
1539
|
+
kernel=kernel,
|
|
1540
|
+
dim=domain.element_count(),
|
|
1541
|
+
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
1542
|
+
device=device,
|
|
1543
|
+
)
|
|
1544
|
+
else:
|
|
1545
|
+
qp_arg = quadrature.arg_value(device)
|
|
1546
|
+
wp.launch(
|
|
1547
|
+
kernel=kernel,
|
|
1548
|
+
dim=domain.element_count(),
|
|
1549
|
+
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
|
|
1550
|
+
device=device,
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
|
|
1554
|
+
def interpolate(
|
|
1555
|
+
integrand: Integrand,
|
|
1556
|
+
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
1557
|
+
quadrature: Optional[Quadrature] = None,
|
|
1558
|
+
fields: Dict[str, FieldLike] = {},
|
|
1559
|
+
values: Dict[str, Any] = {},
|
|
1560
|
+
device=None,
|
|
1561
|
+
kernel_options: Dict[str, Any] = {},
|
|
1562
|
+
):
|
|
1563
|
+
"""
|
|
1564
|
+
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
1565
|
+
|
|
1566
|
+
Args:
|
|
1567
|
+
integrand: Function to be interpolated, must have :func:`integrand` decorator
|
|
1568
|
+
dest: Where to store the interpolation result. Can be either
|
|
1569
|
+
|
|
1570
|
+
- a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
|
|
1571
|
+
- a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
|
|
1572
|
+
- ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is responsible for dealing with the interpolation result.
|
|
1573
|
+
quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
|
|
1574
|
+
fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
|
|
1575
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1576
|
+
device: Device on which to perform the interpolation
|
|
1577
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1578
|
+
"""
|
|
1579
|
+
if not isinstance(integrand, Integrand):
|
|
1580
|
+
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
1581
|
+
|
|
1582
|
+
test, _, trial, __ = _get_test_and_trial_fields(fields)
|
|
1583
|
+
if test is not None or trial is not None:
|
|
1584
|
+
raise ValueError("Test or Trial fields should not be used for interpolation")
|
|
1585
|
+
|
|
1586
|
+
if isinstance(dest, DiscreteField):
|
|
1587
|
+
dest = make_restriction(dest)
|
|
1588
|
+
|
|
1589
|
+
if isinstance(dest, FieldRestriction):
|
|
1590
|
+
domain = dest.domain
|
|
1591
|
+
else:
|
|
1592
|
+
if quadrature is None:
|
|
1593
|
+
raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
|
|
1594
|
+
|
|
1595
|
+
domain = quadrature.domain
|
|
1596
|
+
|
|
1597
|
+
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
1598
|
+
integrand=integrand,
|
|
1599
|
+
domain=domain,
|
|
1600
|
+
dest=dest,
|
|
1601
|
+
quadrature=quadrature,
|
|
1602
|
+
fields=fields,
|
|
1603
|
+
kernel_options=kernel_options,
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
return _launch_interpolate_kernel(
|
|
1607
|
+
kernel=kernel,
|
|
1608
|
+
FieldStruct=FieldStruct,
|
|
1609
|
+
ValueStruct=ValueStruct,
|
|
1610
|
+
domain=domain,
|
|
1611
|
+
dest=dest,
|
|
1612
|
+
quadrature=quadrature,
|
|
1613
|
+
fields=fields,
|
|
1614
|
+
values=values,
|
|
1615
|
+
device=device,
|
|
1616
|
+
)
|