warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/fem/integrate.py
ADDED
|
@@ -0,0 +1,2335 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import ast
|
|
17
|
+
import inspect
|
|
18
|
+
import textwrap
|
|
19
|
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
|
|
20
|
+
|
|
21
|
+
import warp as wp
|
|
22
|
+
from warp.codegen import get_annotations
|
|
23
|
+
from warp.fem import cache
|
|
24
|
+
from warp.fem.domain import GeometryDomain
|
|
25
|
+
from warp.fem.field import (
|
|
26
|
+
DiscreteField,
|
|
27
|
+
FieldLike,
|
|
28
|
+
FieldRestriction,
|
|
29
|
+
GeometryField,
|
|
30
|
+
LocalTestField,
|
|
31
|
+
LocalTrialField,
|
|
32
|
+
TestField,
|
|
33
|
+
TrialField,
|
|
34
|
+
make_restriction,
|
|
35
|
+
)
|
|
36
|
+
from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
|
|
37
|
+
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
38
|
+
from warp.fem.operator import Integrand, Operator, at_node, integrand
|
|
39
|
+
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
40
|
+
from warp.fem.types import (
|
|
41
|
+
NULL_DOF_INDEX,
|
|
42
|
+
NULL_ELEMENT_INDEX,
|
|
43
|
+
NULL_NODE_INDEX,
|
|
44
|
+
OUTSIDE,
|
|
45
|
+
Coords,
|
|
46
|
+
DofIndex,
|
|
47
|
+
Domain,
|
|
48
|
+
Field,
|
|
49
|
+
Sample,
|
|
50
|
+
make_free_sample,
|
|
51
|
+
)
|
|
52
|
+
from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
|
|
53
|
+
from warp.types import type_length
|
|
54
|
+
from warp.utils import array_cast
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _resolve_path(func, node):
|
|
58
|
+
"""
|
|
59
|
+
Resolves variable and path from ast node/attribute (adapted from warp.codegen)
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
modules = []
|
|
63
|
+
|
|
64
|
+
while isinstance(node, ast.Attribute):
|
|
65
|
+
modules.append(node.attr)
|
|
66
|
+
node = node.value
|
|
67
|
+
|
|
68
|
+
if isinstance(node, ast.Name):
|
|
69
|
+
modules.append(node.id)
|
|
70
|
+
|
|
71
|
+
# reverse list since ast presents it backward order
|
|
72
|
+
path = [*reversed(modules)]
|
|
73
|
+
|
|
74
|
+
if len(path) == 0:
|
|
75
|
+
return None, path
|
|
76
|
+
|
|
77
|
+
# try and evaluate object path
|
|
78
|
+
try:
|
|
79
|
+
# Look up the closure info and append it to adj.func.__globals__
|
|
80
|
+
# in case you want to define a kernel inside a function and refer
|
|
81
|
+
# to variables you've declared inside that function:
|
|
82
|
+
capturedvars = dict(zip(func.__code__.co_freevars, [c.cell_contents for c in (func.__closure__ or [])]))
|
|
83
|
+
|
|
84
|
+
vars_dict = {**func.__globals__, **capturedvars}
|
|
85
|
+
func = eval(".".join(path), vars_dict)
|
|
86
|
+
return func, path
|
|
87
|
+
except (NameError, AttributeError):
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
return None, path
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class IntegrandVisitor(ast.NodeTransformer):
|
|
94
|
+
class FieldInfo(NamedTuple):
|
|
95
|
+
field: FieldLike
|
|
96
|
+
abstract_type: type
|
|
97
|
+
concrete_type: type
|
|
98
|
+
root_arg_name: type
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
integrand: Integrand,
|
|
103
|
+
field_info: Dict[str, FieldInfo],
|
|
104
|
+
):
|
|
105
|
+
self._integrand = integrand
|
|
106
|
+
self._field_symbols = field_info.copy()
|
|
107
|
+
self._field_nodes = {}
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
111
|
+
def get_concrete_type(field: Union[FieldLike, Domain]):
|
|
112
|
+
if isinstance(field, FieldLike):
|
|
113
|
+
return field.ElementEvalArg
|
|
114
|
+
return field.ElementArg
|
|
115
|
+
|
|
116
|
+
return {
|
|
117
|
+
name: IntegrandVisitor.FieldInfo(
|
|
118
|
+
field=field,
|
|
119
|
+
abstract_type=integrand.argspec.annotations[name],
|
|
120
|
+
concrete_type=get_concrete_type(field),
|
|
121
|
+
root_arg_name=name,
|
|
122
|
+
)
|
|
123
|
+
for name, field in field_args.items()
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def _get_field_info(self, node: ast.expr):
|
|
127
|
+
field_info = self._field_nodes.get(node)
|
|
128
|
+
if field_info is None and isinstance(node, ast.Name):
|
|
129
|
+
field_info = self._field_symbols.get(node.id)
|
|
130
|
+
|
|
131
|
+
return field_info
|
|
132
|
+
|
|
133
|
+
def visit_Call(self, call: ast.Call):
|
|
134
|
+
call = self.generic_visit(call)
|
|
135
|
+
|
|
136
|
+
callee = getattr(call.func, "id", None)
|
|
137
|
+
if callee in self._field_symbols:
|
|
138
|
+
# Shortcut for evaluating fields as f(x...)
|
|
139
|
+
field_info = self._field_symbols[callee]
|
|
140
|
+
|
|
141
|
+
# Replace with default call operator
|
|
142
|
+
default_operator = field_info.abstract_type.call_operator
|
|
143
|
+
|
|
144
|
+
self._process_operator_call(call, callee, default_operator, field_info)
|
|
145
|
+
|
|
146
|
+
return call
|
|
147
|
+
|
|
148
|
+
func, _ = _resolve_path(self._integrand.func, call.func)
|
|
149
|
+
|
|
150
|
+
if isinstance(func, Operator) and len(call.args) > 0:
|
|
151
|
+
# Evaluating operators as op(field, x, ...)
|
|
152
|
+
field_info = self._get_field_info(call.args[0])
|
|
153
|
+
if field_info is not None:
|
|
154
|
+
self._process_operator_call(call, func, func, field_info)
|
|
155
|
+
|
|
156
|
+
if func.field_result:
|
|
157
|
+
res = func.field_result(field_info.field)
|
|
158
|
+
self._field_nodes[call] = IntegrandVisitor.FieldInfo(
|
|
159
|
+
field=res[0],
|
|
160
|
+
abstract_type=res[1],
|
|
161
|
+
concrete_type=res[2],
|
|
162
|
+
root_arg_name=f"{field_info.root_arg_name}.{func.name}",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if isinstance(func, Integrand):
|
|
166
|
+
callee_field_args = self._get_callee_field_args(func, call.args)
|
|
167
|
+
self._process_integrand_call(call, func, callee_field_args)
|
|
168
|
+
|
|
169
|
+
# print(ast.dump(call, indent=4))
|
|
170
|
+
|
|
171
|
+
return call
|
|
172
|
+
|
|
173
|
+
def visit_Assign(self, node: ast.Assign):
|
|
174
|
+
node = self.generic_visit(node)
|
|
175
|
+
|
|
176
|
+
# Check if we're assigning a field
|
|
177
|
+
src_field_info = self._get_field_info(node.value)
|
|
178
|
+
if src_field_info is not None:
|
|
179
|
+
if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
|
|
180
|
+
raise NotImplementedError("warp.fem Fields and Domains may only be assigned to simple variables")
|
|
181
|
+
|
|
182
|
+
self._field_symbols[node.targets[0].id] = src_field_info
|
|
183
|
+
|
|
184
|
+
return node
|
|
185
|
+
|
|
186
|
+
def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
|
|
187
|
+
# Get field types for call site arguments
|
|
188
|
+
call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
|
|
189
|
+
for arg in args:
|
|
190
|
+
field_info = self._get_field_info(arg)
|
|
191
|
+
if field_info is not None:
|
|
192
|
+
call_site_field_args.append(field_info)
|
|
193
|
+
|
|
194
|
+
call_site_field_args.reverse()
|
|
195
|
+
|
|
196
|
+
# Pass to callee in same order
|
|
197
|
+
callee_field_args = {}
|
|
198
|
+
for arg in callee.argspec.args:
|
|
199
|
+
arg_type = callee.argspec.annotations[arg]
|
|
200
|
+
if arg_type in (Field, Domain):
|
|
201
|
+
passed_field_info = call_site_field_args.pop()
|
|
202
|
+
if passed_field_info.abstract_type != arg_type:
|
|
203
|
+
raise TypeError(
|
|
204
|
+
f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
|
|
205
|
+
)
|
|
206
|
+
callee_field_args[arg] = passed_field_info
|
|
207
|
+
|
|
208
|
+
return callee_field_args
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class IntegrandOperatorParser(IntegrandVisitor):
|
|
212
|
+
def __init__(self, integrand: Integrand, field_info: Dict[str, IntegrandVisitor.FieldInfo], callback: Callable):
|
|
213
|
+
super().__init__(integrand, field_info)
|
|
214
|
+
self._operator_callback = callback
|
|
215
|
+
|
|
216
|
+
def _process_operator_call(
|
|
217
|
+
self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
|
|
218
|
+
):
|
|
219
|
+
self._operator_callback(field_info, operator)
|
|
220
|
+
|
|
221
|
+
def _process_integrand_call(
|
|
222
|
+
self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
|
|
223
|
+
):
|
|
224
|
+
callee_field_args = self._get_callee_field_args(callee, call.args)
|
|
225
|
+
callee_parser = IntegrandOperatorParser(callee, callee_field_args, callback=self._operator_callback)
|
|
226
|
+
callee_parser._apply()
|
|
227
|
+
|
|
228
|
+
def _apply(self):
|
|
229
|
+
source = textwrap.dedent(inspect.getsource(self._integrand.func))
|
|
230
|
+
tree = ast.parse(source)
|
|
231
|
+
self.visit(tree)
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def apply(
|
|
235
|
+
integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
|
|
236
|
+
) -> wp.Function:
|
|
237
|
+
field_info = IntegrandVisitor._build_field_info(integrand, field_args)
|
|
238
|
+
IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class IntegrandTransformer(IntegrandVisitor):
|
|
242
|
+
def _process_operator_call(
|
|
243
|
+
self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
|
|
244
|
+
):
|
|
245
|
+
field = field_info.field
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
# Retrieve the function pointer corresponding to the operator implementation for the field type
|
|
249
|
+
pointer = operator.resolver(field)
|
|
250
|
+
if not isinstance(pointer, wp.context.Function):
|
|
251
|
+
raise NotImplementedError(operator.resolver.__name__)
|
|
252
|
+
|
|
253
|
+
except (AttributeError, NotImplementedError) as e:
|
|
254
|
+
raise TypeError(
|
|
255
|
+
f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
|
|
256
|
+
) from e
|
|
257
|
+
|
|
258
|
+
# Update the ast Call node to use the new function pointer
|
|
259
|
+
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
|
|
260
|
+
|
|
261
|
+
# Save the pointer as an attribute than can be accessed from the calling scope
|
|
262
|
+
# For usual operator call syntax, we can use the operator itself, but for the
|
|
263
|
+
# shortcut default operator syntax, we store it on the callee's concrete type
|
|
264
|
+
if isinstance(callee, Operator):
|
|
265
|
+
setattr(callee, pointer.key, pointer)
|
|
266
|
+
else:
|
|
267
|
+
setattr(field_info.concrete_type, pointer.key, pointer)
|
|
268
|
+
|
|
269
|
+
# also insert callee as first argument
|
|
270
|
+
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
|
|
271
|
+
|
|
272
|
+
def _process_integrand_call(
|
|
273
|
+
self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
|
|
274
|
+
):
|
|
275
|
+
callee_field_args = self._get_callee_field_args(callee, call.args)
|
|
276
|
+
transformer = IntegrandTransformer(callee, callee_field_args)
|
|
277
|
+
key = transformer._apply().key
|
|
278
|
+
call.func = ast.Attribute(
|
|
279
|
+
value=call.func,
|
|
280
|
+
attr=key,
|
|
281
|
+
ctx=ast.Load(),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def _apply(self) -> wp.Function:
|
|
285
|
+
# Transform field evaluation calls
|
|
286
|
+
field_info = self._field_symbols
|
|
287
|
+
|
|
288
|
+
# Specialize field argument types
|
|
289
|
+
argspec = self._integrand.argspec
|
|
290
|
+
annotations = argspec.annotations.copy()
|
|
291
|
+
annotations.update({name: f.concrete_type for name, f in field_info.items()})
|
|
292
|
+
|
|
293
|
+
suffix = "_".join([f.field.name for f in field_info.values()])
|
|
294
|
+
func = cache.get_integrand_function(
|
|
295
|
+
integrand=self._integrand,
|
|
296
|
+
suffix=suffix,
|
|
297
|
+
annotations=annotations,
|
|
298
|
+
code_transformers=[self],
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# func = self._integrand.module.functions[func.key] #no longer needed?
|
|
302
|
+
setattr(self._integrand, func.key, func)
|
|
303
|
+
|
|
304
|
+
return func
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
def apply(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
|
|
308
|
+
field_info = IntegrandVisitor._build_field_info(integrand, field_args)
|
|
309
|
+
return IntegrandTransformer(integrand, field_info)._apply()
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class IntegrandArguments(NamedTuple):
|
|
313
|
+
field_args: Dict[str, Union[FieldLike, GeometryDomain]]
|
|
314
|
+
value_args: Dict[str, Any]
|
|
315
|
+
domain_name: str
|
|
316
|
+
sample_name: str
|
|
317
|
+
test_name: str
|
|
318
|
+
trial_name: str
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _parse_integrand_arguments(
|
|
322
|
+
integrand: Integrand,
|
|
323
|
+
fields: Dict[str, FieldLike],
|
|
324
|
+
):
|
|
325
|
+
# parse argument types
|
|
326
|
+
field_args = {}
|
|
327
|
+
value_args = {}
|
|
328
|
+
|
|
329
|
+
domain_name = None
|
|
330
|
+
sample_name = None
|
|
331
|
+
test_name = None
|
|
332
|
+
trial_name = None
|
|
333
|
+
|
|
334
|
+
argspec = integrand.argspec
|
|
335
|
+
for arg in argspec.args:
|
|
336
|
+
arg_type = argspec.annotations[arg]
|
|
337
|
+
if arg_type == Field:
|
|
338
|
+
try:
|
|
339
|
+
field = fields[arg]
|
|
340
|
+
except KeyError as err:
|
|
341
|
+
raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
|
|
342
|
+
if not isinstance(field, FieldLike):
|
|
343
|
+
raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
|
|
344
|
+
if isinstance(field, TestField):
|
|
345
|
+
if test_name is not None:
|
|
346
|
+
raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
|
|
347
|
+
test_name = arg
|
|
348
|
+
elif isinstance(field, TrialField):
|
|
349
|
+
if trial_name is not None:
|
|
350
|
+
raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
|
|
351
|
+
trial_name = arg
|
|
352
|
+
field_args[arg] = field
|
|
353
|
+
elif arg_type == Domain:
|
|
354
|
+
if domain_name is not None:
|
|
355
|
+
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
|
|
356
|
+
if arg in fields:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
359
|
+
)
|
|
360
|
+
domain_name = arg
|
|
361
|
+
elif arg_type == Sample:
|
|
362
|
+
if sample_name is not None:
|
|
363
|
+
raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
|
|
364
|
+
if arg in fields:
|
|
365
|
+
raise ValueError(
|
|
366
|
+
f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
|
|
367
|
+
)
|
|
368
|
+
sample_name = arg
|
|
369
|
+
else:
|
|
370
|
+
if arg in fields:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
|
|
373
|
+
)
|
|
374
|
+
value_args[arg] = arg_type
|
|
375
|
+
|
|
376
|
+
return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _check_field_compat(integrand: Integrand, arguments: IntegrandArguments, domain: GeometryDomain):
|
|
380
|
+
# Check field compatibility
|
|
381
|
+
for name, field in arguments.field_args.items():
|
|
382
|
+
if isinstance(field, GeometryField) and domain is not None:
|
|
383
|
+
if field.geometry != domain.geometry:
|
|
384
|
+
raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
|
|
385
|
+
if field.element_kind != domain.element_kind:
|
|
386
|
+
raise ValueError(
|
|
387
|
+
f"Field '{name}' is not defined on the same kind of elements (cells or sides) as the integration domain. Maybe a forgotten `.trace()`?"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def _find_integrand_operators(integrand: Integrand, field_args: Dict[str, FieldLike]):
|
|
392
|
+
if integrand.operators is None:
|
|
393
|
+
# Integrands operator dictionary does not depend on concrete field type,
|
|
394
|
+
# so only needs to be built once per integrand
|
|
395
|
+
|
|
396
|
+
operators = {}
|
|
397
|
+
|
|
398
|
+
def operator_callback(field: IntegrandVisitor.FieldInfo, op: Operator):
|
|
399
|
+
if field.root_arg_name in operators:
|
|
400
|
+
operators[field.root_arg_name].add(op)
|
|
401
|
+
else:
|
|
402
|
+
operators[field.root_arg_name] = {op}
|
|
403
|
+
|
|
404
|
+
IntegrandOperatorParser.apply(integrand, field_args, operator_callback=operator_callback)
|
|
405
|
+
|
|
406
|
+
integrand.operators = operators
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _notify_operator_usage(
|
|
410
|
+
integrand: Integrand,
|
|
411
|
+
field_args: Dict[str, FieldLike],
|
|
412
|
+
):
|
|
413
|
+
for arg, field_ops in integrand.operators.items():
|
|
414
|
+
if arg in field_args:
|
|
415
|
+
# print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
|
|
416
|
+
field_args[arg].notify_operator_usage(field_ops)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _gen_field_struct(field_args: Dict[str, FieldLike]):
|
|
420
|
+
class Fields:
|
|
421
|
+
pass
|
|
422
|
+
|
|
423
|
+
annotations = get_annotations(Fields)
|
|
424
|
+
|
|
425
|
+
for name, arg in field_args.items():
|
|
426
|
+
if isinstance(arg, GeometryDomain):
|
|
427
|
+
continue
|
|
428
|
+
setattr(Fields, name, arg.EvalArg())
|
|
429
|
+
annotations[name] = arg.EvalArg
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
Fields.__annotations__ = annotations
|
|
433
|
+
except AttributeError:
|
|
434
|
+
Fields.__dict__.__annotations__ = annotations
|
|
435
|
+
|
|
436
|
+
suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
|
|
437
|
+
|
|
438
|
+
return cache.get_struct(Fields, suffix=suffix)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def _get_trial_arg():
|
|
442
|
+
pass
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _get_test_arg():
|
|
446
|
+
pass
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
450
|
+
def __init__(
|
|
451
|
+
self,
|
|
452
|
+
arg_names: List[str],
|
|
453
|
+
parsed_args: IntegrandArguments,
|
|
454
|
+
integrand_func: wp.Function,
|
|
455
|
+
func_name: str = "integrand_func",
|
|
456
|
+
fields_var_name: str = "fields",
|
|
457
|
+
values_var_name: str = "values",
|
|
458
|
+
domain_var_name: str = "domain_arg",
|
|
459
|
+
sample_var_name: str = "sample",
|
|
460
|
+
field_wrappers_attr: str = "_field_wrappers",
|
|
461
|
+
):
|
|
462
|
+
self._arg_names = arg_names
|
|
463
|
+
self._field_args = parsed_args.field_args
|
|
464
|
+
self._value_args = parsed_args.value_args
|
|
465
|
+
self._domain_name = parsed_args.domain_name
|
|
466
|
+
self._sample_name = parsed_args.sample_name
|
|
467
|
+
self._test_name = parsed_args.test_name
|
|
468
|
+
self._trial_name = parsed_args.trial_name
|
|
469
|
+
self._func_name = func_name
|
|
470
|
+
self._fields_var_name = fields_var_name
|
|
471
|
+
self._values_var_name = values_var_name
|
|
472
|
+
self._domain_var_name = domain_var_name
|
|
473
|
+
self._sample_var_name = sample_var_name
|
|
474
|
+
|
|
475
|
+
self._field_wrappers_attr = field_wrappers_attr
|
|
476
|
+
self._register_integrand_field_wrappers(integrand_func, parsed_args.field_args)
|
|
477
|
+
|
|
478
|
+
class _FieldWrappers:
|
|
479
|
+
pass
|
|
480
|
+
|
|
481
|
+
def _register_integrand_field_wrappers(self, integrand_func: wp.Function, fields: Dict[str, FieldLike]):
|
|
482
|
+
# Mechanism to pass the geometry argument only once to the root kernel
|
|
483
|
+
# Field wrappers are used to forward it to all fields in nested integrand calls
|
|
484
|
+
field_wrappers = PassFieldArgsToIntegrand._FieldWrappers()
|
|
485
|
+
for name, field in fields.items():
|
|
486
|
+
if isinstance(field, FieldLike):
|
|
487
|
+
setattr(field_wrappers, name, field.ElementEvalArg)
|
|
488
|
+
setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
|
|
489
|
+
|
|
490
|
+
def visit_Call(self, call: ast.Call):
|
|
491
|
+
call = self.generic_visit(call)
|
|
492
|
+
|
|
493
|
+
callee = getattr(call.func, "id", None)
|
|
494
|
+
|
|
495
|
+
if callee == self._func_name:
|
|
496
|
+
# Replace function arguments with our generated structs
|
|
497
|
+
call.args.clear()
|
|
498
|
+
for arg in self._arg_names:
|
|
499
|
+
if arg == self._domain_name:
|
|
500
|
+
call.args.append(
|
|
501
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
502
|
+
)
|
|
503
|
+
elif arg == self._sample_name:
|
|
504
|
+
call.args.append(
|
|
505
|
+
ast.Name(id=self._sample_var_name, ctx=ast.Load()),
|
|
506
|
+
)
|
|
507
|
+
elif arg in self._field_args:
|
|
508
|
+
call.args.append(
|
|
509
|
+
ast.Call(
|
|
510
|
+
func=ast.Attribute(
|
|
511
|
+
value=ast.Attribute(
|
|
512
|
+
value=ast.Name(id=self._func_name, ctx=ast.Load()),
|
|
513
|
+
attr=self._field_wrappers_attr,
|
|
514
|
+
ctx=ast.Load(),
|
|
515
|
+
),
|
|
516
|
+
attr=arg,
|
|
517
|
+
ctx=ast.Load(),
|
|
518
|
+
),
|
|
519
|
+
args=[
|
|
520
|
+
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
|
|
521
|
+
ast.Attribute(
|
|
522
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
523
|
+
attr=arg,
|
|
524
|
+
ctx=ast.Load(),
|
|
525
|
+
),
|
|
526
|
+
],
|
|
527
|
+
keywords=[],
|
|
528
|
+
)
|
|
529
|
+
)
|
|
530
|
+
elif arg in self._value_args:
|
|
531
|
+
call.args.append(
|
|
532
|
+
ast.Attribute(
|
|
533
|
+
value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
|
|
534
|
+
attr=arg,
|
|
535
|
+
ctx=ast.Load(),
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
raise RuntimeError(f"Unhandled argument {arg}")
|
|
540
|
+
# print(ast.dump(call, indent=4))
|
|
541
|
+
elif callee == _get_test_arg.__name__:
|
|
542
|
+
# print(ast.dump(call, indent=4))
|
|
543
|
+
call = ast.Attribute(
|
|
544
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
545
|
+
attr=self._test_name,
|
|
546
|
+
ctx=ast.Load(),
|
|
547
|
+
)
|
|
548
|
+
elif callee == _get_trial_arg.__name__:
|
|
549
|
+
# print(ast.dump(call, indent=4))
|
|
550
|
+
call = ast.Attribute(
|
|
551
|
+
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
|
|
552
|
+
attr=self._trial_name,
|
|
553
|
+
ctx=ast.Load(),
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
return call
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_site_options: Optional[Dict[str, Any]]):
|
|
560
|
+
if integrand_options is None:
|
|
561
|
+
return {} if call_site_options is None else call_site_options
|
|
562
|
+
|
|
563
|
+
options = integrand_options.copy()
|
|
564
|
+
if call_site_options is not None:
|
|
565
|
+
options.update(call_site_options)
|
|
566
|
+
return options
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def get_integrate_constant_kernel(
|
|
570
|
+
integrand_func: wp.Function,
|
|
571
|
+
domain: GeometryDomain,
|
|
572
|
+
quadrature: Quadrature,
|
|
573
|
+
FieldStruct: wp.codegen.Struct,
|
|
574
|
+
ValueStruct: wp.codegen.Struct,
|
|
575
|
+
accumulate_dtype,
|
|
576
|
+
):
|
|
577
|
+
def integrate_kernel_fn(
|
|
578
|
+
qp_arg: quadrature.Arg,
|
|
579
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
580
|
+
domain_arg: domain.ElementArg,
|
|
581
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
582
|
+
fields: FieldStruct,
|
|
583
|
+
values: ValueStruct,
|
|
584
|
+
result: wp.array(dtype=accumulate_dtype),
|
|
585
|
+
):
|
|
586
|
+
qp_eval_index = wp.tid()
|
|
587
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
588
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
589
|
+
return
|
|
590
|
+
|
|
591
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
592
|
+
|
|
593
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
594
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
595
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
596
|
+
|
|
597
|
+
test_dof_index = NULL_DOF_INDEX
|
|
598
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
599
|
+
|
|
600
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
601
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
602
|
+
|
|
603
|
+
val = integrand_func(sample, fields, values)
|
|
604
|
+
|
|
605
|
+
wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
|
|
606
|
+
|
|
607
|
+
return integrate_kernel_fn
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def get_integrate_linear_kernel(
|
|
611
|
+
integrand_func: wp.Function,
|
|
612
|
+
domain: GeometryDomain,
|
|
613
|
+
quadrature: Quadrature,
|
|
614
|
+
FieldStruct: wp.codegen.Struct,
|
|
615
|
+
ValueStruct: wp.codegen.Struct,
|
|
616
|
+
test: TestField,
|
|
617
|
+
output_dtype,
|
|
618
|
+
accumulate_dtype,
|
|
619
|
+
):
|
|
620
|
+
def integrate_kernel_fn(
|
|
621
|
+
qp_arg: quadrature.Arg,
|
|
622
|
+
domain_arg: domain.ElementArg,
|
|
623
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
624
|
+
test_arg: test.space_restriction.NodeArg,
|
|
625
|
+
fields: FieldStruct,
|
|
626
|
+
values: ValueStruct,
|
|
627
|
+
result: wp.array2d(dtype=output_dtype),
|
|
628
|
+
):
|
|
629
|
+
local_node_index, test_dof = wp.tid()
|
|
630
|
+
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
631
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
|
|
632
|
+
|
|
633
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
634
|
+
|
|
635
|
+
val_sum = accumulate_dtype(0.0)
|
|
636
|
+
|
|
637
|
+
for n in range(element_beg, element_end):
|
|
638
|
+
node_element_index = test.space_restriction.node_element_index(test_arg, n)
|
|
639
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
640
|
+
|
|
641
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
642
|
+
|
|
643
|
+
qp_point_count = quadrature.point_count(
|
|
644
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index
|
|
645
|
+
)
|
|
646
|
+
for k in range(qp_point_count):
|
|
647
|
+
qp_index = quadrature.point_index(
|
|
648
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
|
|
649
|
+
)
|
|
650
|
+
qp_coords = quadrature.point_coords(
|
|
651
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
|
|
652
|
+
)
|
|
653
|
+
qp_weight = quadrature.point_weight(
|
|
654
|
+
domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
658
|
+
|
|
659
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
660
|
+
val = integrand_func(sample, fields, values)
|
|
661
|
+
|
|
662
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
663
|
+
|
|
664
|
+
result[node_index, test_dof] += output_dtype(val_sum)
|
|
665
|
+
|
|
666
|
+
return integrate_kernel_fn
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def get_integrate_linear_nodal_kernel(
|
|
670
|
+
integrand_func: wp.Function,
|
|
671
|
+
domain: GeometryDomain,
|
|
672
|
+
FieldStruct: wp.codegen.Struct,
|
|
673
|
+
ValueStruct: wp.codegen.Struct,
|
|
674
|
+
test: TestField,
|
|
675
|
+
output_dtype,
|
|
676
|
+
accumulate_dtype,
|
|
677
|
+
):
|
|
678
|
+
def integrate_kernel_fn(
|
|
679
|
+
domain_arg: domain.ElementArg,
|
|
680
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
681
|
+
test_restriction_arg: test.space_restriction.NodeArg,
|
|
682
|
+
test_topo_arg: test.space.topology.TopologyArg,
|
|
683
|
+
fields: FieldStruct,
|
|
684
|
+
values: ValueStruct,
|
|
685
|
+
result: wp.array2d(dtype=output_dtype),
|
|
686
|
+
):
|
|
687
|
+
local_node_index, dof = wp.tid()
|
|
688
|
+
|
|
689
|
+
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
690
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
691
|
+
|
|
692
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
693
|
+
|
|
694
|
+
val_sum = accumulate_dtype(0.0)
|
|
695
|
+
|
|
696
|
+
for n in range(element_beg, element_end):
|
|
697
|
+
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
|
|
698
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
699
|
+
|
|
700
|
+
if n == element_beg:
|
|
701
|
+
node_index = test.space.topology.element_node_index(
|
|
702
|
+
domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
coords = test.space.node_coords_in_element(
|
|
706
|
+
domain_arg,
|
|
707
|
+
_get_test_arg(),
|
|
708
|
+
element_index,
|
|
709
|
+
node_element_index.node_index_in_element,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
if coords[0] != OUTSIDE:
|
|
713
|
+
node_weight = test.space.node_quadrature_weight(
|
|
714
|
+
domain_arg,
|
|
715
|
+
_get_test_arg(),
|
|
716
|
+
element_index,
|
|
717
|
+
node_element_index.node_index_in_element,
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
|
|
721
|
+
|
|
722
|
+
sample = Sample(
|
|
723
|
+
element_index,
|
|
724
|
+
coords,
|
|
725
|
+
node_index,
|
|
726
|
+
node_weight,
|
|
727
|
+
test_dof_index,
|
|
728
|
+
trial_dof_index,
|
|
729
|
+
)
|
|
730
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
731
|
+
val = integrand_func(sample, fields, values)
|
|
732
|
+
|
|
733
|
+
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
734
|
+
|
|
735
|
+
result[partition_node_index, dof] += output_dtype(val_sum)
|
|
736
|
+
|
|
737
|
+
return integrate_kernel_fn
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def get_integrate_linear_local_kernel(
|
|
741
|
+
integrand_func: wp.Function,
|
|
742
|
+
domain: GeometryDomain,
|
|
743
|
+
quadrature: Quadrature,
|
|
744
|
+
FieldStruct: wp.codegen.Struct,
|
|
745
|
+
ValueStruct: wp.codegen.Struct,
|
|
746
|
+
test: LocalTestField,
|
|
747
|
+
):
|
|
748
|
+
def integrate_kernel_fn(
|
|
749
|
+
qp_arg: quadrature.Arg,
|
|
750
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
751
|
+
domain_arg: domain.ElementArg,
|
|
752
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
753
|
+
fields: FieldStruct,
|
|
754
|
+
values: ValueStruct,
|
|
755
|
+
result: wp.array3d(dtype=float),
|
|
756
|
+
):
|
|
757
|
+
qp_eval_index, taylor_dof, test_dof = wp.tid()
|
|
758
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
759
|
+
|
|
760
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
761
|
+
return
|
|
762
|
+
|
|
763
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
764
|
+
|
|
765
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
766
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
767
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
768
|
+
|
|
769
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
770
|
+
|
|
771
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
772
|
+
test_dof_index = DofIndex(taylor_dof, test_dof)
|
|
773
|
+
|
|
774
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
775
|
+
val = integrand_func(sample, fields, values)
|
|
776
|
+
result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
|
|
777
|
+
|
|
778
|
+
return integrate_kernel_fn
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def get_integrate_bilinear_kernel(
|
|
782
|
+
integrand_func: wp.Function,
|
|
783
|
+
domain: GeometryDomain,
|
|
784
|
+
quadrature: Quadrature,
|
|
785
|
+
FieldStruct: wp.codegen.Struct,
|
|
786
|
+
ValueStruct: wp.codegen.Struct,
|
|
787
|
+
test: TestField,
|
|
788
|
+
trial: TrialField,
|
|
789
|
+
output_dtype,
|
|
790
|
+
accumulate_dtype,
|
|
791
|
+
):
|
|
792
|
+
MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
793
|
+
|
|
794
|
+
def integrate_kernel_fn(
|
|
795
|
+
qp_arg: quadrature.Arg,
|
|
796
|
+
domain_arg: domain.ElementArg,
|
|
797
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
798
|
+
test_arg: test.space_restriction.NodeArg,
|
|
799
|
+
trial_partition_arg: trial.space_partition.PartitionArg,
|
|
800
|
+
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
801
|
+
fields: FieldStruct,
|
|
802
|
+
values: ValueStruct,
|
|
803
|
+
triplet_rows: wp.array(dtype=int),
|
|
804
|
+
triplet_cols: wp.array(dtype=int),
|
|
805
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
806
|
+
):
|
|
807
|
+
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
808
|
+
|
|
809
|
+
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
810
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
|
|
811
|
+
|
|
812
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
813
|
+
|
|
814
|
+
for element in range(element_beg, element_end):
|
|
815
|
+
test_element_index = test.space_restriction.node_element_index(test_arg, element)
|
|
816
|
+
element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
|
|
817
|
+
|
|
818
|
+
element_trial_node_count = trial.space.topology.element_node_count(
|
|
819
|
+
domain_arg, trial_topology_arg, element_index
|
|
820
|
+
)
|
|
821
|
+
qp_point_count = wp.where(
|
|
822
|
+
trial_node < element_trial_node_count,
|
|
823
|
+
quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
|
|
824
|
+
0,
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
test_dof_index = DofIndex(
|
|
828
|
+
test_element_index.node_index_in_element,
|
|
829
|
+
test_dof,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
val_sum = accumulate_dtype(0.0)
|
|
833
|
+
|
|
834
|
+
for k in range(qp_point_count):
|
|
835
|
+
qp_index = quadrature.point_index(
|
|
836
|
+
domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
|
|
837
|
+
)
|
|
838
|
+
coords = quadrature.point_coords(
|
|
839
|
+
domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
qp_weight = quadrature.point_weight(
|
|
843
|
+
domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
|
|
844
|
+
)
|
|
845
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
|
|
846
|
+
|
|
847
|
+
sample = Sample(
|
|
848
|
+
element_index,
|
|
849
|
+
coords,
|
|
850
|
+
qp_index,
|
|
851
|
+
qp_weight,
|
|
852
|
+
test_dof_index,
|
|
853
|
+
trial_dof_index,
|
|
854
|
+
)
|
|
855
|
+
val = integrand_func(sample, fields, values)
|
|
856
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
857
|
+
|
|
858
|
+
block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
|
|
859
|
+
triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
|
|
860
|
+
|
|
861
|
+
# Set row and column indices
|
|
862
|
+
if test_dof == 0 and trial_dof == 0:
|
|
863
|
+
if trial_node < element_trial_node_count:
|
|
864
|
+
trial_node_index = trial.space_partition.partition_node_index(
|
|
865
|
+
trial_partition_arg,
|
|
866
|
+
trial.space.topology.element_node_index(
|
|
867
|
+
domain_arg, trial_topology_arg, element_index, trial_node
|
|
868
|
+
),
|
|
869
|
+
)
|
|
870
|
+
else:
|
|
871
|
+
trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
|
|
872
|
+
triplet_rows[block_offset] = test_node_index
|
|
873
|
+
triplet_cols[block_offset] = trial_node_index
|
|
874
|
+
|
|
875
|
+
return integrate_kernel_fn
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
def get_integrate_bilinear_nodal_kernel(
|
|
879
|
+
integrand_func: wp.Function,
|
|
880
|
+
domain: GeometryDomain,
|
|
881
|
+
FieldStruct: wp.codegen.Struct,
|
|
882
|
+
ValueStruct: wp.codegen.Struct,
|
|
883
|
+
test: TestField,
|
|
884
|
+
output_dtype,
|
|
885
|
+
accumulate_dtype,
|
|
886
|
+
):
|
|
887
|
+
def integrate_kernel_fn(
|
|
888
|
+
domain_arg: domain.ElementArg,
|
|
889
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
890
|
+
test_restriction_arg: test.space_restriction.NodeArg,
|
|
891
|
+
test_topo_arg: test.space.topology.TopologyArg,
|
|
892
|
+
fields: FieldStruct,
|
|
893
|
+
values: ValueStruct,
|
|
894
|
+
triplet_rows: wp.array(dtype=int),
|
|
895
|
+
triplet_cols: wp.array(dtype=int),
|
|
896
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
897
|
+
):
|
|
898
|
+
local_node_index, test_dof, trial_dof = wp.tid()
|
|
899
|
+
|
|
900
|
+
partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
|
|
901
|
+
element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
|
|
902
|
+
|
|
903
|
+
val_sum = accumulate_dtype(0.0)
|
|
904
|
+
|
|
905
|
+
for n in range(element_beg, element_end):
|
|
906
|
+
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
|
|
907
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
908
|
+
|
|
909
|
+
if n == element_beg:
|
|
910
|
+
node_index = test.space.topology.element_node_index(
|
|
911
|
+
domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
coords = test.space.node_coords_in_element(
|
|
915
|
+
domain_arg,
|
|
916
|
+
_get_test_arg(),
|
|
917
|
+
element_index,
|
|
918
|
+
node_element_index.node_index_in_element,
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
if coords[0] != OUTSIDE:
|
|
922
|
+
node_weight = test.space.node_quadrature_weight(
|
|
923
|
+
domain_arg,
|
|
924
|
+
_get_test_arg(),
|
|
925
|
+
element_index,
|
|
926
|
+
node_element_index.node_index_in_element,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
930
|
+
trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
|
|
931
|
+
|
|
932
|
+
sample = Sample(
|
|
933
|
+
element_index,
|
|
934
|
+
coords,
|
|
935
|
+
node_index,
|
|
936
|
+
node_weight,
|
|
937
|
+
test_dof_index,
|
|
938
|
+
trial_dof_index,
|
|
939
|
+
)
|
|
940
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
941
|
+
val = integrand_func(sample, fields, values)
|
|
942
|
+
|
|
943
|
+
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
944
|
+
|
|
945
|
+
triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
|
|
946
|
+
triplet_rows[local_node_index] = partition_node_index
|
|
947
|
+
triplet_cols[local_node_index] = partition_node_index
|
|
948
|
+
|
|
949
|
+
return integrate_kernel_fn
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def get_integrate_bilinear_local_kernel(
|
|
953
|
+
integrand_func: wp.Function,
|
|
954
|
+
domain: GeometryDomain,
|
|
955
|
+
quadrature: Quadrature,
|
|
956
|
+
FieldStruct: wp.codegen.Struct,
|
|
957
|
+
ValueStruct: wp.codegen.Struct,
|
|
958
|
+
test: LocalTestField,
|
|
959
|
+
trial: LocalTrialField,
|
|
960
|
+
):
|
|
961
|
+
TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
|
|
962
|
+
TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
|
|
963
|
+
|
|
964
|
+
def integrate_kernel_fn(
|
|
965
|
+
qp_arg: quadrature.Arg,
|
|
966
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
967
|
+
domain_arg: domain.ElementArg,
|
|
968
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
969
|
+
fields: FieldStruct,
|
|
970
|
+
values: ValueStruct,
|
|
971
|
+
result: wp.array4d(dtype=float),
|
|
972
|
+
):
|
|
973
|
+
qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
|
|
974
|
+
|
|
975
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
976
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
977
|
+
return
|
|
978
|
+
|
|
979
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
980
|
+
|
|
981
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
982
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
983
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
984
|
+
|
|
985
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
986
|
+
qp_vol = vol * qp_weight
|
|
987
|
+
|
|
988
|
+
trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
|
|
989
|
+
|
|
990
|
+
for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
|
|
991
|
+
taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
|
|
992
|
+
|
|
993
|
+
test_dof_index = DofIndex(test_taylor_dof, test_dof)
|
|
994
|
+
|
|
995
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
996
|
+
val = integrand_func(sample, fields, values)
|
|
997
|
+
result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
|
|
998
|
+
|
|
999
|
+
return integrate_kernel_fn
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
def _generate_integrate_kernel(
|
|
1003
|
+
integrand: Integrand,
|
|
1004
|
+
domain: GeometryDomain,
|
|
1005
|
+
quadrature: Quadrature,
|
|
1006
|
+
arguments: IntegrandArguments,
|
|
1007
|
+
test: Optional[TestField],
|
|
1008
|
+
trial: Optional[TrialField],
|
|
1009
|
+
output_dtype: type,
|
|
1010
|
+
accumulate_dtype: type,
|
|
1011
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1012
|
+
) -> wp.Kernel:
|
|
1013
|
+
output_dtype = wp.types.type_scalar_type(output_dtype)
|
|
1014
|
+
|
|
1015
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
1016
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
1017
|
+
|
|
1018
|
+
_notify_operator_usage(integrand, arguments.field_args)
|
|
1019
|
+
|
|
1020
|
+
# Check if kernel exist in cache
|
|
1021
|
+
field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
|
|
1022
|
+
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{field_names}"
|
|
1023
|
+
|
|
1024
|
+
if quadrature is not None:
|
|
1025
|
+
kernel_suffix += quadrature.name
|
|
1026
|
+
|
|
1027
|
+
kernel = cache.get_integrand_kernel(integrand=integrand, suffix=kernel_suffix, kernel_options=kernel_options)
|
|
1028
|
+
if kernel is not None:
|
|
1029
|
+
return kernel, FieldStruct, ValueStruct
|
|
1030
|
+
|
|
1031
|
+
# Not found in cache, transform integrand and generate kernel
|
|
1032
|
+
_check_field_compat(integrand, arguments, domain)
|
|
1033
|
+
|
|
1034
|
+
integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
|
|
1035
|
+
|
|
1036
|
+
nodal = quadrature is None
|
|
1037
|
+
|
|
1038
|
+
if test is None and trial is None:
|
|
1039
|
+
integrate_kernel_fn = get_integrate_constant_kernel(
|
|
1040
|
+
integrand_func,
|
|
1041
|
+
domain,
|
|
1042
|
+
quadrature,
|
|
1043
|
+
FieldStruct,
|
|
1044
|
+
ValueStruct,
|
|
1045
|
+
accumulate_dtype=accumulate_dtype,
|
|
1046
|
+
)
|
|
1047
|
+
elif trial is None:
|
|
1048
|
+
if nodal:
|
|
1049
|
+
integrate_kernel_fn = get_integrate_linear_nodal_kernel(
|
|
1050
|
+
integrand_func,
|
|
1051
|
+
domain,
|
|
1052
|
+
FieldStruct,
|
|
1053
|
+
ValueStruct,
|
|
1054
|
+
test=test,
|
|
1055
|
+
output_dtype=output_dtype,
|
|
1056
|
+
accumulate_dtype=accumulate_dtype,
|
|
1057
|
+
)
|
|
1058
|
+
elif isinstance(test, LocalTestField):
|
|
1059
|
+
integrate_kernel_fn = get_integrate_linear_local_kernel(
|
|
1060
|
+
integrand_func,
|
|
1061
|
+
domain,
|
|
1062
|
+
quadrature,
|
|
1063
|
+
FieldStruct,
|
|
1064
|
+
ValueStruct,
|
|
1065
|
+
test=test,
|
|
1066
|
+
)
|
|
1067
|
+
else:
|
|
1068
|
+
integrate_kernel_fn = get_integrate_linear_kernel(
|
|
1069
|
+
integrand_func,
|
|
1070
|
+
domain,
|
|
1071
|
+
quadrature,
|
|
1072
|
+
FieldStruct,
|
|
1073
|
+
ValueStruct,
|
|
1074
|
+
test=test,
|
|
1075
|
+
output_dtype=output_dtype,
|
|
1076
|
+
accumulate_dtype=accumulate_dtype,
|
|
1077
|
+
)
|
|
1078
|
+
else:
|
|
1079
|
+
if nodal:
|
|
1080
|
+
integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
|
|
1081
|
+
integrand_func,
|
|
1082
|
+
domain,
|
|
1083
|
+
FieldStruct,
|
|
1084
|
+
ValueStruct,
|
|
1085
|
+
test=test,
|
|
1086
|
+
output_dtype=output_dtype,
|
|
1087
|
+
accumulate_dtype=accumulate_dtype,
|
|
1088
|
+
)
|
|
1089
|
+
elif isinstance(test, LocalTestField):
|
|
1090
|
+
integrate_kernel_fn = get_integrate_bilinear_local_kernel(
|
|
1091
|
+
integrand_func,
|
|
1092
|
+
domain,
|
|
1093
|
+
quadrature,
|
|
1094
|
+
FieldStruct,
|
|
1095
|
+
ValueStruct,
|
|
1096
|
+
test=test,
|
|
1097
|
+
trial=trial,
|
|
1098
|
+
)
|
|
1099
|
+
else:
|
|
1100
|
+
integrate_kernel_fn = get_integrate_bilinear_kernel(
|
|
1101
|
+
integrand_func,
|
|
1102
|
+
domain,
|
|
1103
|
+
quadrature,
|
|
1104
|
+
FieldStruct,
|
|
1105
|
+
ValueStruct,
|
|
1106
|
+
test=test,
|
|
1107
|
+
trial=trial,
|
|
1108
|
+
output_dtype=output_dtype,
|
|
1109
|
+
accumulate_dtype=accumulate_dtype,
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
kernel = cache.get_integrand_kernel(
|
|
1113
|
+
integrand=integrand,
|
|
1114
|
+
kernel_fn=integrate_kernel_fn,
|
|
1115
|
+
suffix=kernel_suffix,
|
|
1116
|
+
kernel_options=kernel_options,
|
|
1117
|
+
code_transformers=[
|
|
1118
|
+
PassFieldArgsToIntegrand(
|
|
1119
|
+
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
1120
|
+
)
|
|
1121
|
+
],
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
return kernel, FieldStruct, ValueStruct
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
def _launch_integrate_kernel(
|
|
1128
|
+
integrand: Integrand,
|
|
1129
|
+
kernel: wp.Kernel,
|
|
1130
|
+
FieldStruct: wp.codegen.Struct,
|
|
1131
|
+
ValueStruct: wp.codegen.Struct,
|
|
1132
|
+
domain: GeometryDomain,
|
|
1133
|
+
quadrature: Quadrature,
|
|
1134
|
+
test: Optional[TestField],
|
|
1135
|
+
trial: Optional[TrialField],
|
|
1136
|
+
fields: Dict[str, FieldLike],
|
|
1137
|
+
values: Dict[str, Any],
|
|
1138
|
+
accumulate_dtype: type,
|
|
1139
|
+
temporary_store: Optional[cache.TemporaryStore],
|
|
1140
|
+
output_dtype: type,
|
|
1141
|
+
output: Optional[Union[wp.array, BsrMatrix]],
|
|
1142
|
+
add_to_output: bool,
|
|
1143
|
+
bsr_options: Optional[Dict[str, Any]],
|
|
1144
|
+
device,
|
|
1145
|
+
):
|
|
1146
|
+
# Set-up launch arguments
|
|
1147
|
+
domain_elt_arg = domain.element_arg_value(device=device)
|
|
1148
|
+
domain_elt_index_arg = domain.element_index_arg_value(device=device)
|
|
1149
|
+
|
|
1150
|
+
if quadrature is not None:
|
|
1151
|
+
qp_arg = quadrature.arg_value(device=device)
|
|
1152
|
+
|
|
1153
|
+
field_arg_values = FieldStruct()
|
|
1154
|
+
for k, v in fields.items():
|
|
1155
|
+
if not isinstance(v, GeometryDomain):
|
|
1156
|
+
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
1157
|
+
|
|
1158
|
+
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
1159
|
+
|
|
1160
|
+
# Constant form
|
|
1161
|
+
if test is None and trial is None:
|
|
1162
|
+
if output is not None and output.dtype == accumulate_dtype:
|
|
1163
|
+
if output.size < 1:
|
|
1164
|
+
raise RuntimeError("Output array must be of size at least 1")
|
|
1165
|
+
accumulate_array = output
|
|
1166
|
+
else:
|
|
1167
|
+
accumulate_temporary = cache.borrow_temporary(
|
|
1168
|
+
shape=(1),
|
|
1169
|
+
device=device,
|
|
1170
|
+
dtype=accumulate_dtype,
|
|
1171
|
+
temporary_store=temporary_store,
|
|
1172
|
+
requires_grad=output is not None and output.requires_grad,
|
|
1173
|
+
)
|
|
1174
|
+
accumulate_array = accumulate_temporary.array
|
|
1175
|
+
|
|
1176
|
+
if output != accumulate_array or not add_to_output:
|
|
1177
|
+
accumulate_array.zero_()
|
|
1178
|
+
|
|
1179
|
+
wp.launch(
|
|
1180
|
+
kernel=kernel,
|
|
1181
|
+
dim=quadrature.evaluation_point_count(),
|
|
1182
|
+
inputs=[
|
|
1183
|
+
qp_arg,
|
|
1184
|
+
quadrature.element_index_arg_value(device),
|
|
1185
|
+
domain_elt_arg,
|
|
1186
|
+
domain_elt_index_arg,
|
|
1187
|
+
field_arg_values,
|
|
1188
|
+
value_struct_values,
|
|
1189
|
+
accumulate_array,
|
|
1190
|
+
],
|
|
1191
|
+
device=device,
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
if output == accumulate_array:
|
|
1195
|
+
return output
|
|
1196
|
+
if output is None:
|
|
1197
|
+
return accumulate_array.numpy()[0]
|
|
1198
|
+
|
|
1199
|
+
if add_to_output:
|
|
1200
|
+
# accumulate dtype is distinct from output dtype
|
|
1201
|
+
array_axpy(x=accumulate_array, y=output)
|
|
1202
|
+
else:
|
|
1203
|
+
array_cast(in_array=accumulate_array, out_array=output)
|
|
1204
|
+
return output
|
|
1205
|
+
|
|
1206
|
+
test_arg = test.space_restriction.node_arg(device=device)
|
|
1207
|
+
nodal = quadrature is None
|
|
1208
|
+
|
|
1209
|
+
# Linear form
|
|
1210
|
+
if trial is None:
|
|
1211
|
+
# If an output array is provided with the correct type, accumulate directly into it
|
|
1212
|
+
# Otherwise, grab a temporary array
|
|
1213
|
+
if output is None:
|
|
1214
|
+
if type_length(output_dtype) == test.node_dof_count:
|
|
1215
|
+
output_shape = (test.space_partition.node_count(),)
|
|
1216
|
+
elif type_length(output_dtype) == 1:
|
|
1217
|
+
output_shape = (test.space_partition.node_count(), test.node_dof_count)
|
|
1218
|
+
else:
|
|
1219
|
+
raise RuntimeError(
|
|
1220
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
output_temporary = cache.borrow_temporary(
|
|
1224
|
+
temporary_store=temporary_store,
|
|
1225
|
+
shape=output_shape,
|
|
1226
|
+
dtype=output_dtype,
|
|
1227
|
+
device=device,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
output = output_temporary.array
|
|
1231
|
+
|
|
1232
|
+
else:
|
|
1233
|
+
output_temporary = None
|
|
1234
|
+
|
|
1235
|
+
if output.shape[0] < test.space_partition.node_count():
|
|
1236
|
+
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
1237
|
+
|
|
1238
|
+
output_dtype = output.dtype
|
|
1239
|
+
if type_length(output_dtype) != test.node_dof_count:
|
|
1240
|
+
if type_length(output_dtype) != 1:
|
|
1241
|
+
raise RuntimeError(
|
|
1242
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
|
|
1243
|
+
)
|
|
1244
|
+
if output.ndim != 2 and output.shape[1] != test.node_dof_count:
|
|
1245
|
+
raise RuntimeError(
|
|
1246
|
+
f"Incompatible output array shape, last dimension must be of size {test.node_dof_count}"
|
|
1247
|
+
)
|
|
1248
|
+
|
|
1249
|
+
# Launch the integration on the kernel on a 2d scalar view of the actual array
|
|
1250
|
+
if not add_to_output:
|
|
1251
|
+
output.zero_()
|
|
1252
|
+
|
|
1253
|
+
def as_2d_array(array):
|
|
1254
|
+
return wp.array(
|
|
1255
|
+
data=None,
|
|
1256
|
+
ptr=array.ptr,
|
|
1257
|
+
capacity=array.capacity,
|
|
1258
|
+
device=array.device,
|
|
1259
|
+
shape=(test.space_partition.node_count(), test.node_dof_count),
|
|
1260
|
+
dtype=wp.types.type_scalar_type(output_dtype),
|
|
1261
|
+
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
1262
|
+
)
|
|
1263
|
+
|
|
1264
|
+
output_view = output if output.ndim == 2 else as_2d_array(output)
|
|
1265
|
+
|
|
1266
|
+
if nodal:
|
|
1267
|
+
wp.launch(
|
|
1268
|
+
kernel=kernel,
|
|
1269
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1270
|
+
inputs=[
|
|
1271
|
+
domain_elt_arg,
|
|
1272
|
+
domain_elt_index_arg,
|
|
1273
|
+
test_arg,
|
|
1274
|
+
test.space.topology.topo_arg_value(device),
|
|
1275
|
+
field_arg_values,
|
|
1276
|
+
value_struct_values,
|
|
1277
|
+
output_view,
|
|
1278
|
+
],
|
|
1279
|
+
device=device,
|
|
1280
|
+
)
|
|
1281
|
+
elif isinstance(test, LocalTestField):
|
|
1282
|
+
local_result = cache.borrow_temporary(
|
|
1283
|
+
temporary_store=temporary_store,
|
|
1284
|
+
device=device,
|
|
1285
|
+
requires_grad=output.requires_grad,
|
|
1286
|
+
shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
|
|
1287
|
+
dtype=float,
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
wp.launch(
|
|
1291
|
+
kernel=kernel,
|
|
1292
|
+
dim=local_result.array.shape,
|
|
1293
|
+
inputs=[
|
|
1294
|
+
qp_arg,
|
|
1295
|
+
quadrature.element_index_arg_value(device),
|
|
1296
|
+
domain_elt_arg,
|
|
1297
|
+
domain_elt_index_arg,
|
|
1298
|
+
field_arg_values,
|
|
1299
|
+
value_struct_values,
|
|
1300
|
+
local_result.array,
|
|
1301
|
+
],
|
|
1302
|
+
device=device,
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
|
|
1306
|
+
wp.launch(
|
|
1307
|
+
kernel=dispatch_kernel,
|
|
1308
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1309
|
+
inputs=[
|
|
1310
|
+
qp_arg,
|
|
1311
|
+
domain_elt_arg,
|
|
1312
|
+
domain_elt_index_arg,
|
|
1313
|
+
test_arg,
|
|
1314
|
+
test.global_field.eval_arg_value(device),
|
|
1315
|
+
local_result.array,
|
|
1316
|
+
output_view,
|
|
1317
|
+
],
|
|
1318
|
+
device=device,
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
local_result.release()
|
|
1322
|
+
|
|
1323
|
+
else:
|
|
1324
|
+
wp.launch(
|
|
1325
|
+
kernel=kernel,
|
|
1326
|
+
dim=(test.space_restriction.node_count(), test.node_dof_count),
|
|
1327
|
+
inputs=[
|
|
1328
|
+
qp_arg,
|
|
1329
|
+
domain_elt_arg,
|
|
1330
|
+
domain_elt_index_arg,
|
|
1331
|
+
test_arg,
|
|
1332
|
+
field_arg_values,
|
|
1333
|
+
value_struct_values,
|
|
1334
|
+
output_view,
|
|
1335
|
+
],
|
|
1336
|
+
device=device,
|
|
1337
|
+
)
|
|
1338
|
+
|
|
1339
|
+
if output_temporary is not None:
|
|
1340
|
+
return output_temporary.detach()
|
|
1341
|
+
|
|
1342
|
+
return output
|
|
1343
|
+
|
|
1344
|
+
# Bilinear form
|
|
1345
|
+
|
|
1346
|
+
if test.node_dof_count == 1 and trial.node_dof_count == 1:
|
|
1347
|
+
block_type = output_dtype
|
|
1348
|
+
else:
|
|
1349
|
+
block_type = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=output_dtype)
|
|
1350
|
+
|
|
1351
|
+
if nodal:
|
|
1352
|
+
nnz = test.space_restriction.node_count()
|
|
1353
|
+
else:
|
|
1354
|
+
nnz = test.space_restriction.total_node_element_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
1355
|
+
|
|
1356
|
+
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1357
|
+
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
1358
|
+
triplet_values_temp = cache.borrow_temporary(
|
|
1359
|
+
temporary_store,
|
|
1360
|
+
shape=(
|
|
1361
|
+
nnz,
|
|
1362
|
+
test.node_dof_count,
|
|
1363
|
+
trial.node_dof_count,
|
|
1364
|
+
),
|
|
1365
|
+
dtype=output_dtype,
|
|
1366
|
+
device=device,
|
|
1367
|
+
)
|
|
1368
|
+
triplet_cols = triplet_cols_temp.array
|
|
1369
|
+
triplet_rows = triplet_rows_temp.array
|
|
1370
|
+
triplet_values = triplet_values_temp.array
|
|
1371
|
+
|
|
1372
|
+
triplet_values.zero_()
|
|
1373
|
+
|
|
1374
|
+
if nodal:
|
|
1375
|
+
wp.launch(
|
|
1376
|
+
kernel=kernel,
|
|
1377
|
+
dim=triplet_values.shape,
|
|
1378
|
+
inputs=[
|
|
1379
|
+
domain_elt_arg,
|
|
1380
|
+
domain_elt_index_arg,
|
|
1381
|
+
test_arg,
|
|
1382
|
+
test.space.topology.topo_arg_value(device),
|
|
1383
|
+
field_arg_values,
|
|
1384
|
+
value_struct_values,
|
|
1385
|
+
triplet_rows,
|
|
1386
|
+
triplet_cols,
|
|
1387
|
+
triplet_values,
|
|
1388
|
+
],
|
|
1389
|
+
device=device,
|
|
1390
|
+
)
|
|
1391
|
+
elif isinstance(test, LocalTestField):
|
|
1392
|
+
local_result = cache.borrow_temporary(
|
|
1393
|
+
temporary_store=temporary_store,
|
|
1394
|
+
device=device,
|
|
1395
|
+
requires_grad=False,
|
|
1396
|
+
shape=(
|
|
1397
|
+
quadrature.evaluation_point_count(),
|
|
1398
|
+
test.value_dof_count,
|
|
1399
|
+
trial.value_dof_count,
|
|
1400
|
+
test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
|
|
1401
|
+
),
|
|
1402
|
+
dtype=float,
|
|
1403
|
+
)
|
|
1404
|
+
|
|
1405
|
+
wp.launch(
|
|
1406
|
+
kernel=kernel,
|
|
1407
|
+
dim=(
|
|
1408
|
+
quadrature.evaluation_point_count(),
|
|
1409
|
+
test.value_dof_count,
|
|
1410
|
+
trial.value_dof_count,
|
|
1411
|
+
trial.TAYLOR_DOF_COUNT,
|
|
1412
|
+
),
|
|
1413
|
+
inputs=[
|
|
1414
|
+
qp_arg,
|
|
1415
|
+
quadrature.element_index_arg_value(device),
|
|
1416
|
+
domain_elt_arg,
|
|
1417
|
+
domain_elt_index_arg,
|
|
1418
|
+
field_arg_values,
|
|
1419
|
+
value_struct_values,
|
|
1420
|
+
local_result.array,
|
|
1421
|
+
],
|
|
1422
|
+
device=device,
|
|
1423
|
+
)
|
|
1424
|
+
|
|
1425
|
+
vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
|
|
1426
|
+
vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
|
|
1427
|
+
local_result_as_vec = wp.array(
|
|
1428
|
+
data=None,
|
|
1429
|
+
ptr=local_result.array.ptr,
|
|
1430
|
+
capacity=local_result.array.capacity,
|
|
1431
|
+
device=local_result.array.device,
|
|
1432
|
+
shape=vec_array_shape,
|
|
1433
|
+
dtype=vec_array_dtype,
|
|
1434
|
+
)
|
|
1435
|
+
|
|
1436
|
+
dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
|
|
1437
|
+
|
|
1438
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1439
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1440
|
+
wp.launch(
|
|
1441
|
+
kernel=dispatch_kernel,
|
|
1442
|
+
dim=(
|
|
1443
|
+
test.space_restriction.node_count(),
|
|
1444
|
+
test.node_dof_count,
|
|
1445
|
+
trial.node_dof_count,
|
|
1446
|
+
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1447
|
+
),
|
|
1448
|
+
inputs=[
|
|
1449
|
+
qp_arg,
|
|
1450
|
+
domain_elt_arg,
|
|
1451
|
+
domain_elt_index_arg,
|
|
1452
|
+
test_arg,
|
|
1453
|
+
test.global_field.eval_arg_value(device),
|
|
1454
|
+
trial_partition_arg,
|
|
1455
|
+
trial_topology_arg,
|
|
1456
|
+
trial.global_field.eval_arg_value(device),
|
|
1457
|
+
local_result_as_vec,
|
|
1458
|
+
triplet_rows,
|
|
1459
|
+
triplet_cols,
|
|
1460
|
+
triplet_values,
|
|
1461
|
+
],
|
|
1462
|
+
device=device,
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
local_result.release()
|
|
1466
|
+
|
|
1467
|
+
else:
|
|
1468
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
1469
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1470
|
+
wp.launch(
|
|
1471
|
+
kernel=kernel,
|
|
1472
|
+
dim=(
|
|
1473
|
+
test.space_restriction.node_count(),
|
|
1474
|
+
trial.space.topology.MAX_NODES_PER_ELEMENT,
|
|
1475
|
+
test.node_dof_count,
|
|
1476
|
+
trial.node_dof_count,
|
|
1477
|
+
),
|
|
1478
|
+
inputs=[
|
|
1479
|
+
qp_arg,
|
|
1480
|
+
domain_elt_arg,
|
|
1481
|
+
domain_elt_index_arg,
|
|
1482
|
+
test_arg,
|
|
1483
|
+
trial_partition_arg,
|
|
1484
|
+
trial_topology_arg,
|
|
1485
|
+
field_arg_values,
|
|
1486
|
+
value_struct_values,
|
|
1487
|
+
triplet_rows,
|
|
1488
|
+
triplet_cols,
|
|
1489
|
+
triplet_values,
|
|
1490
|
+
],
|
|
1491
|
+
device=device,
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
if output is not None:
|
|
1495
|
+
if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
|
|
1496
|
+
raise RuntimeError(
|
|
1497
|
+
f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
1498
|
+
)
|
|
1499
|
+
|
|
1500
|
+
if output is None or add_to_output:
|
|
1501
|
+
bsr_result = bsr_zeros(
|
|
1502
|
+
rows_of_blocks=test.space_partition.node_count(),
|
|
1503
|
+
cols_of_blocks=trial.space_partition.node_count(),
|
|
1504
|
+
block_type=block_type,
|
|
1505
|
+
device=device,
|
|
1506
|
+
)
|
|
1507
|
+
else:
|
|
1508
|
+
bsr_result = output
|
|
1509
|
+
|
|
1510
|
+
bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
|
|
1511
|
+
|
|
1512
|
+
# Do not wait for garbage collection
|
|
1513
|
+
triplet_values_temp.release()
|
|
1514
|
+
triplet_rows_temp.release()
|
|
1515
|
+
triplet_cols_temp.release()
|
|
1516
|
+
|
|
1517
|
+
if add_to_output:
|
|
1518
|
+
output += bsr_result
|
|
1519
|
+
else:
|
|
1520
|
+
output = bsr_result
|
|
1521
|
+
|
|
1522
|
+
return output
|
|
1523
|
+
|
|
1524
|
+
|
|
1525
|
+
def _pick_assembly_strategy(
|
|
1526
|
+
assembly: Optional[str], nodal: bool, operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
|
|
1527
|
+
):
|
|
1528
|
+
if assembly is not None:
|
|
1529
|
+
if assembly not in ("generic", "nodal", "dispatch"):
|
|
1530
|
+
raise ValueError(f"Invalid assembly strategy'{assembly}'")
|
|
1531
|
+
return assembly
|
|
1532
|
+
elif nodal:
|
|
1533
|
+
return "nodal"
|
|
1534
|
+
|
|
1535
|
+
test_operators = operators.get(arguments.test_name, {})
|
|
1536
|
+
trial_operators = operators.get(arguments.trial_name, {})
|
|
1537
|
+
uses_at_node = at_node in test_operators or at_node in trial_operators
|
|
1538
|
+
|
|
1539
|
+
return "generic" if uses_at_node else "dispatch"
|
|
1540
|
+
|
|
1541
|
+
|
|
1542
|
+
def integrate(
|
|
1543
|
+
integrand: Integrand,
|
|
1544
|
+
domain: Optional[GeometryDomain] = None,
|
|
1545
|
+
quadrature: Optional[Quadrature] = None,
|
|
1546
|
+
nodal: bool = False,
|
|
1547
|
+
fields: Optional[Dict[str, FieldLike]] = None,
|
|
1548
|
+
values: Optional[Dict[str, Any]] = None,
|
|
1549
|
+
accumulate_dtype: type = wp.float64,
|
|
1550
|
+
output_dtype: Optional[type] = None,
|
|
1551
|
+
output: Optional[Union[BsrMatrix, wp.array]] = None,
|
|
1552
|
+
device=None,
|
|
1553
|
+
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
1554
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1555
|
+
assembly: Optional[str] = None,
|
|
1556
|
+
add: bool = False,
|
|
1557
|
+
bsr_options: Optional[Dict[str, Any]] = None,
|
|
1558
|
+
):
|
|
1559
|
+
"""
|
|
1560
|
+
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
|
|
1561
|
+
|
|
1562
|
+
Args:
|
|
1563
|
+
integrand: Form to be integrated, must have :func:`integrand` decorator
|
|
1564
|
+
domain: Integration domain. If None, deduced from fields
|
|
1565
|
+
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1566
|
+
nodal: Deprecated. Use the equivalent assembly="nodal" instead.
|
|
1567
|
+
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1568
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
1569
|
+
temporary_store: shared pool from which to allocate temporary arrays
|
|
1570
|
+
accumulate_dtype: Scalar type to be used for accumulating integration samples
|
|
1571
|
+
output: Sparse matrix or warp array into which to store the result of the integration
|
|
1572
|
+
output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
|
|
1573
|
+
device: Device on which to perform the integration
|
|
1574
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1575
|
+
assembly: Specifies the strategy for assembling the integrated vector or matrix:
|
|
1576
|
+
- "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1577
|
+
- "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
|
|
1578
|
+
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
|
|
1579
|
+
- `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
|
|
1580
|
+
add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
|
|
1581
|
+
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
1582
|
+
"""
|
|
1583
|
+
if fields is None:
|
|
1584
|
+
fields = {}
|
|
1585
|
+
|
|
1586
|
+
if values is None:
|
|
1587
|
+
values = {}
|
|
1588
|
+
|
|
1589
|
+
if not isinstance(integrand, Integrand):
|
|
1590
|
+
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
|
|
1591
|
+
|
|
1592
|
+
# test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
|
|
1593
|
+
arguments = _parse_integrand_arguments(integrand, fields)
|
|
1594
|
+
|
|
1595
|
+
test = None
|
|
1596
|
+
if arguments.test_name:
|
|
1597
|
+
test = arguments.field_args[arguments.test_name]
|
|
1598
|
+
trial = None
|
|
1599
|
+
if arguments.trial_name:
|
|
1600
|
+
if test is None:
|
|
1601
|
+
raise ValueError("A trial field cannot be provided without a test field")
|
|
1602
|
+
trial = arguments.field_args[arguments.trial_name]
|
|
1603
|
+
if test.domain != trial.domain:
|
|
1604
|
+
raise ValueError("Incompatible test and trial domains")
|
|
1605
|
+
|
|
1606
|
+
if domain is None:
|
|
1607
|
+
if quadrature is not None:
|
|
1608
|
+
domain = quadrature.domain
|
|
1609
|
+
elif test is not None:
|
|
1610
|
+
domain = test.domain
|
|
1611
|
+
|
|
1612
|
+
if domain is None:
|
|
1613
|
+
raise ValueError("Must provide at least one of domain, quadrature, or test field")
|
|
1614
|
+
if test is not None and domain != test.domain:
|
|
1615
|
+
raise NotImplementedError("Mixing integration and test domain is not supported yet")
|
|
1616
|
+
|
|
1617
|
+
if add and output is None:
|
|
1618
|
+
raise ValueError("An 'output' array or matrix needs to be provided for add=True")
|
|
1619
|
+
|
|
1620
|
+
if arguments.domain_name is not None:
|
|
1621
|
+
arguments.field_args[arguments.domain_name] = domain
|
|
1622
|
+
|
|
1623
|
+
_find_integrand_operators(integrand, arguments.field_args)
|
|
1624
|
+
|
|
1625
|
+
assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
|
|
1626
|
+
# print("assembly for ", integrand.name, ":", strategy)
|
|
1627
|
+
|
|
1628
|
+
if assembly == "dispatch":
|
|
1629
|
+
if test is not None:
|
|
1630
|
+
test = LocalTestField(test)
|
|
1631
|
+
arguments.field_args[arguments.test_name] = test
|
|
1632
|
+
if trial is not None:
|
|
1633
|
+
trial = LocalTrialField(trial)
|
|
1634
|
+
arguments.field_args[arguments.trial_name] = trial
|
|
1635
|
+
|
|
1636
|
+
if assembly == "nodal":
|
|
1637
|
+
if quadrature is not None:
|
|
1638
|
+
raise ValueError("Cannot specify quadrature for nodal integration")
|
|
1639
|
+
|
|
1640
|
+
if test is None:
|
|
1641
|
+
raise ValueError("Nodal integration requires specifying a test function")
|
|
1642
|
+
|
|
1643
|
+
if trial is not None and test.space_partition != trial.space_partition:
|
|
1644
|
+
raise ValueError(
|
|
1645
|
+
"Bilinear nodal integration requires test and trial to be defined on the same function space"
|
|
1646
|
+
)
|
|
1647
|
+
else:
|
|
1648
|
+
if quadrature is None:
|
|
1649
|
+
order = sum(field.degree for field in fields.values())
|
|
1650
|
+
quadrature = RegularQuadrature(domain=domain, order=order)
|
|
1651
|
+
elif domain != quadrature.domain:
|
|
1652
|
+
raise ValueError("Incompatible integration and quadrature domain")
|
|
1653
|
+
|
|
1654
|
+
# Canonicalize types
|
|
1655
|
+
accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
|
|
1656
|
+
if output is not None:
|
|
1657
|
+
if isinstance(output, BsrMatrix):
|
|
1658
|
+
output_dtype = output.scalar_type
|
|
1659
|
+
else:
|
|
1660
|
+
output_dtype = output.dtype
|
|
1661
|
+
elif output_dtype is None:
|
|
1662
|
+
output_dtype = accumulate_dtype
|
|
1663
|
+
else:
|
|
1664
|
+
output_dtype = wp.types.type_to_warp(output_dtype)
|
|
1665
|
+
|
|
1666
|
+
kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
|
|
1667
|
+
integrand=integrand,
|
|
1668
|
+
domain=domain,
|
|
1669
|
+
quadrature=quadrature,
|
|
1670
|
+
arguments=arguments,
|
|
1671
|
+
test=test,
|
|
1672
|
+
trial=trial,
|
|
1673
|
+
accumulate_dtype=accumulate_dtype,
|
|
1674
|
+
output_dtype=output_dtype,
|
|
1675
|
+
kernel_options=kernel_options,
|
|
1676
|
+
)
|
|
1677
|
+
|
|
1678
|
+
return _launch_integrate_kernel(
|
|
1679
|
+
integrand=integrand,
|
|
1680
|
+
kernel=kernel,
|
|
1681
|
+
FieldStruct=FieldStruct,
|
|
1682
|
+
ValueStruct=ValueStruct,
|
|
1683
|
+
domain=domain,
|
|
1684
|
+
quadrature=quadrature,
|
|
1685
|
+
test=test,
|
|
1686
|
+
trial=trial,
|
|
1687
|
+
fields=arguments.field_args,
|
|
1688
|
+
values=values,
|
|
1689
|
+
accumulate_dtype=accumulate_dtype,
|
|
1690
|
+
temporary_store=temporary_store,
|
|
1691
|
+
output_dtype=output_dtype,
|
|
1692
|
+
output=output,
|
|
1693
|
+
add_to_output=add,
|
|
1694
|
+
bsr_options=bsr_options,
|
|
1695
|
+
device=device,
|
|
1696
|
+
)
|
|
1697
|
+
|
|
1698
|
+
|
|
1699
|
+
def get_interpolate_to_field_function(
|
|
1700
|
+
integrand_func: wp.Function,
|
|
1701
|
+
domain: GeometryDomain,
|
|
1702
|
+
FieldStruct: wp.codegen.Struct,
|
|
1703
|
+
ValueStruct: wp.codegen.Struct,
|
|
1704
|
+
dest: FieldRestriction,
|
|
1705
|
+
):
|
|
1706
|
+
value_type = dest.space.dtype
|
|
1707
|
+
|
|
1708
|
+
def interpolate_to_field_fn(
|
|
1709
|
+
local_node_index: int,
|
|
1710
|
+
domain_arg: domain.ElementArg,
|
|
1711
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1712
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1713
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1714
|
+
fields: FieldStruct,
|
|
1715
|
+
values: ValueStruct,
|
|
1716
|
+
):
|
|
1717
|
+
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1718
|
+
element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
|
|
1719
|
+
|
|
1720
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1721
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1722
|
+
node_weight = 1.0
|
|
1723
|
+
|
|
1724
|
+
# Volume-weighted average across elements
|
|
1725
|
+
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
|
|
1726
|
+
|
|
1727
|
+
val_sum = value_type(0.0)
|
|
1728
|
+
vol_sum = float(0.0)
|
|
1729
|
+
|
|
1730
|
+
for n in range(element_beg, element_end):
|
|
1731
|
+
node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
|
|
1732
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
1733
|
+
|
|
1734
|
+
if n == element_beg:
|
|
1735
|
+
node_index = dest.space.topology.element_node_index(
|
|
1736
|
+
domain_arg, dest_eval_arg.topology_arg, element_index, node_element_index.node_index_in_element
|
|
1737
|
+
)
|
|
1738
|
+
|
|
1739
|
+
coords = dest.space.node_coords_in_element(
|
|
1740
|
+
domain_arg,
|
|
1741
|
+
dest_eval_arg.space_arg,
|
|
1742
|
+
element_index,
|
|
1743
|
+
node_element_index.node_index_in_element,
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
if coords[0] != OUTSIDE:
|
|
1747
|
+
sample = Sample(
|
|
1748
|
+
element_index,
|
|
1749
|
+
coords,
|
|
1750
|
+
node_index,
|
|
1751
|
+
node_weight,
|
|
1752
|
+
test_dof_index,
|
|
1753
|
+
trial_dof_index,
|
|
1754
|
+
)
|
|
1755
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
1756
|
+
val = integrand_func(sample, fields, values)
|
|
1757
|
+
|
|
1758
|
+
vol_sum += vol
|
|
1759
|
+
val_sum += vol * val
|
|
1760
|
+
|
|
1761
|
+
return val_sum, vol_sum
|
|
1762
|
+
|
|
1763
|
+
return interpolate_to_field_fn
|
|
1764
|
+
|
|
1765
|
+
|
|
1766
|
+
def get_interpolate_to_field_kernel(
|
|
1767
|
+
interpolate_to_field_fn: wp.Function,
|
|
1768
|
+
domain: GeometryDomain,
|
|
1769
|
+
FieldStruct: wp.codegen.Struct,
|
|
1770
|
+
ValueStruct: wp.codegen.Struct,
|
|
1771
|
+
dest: FieldRestriction,
|
|
1772
|
+
):
|
|
1773
|
+
@wp.func
|
|
1774
|
+
def _find_node_in_element(
|
|
1775
|
+
domain_arg: domain.ElementArg,
|
|
1776
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1777
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1778
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1779
|
+
partition_node_index: int,
|
|
1780
|
+
):
|
|
1781
|
+
element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
|
|
1782
|
+
|
|
1783
|
+
for n in range(element_beg, element_end):
|
|
1784
|
+
node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
|
|
1785
|
+
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
1786
|
+
coords = dest.space.node_coords_in_element(
|
|
1787
|
+
domain_arg,
|
|
1788
|
+
dest_eval_arg.space_arg,
|
|
1789
|
+
element_index,
|
|
1790
|
+
node_element_index.node_index_in_element,
|
|
1791
|
+
)
|
|
1792
|
+
if coords[0] != OUTSIDE:
|
|
1793
|
+
return element_index, node_element_index.node_index_in_element
|
|
1794
|
+
|
|
1795
|
+
return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
|
|
1796
|
+
|
|
1797
|
+
def interpolate_to_field_kernel_fn(
|
|
1798
|
+
domain_arg: domain.ElementArg,
|
|
1799
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1800
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1801
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1802
|
+
fields: FieldStruct,
|
|
1803
|
+
values: ValueStruct,
|
|
1804
|
+
):
|
|
1805
|
+
local_node_index = wp.tid()
|
|
1806
|
+
|
|
1807
|
+
val_sum, vol_sum = interpolate_to_field_fn(
|
|
1808
|
+
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
|
|
1809
|
+
)
|
|
1810
|
+
|
|
1811
|
+
if vol_sum > 0.0:
|
|
1812
|
+
partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1813
|
+
|
|
1814
|
+
# Grab first element containing node; there must be at least one since vol_sum != 0
|
|
1815
|
+
element_index, node_index_in_element = _find_node_in_element(
|
|
1816
|
+
domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
|
|
1817
|
+
)
|
|
1818
|
+
dest.field.set_node_value(
|
|
1819
|
+
domain_arg,
|
|
1820
|
+
dest_eval_arg,
|
|
1821
|
+
element_index,
|
|
1822
|
+
node_index_in_element,
|
|
1823
|
+
partition_node_index,
|
|
1824
|
+
val_sum / vol_sum,
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
return interpolate_to_field_kernel_fn
|
|
1828
|
+
|
|
1829
|
+
|
|
1830
|
+
def get_interpolate_at_quadrature_kernel(
|
|
1831
|
+
integrand_func: wp.Function,
|
|
1832
|
+
domain: GeometryDomain,
|
|
1833
|
+
quadrature: Quadrature,
|
|
1834
|
+
FieldStruct: wp.codegen.Struct,
|
|
1835
|
+
ValueStruct: wp.codegen.Struct,
|
|
1836
|
+
value_type: type,
|
|
1837
|
+
):
|
|
1838
|
+
def interpolate_at_quadrature_nonvalued_kernel_fn(
|
|
1839
|
+
qp_arg: quadrature.Arg,
|
|
1840
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1841
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1842
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1843
|
+
fields: FieldStruct,
|
|
1844
|
+
values: ValueStruct,
|
|
1845
|
+
result: wp.array(dtype=float),
|
|
1846
|
+
):
|
|
1847
|
+
qp_eval_index = wp.tid()
|
|
1848
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1849
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1850
|
+
return
|
|
1851
|
+
|
|
1852
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1853
|
+
|
|
1854
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1855
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1856
|
+
|
|
1857
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1858
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1859
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1860
|
+
|
|
1861
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1862
|
+
integrand_func(sample, fields, values)
|
|
1863
|
+
|
|
1864
|
+
def interpolate_at_quadrature_kernel_fn(
|
|
1865
|
+
qp_arg: quadrature.Arg,
|
|
1866
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1867
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1868
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1869
|
+
fields: FieldStruct,
|
|
1870
|
+
values: ValueStruct,
|
|
1871
|
+
result: wp.array(dtype=value_type),
|
|
1872
|
+
):
|
|
1873
|
+
qp_eval_index = wp.tid()
|
|
1874
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1875
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1876
|
+
return
|
|
1877
|
+
|
|
1878
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1879
|
+
|
|
1880
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1881
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1882
|
+
|
|
1883
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1884
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1885
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1886
|
+
|
|
1887
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1888
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1889
|
+
|
|
1890
|
+
return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
|
|
1891
|
+
|
|
1892
|
+
|
|
1893
|
+
def get_interpolate_jacobian_at_quadrature_kernel(
|
|
1894
|
+
integrand_func: wp.Function,
|
|
1895
|
+
domain: GeometryDomain,
|
|
1896
|
+
quadrature: Quadrature,
|
|
1897
|
+
FieldStruct: wp.codegen.Struct,
|
|
1898
|
+
ValueStruct: wp.codegen.Struct,
|
|
1899
|
+
trial: TrialField,
|
|
1900
|
+
value_size: int,
|
|
1901
|
+
value_type: type,
|
|
1902
|
+
):
|
|
1903
|
+
MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
1904
|
+
VALUE_SIZE = wp.constant(value_size)
|
|
1905
|
+
|
|
1906
|
+
def interpolate_jacobian_kernel_fn(
|
|
1907
|
+
qp_arg: quadrature.Arg,
|
|
1908
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1909
|
+
domain_arg: domain.ElementArg,
|
|
1910
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1911
|
+
trial_partition_arg: trial.space_partition.PartitionArg,
|
|
1912
|
+
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
1913
|
+
fields: FieldStruct,
|
|
1914
|
+
values: ValueStruct,
|
|
1915
|
+
triplet_rows: wp.array(dtype=int),
|
|
1916
|
+
triplet_cols: wp.array(dtype=int),
|
|
1917
|
+
triplet_values: wp.array3d(dtype=value_type),
|
|
1918
|
+
):
|
|
1919
|
+
qp_eval_index, trial_node, trial_dof = wp.tid()
|
|
1920
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1921
|
+
|
|
1922
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1923
|
+
return
|
|
1924
|
+
|
|
1925
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1926
|
+
if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
|
|
1927
|
+
return
|
|
1928
|
+
|
|
1929
|
+
element_trial_node_count = trial.space.topology.element_node_count(
|
|
1930
|
+
domain_arg, trial_topology_arg, element_index
|
|
1931
|
+
)
|
|
1932
|
+
|
|
1933
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1934
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1935
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1936
|
+
|
|
1937
|
+
block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
|
|
1938
|
+
|
|
1939
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1940
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
1941
|
+
|
|
1942
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1943
|
+
val = integrand_func(sample, fields, values)
|
|
1944
|
+
|
|
1945
|
+
for k in range(VALUE_SIZE):
|
|
1946
|
+
triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
|
|
1947
|
+
|
|
1948
|
+
if trial_dof == 0:
|
|
1949
|
+
if trial_node < element_trial_node_count:
|
|
1950
|
+
trial_node_index = trial.space_partition.partition_node_index(
|
|
1951
|
+
trial_partition_arg,
|
|
1952
|
+
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
|
|
1953
|
+
)
|
|
1954
|
+
else:
|
|
1955
|
+
trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
|
|
1956
|
+
triplet_rows[block_offset] = qp_index
|
|
1957
|
+
triplet_cols[block_offset] = trial_node_index
|
|
1958
|
+
|
|
1959
|
+
return interpolate_jacobian_kernel_fn
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
def get_interpolate_free_kernel(
|
|
1963
|
+
integrand_func: wp.Function,
|
|
1964
|
+
domain: GeometryDomain,
|
|
1965
|
+
FieldStruct: wp.codegen.Struct,
|
|
1966
|
+
ValueStruct: wp.codegen.Struct,
|
|
1967
|
+
value_type: type,
|
|
1968
|
+
):
|
|
1969
|
+
def interpolate_free_nonvalued_kernel_fn(
|
|
1970
|
+
dim: int,
|
|
1971
|
+
domain_arg: domain.ElementArg,
|
|
1972
|
+
fields: FieldStruct,
|
|
1973
|
+
values: ValueStruct,
|
|
1974
|
+
result: wp.array(dtype=float),
|
|
1975
|
+
):
|
|
1976
|
+
qp_index = wp.tid()
|
|
1977
|
+
qp_weight = 1.0 / float(dim)
|
|
1978
|
+
element_index = NULL_ELEMENT_INDEX
|
|
1979
|
+
coords = Coords(OUTSIDE)
|
|
1980
|
+
|
|
1981
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1982
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1983
|
+
|
|
1984
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1985
|
+
integrand_func(sample, fields, values)
|
|
1986
|
+
|
|
1987
|
+
def interpolate_free_kernel_fn(
|
|
1988
|
+
dim: int,
|
|
1989
|
+
domain_arg: domain.ElementArg,
|
|
1990
|
+
fields: FieldStruct,
|
|
1991
|
+
values: ValueStruct,
|
|
1992
|
+
result: wp.array(dtype=value_type),
|
|
1993
|
+
):
|
|
1994
|
+
qp_index = wp.tid()
|
|
1995
|
+
qp_weight = 1.0 / float(dim)
|
|
1996
|
+
element_index = NULL_ELEMENT_INDEX
|
|
1997
|
+
coords = Coords(OUTSIDE)
|
|
1998
|
+
|
|
1999
|
+
test_dof_index = NULL_DOF_INDEX
|
|
2000
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
2001
|
+
|
|
2002
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
2003
|
+
|
|
2004
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
2005
|
+
|
|
2006
|
+
return interpolate_free_nonvalued_kernel_fn if value_type is None else interpolate_free_kernel_fn
|
|
2007
|
+
|
|
2008
|
+
|
|
2009
|
+
def _generate_interpolate_kernel(
|
|
2010
|
+
integrand: Integrand,
|
|
2011
|
+
domain: GeometryDomain,
|
|
2012
|
+
dest: Optional[Union[FieldLike, wp.array]],
|
|
2013
|
+
quadrature: Optional[Quadrature],
|
|
2014
|
+
arguments: IntegrandArguments,
|
|
2015
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
|
2016
|
+
) -> wp.Kernel:
|
|
2017
|
+
# Generate field struct
|
|
2018
|
+
FieldStruct = _gen_field_struct(arguments.field_args)
|
|
2019
|
+
ValueStruct = cache.get_argument_struct(arguments.value_args)
|
|
2020
|
+
|
|
2021
|
+
_notify_operator_usage(integrand, arguments.field_args)
|
|
2022
|
+
|
|
2023
|
+
# Check if kernel exist in cache
|
|
2024
|
+
field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
|
|
2025
|
+
if isinstance(dest, FieldRestriction):
|
|
2026
|
+
kernel_suffix = f"_itp_{field_names}_{dest.domain.name}_{dest.space_restriction.space_partition.name}"
|
|
2027
|
+
else:
|
|
2028
|
+
dest_dtype = dest.dtype if dest else None
|
|
2029
|
+
type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
|
|
2030
|
+
if quadrature is None:
|
|
2031
|
+
kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
|
|
2032
|
+
else:
|
|
2033
|
+
kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
|
|
2034
|
+
|
|
2035
|
+
kernel = cache.get_integrand_kernel(
|
|
2036
|
+
integrand=integrand,
|
|
2037
|
+
suffix=kernel_suffix,
|
|
2038
|
+
kernel_options=kernel_options,
|
|
2039
|
+
)
|
|
2040
|
+
if kernel is not None:
|
|
2041
|
+
return kernel, FieldStruct, ValueStruct
|
|
2042
|
+
|
|
2043
|
+
# Not found in cache, transform integrand and generate kernel
|
|
2044
|
+
_check_field_compat(integrand, arguments, domain)
|
|
2045
|
+
|
|
2046
|
+
integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
|
|
2047
|
+
|
|
2048
|
+
# Generate interpolation kernel
|
|
2049
|
+
if isinstance(dest, FieldRestriction):
|
|
2050
|
+
# need to split into kernel + function for differentiability
|
|
2051
|
+
interpolate_fn = get_interpolate_to_field_function(
|
|
2052
|
+
integrand_func,
|
|
2053
|
+
domain,
|
|
2054
|
+
dest=dest,
|
|
2055
|
+
FieldStruct=FieldStruct,
|
|
2056
|
+
ValueStruct=ValueStruct,
|
|
2057
|
+
)
|
|
2058
|
+
|
|
2059
|
+
interpolate_fn = cache.get_integrand_function(
|
|
2060
|
+
integrand=integrand,
|
|
2061
|
+
func=interpolate_fn,
|
|
2062
|
+
suffix=kernel_suffix,
|
|
2063
|
+
code_transformers=[
|
|
2064
|
+
PassFieldArgsToIntegrand(
|
|
2065
|
+
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
2066
|
+
)
|
|
2067
|
+
],
|
|
2068
|
+
)
|
|
2069
|
+
|
|
2070
|
+
interpolate_kernel_fn = get_interpolate_to_field_kernel(
|
|
2071
|
+
interpolate_fn,
|
|
2072
|
+
domain,
|
|
2073
|
+
dest=dest,
|
|
2074
|
+
FieldStruct=FieldStruct,
|
|
2075
|
+
ValueStruct=ValueStruct,
|
|
2076
|
+
)
|
|
2077
|
+
elif quadrature is not None:
|
|
2078
|
+
if arguments.trial_name:
|
|
2079
|
+
trial = arguments.field_args[arguments.trial_name]
|
|
2080
|
+
interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
|
|
2081
|
+
integrand_func,
|
|
2082
|
+
domain=domain,
|
|
2083
|
+
quadrature=quadrature,
|
|
2084
|
+
FieldStruct=FieldStruct,
|
|
2085
|
+
ValueStruct=ValueStruct,
|
|
2086
|
+
trial=trial,
|
|
2087
|
+
value_size=dest.block_shape[0],
|
|
2088
|
+
value_type=dest.scalar_type,
|
|
2089
|
+
)
|
|
2090
|
+
else:
|
|
2091
|
+
interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
|
|
2092
|
+
integrand_func,
|
|
2093
|
+
domain=domain,
|
|
2094
|
+
quadrature=quadrature,
|
|
2095
|
+
value_type=dest_dtype,
|
|
2096
|
+
FieldStruct=FieldStruct,
|
|
2097
|
+
ValueStruct=ValueStruct,
|
|
2098
|
+
)
|
|
2099
|
+
else:
|
|
2100
|
+
interpolate_kernel_fn = get_interpolate_free_kernel(
|
|
2101
|
+
integrand_func,
|
|
2102
|
+
domain=domain,
|
|
2103
|
+
value_type=dest_dtype,
|
|
2104
|
+
FieldStruct=FieldStruct,
|
|
2105
|
+
ValueStruct=ValueStruct,
|
|
2106
|
+
)
|
|
2107
|
+
|
|
2108
|
+
kernel = cache.get_integrand_kernel(
|
|
2109
|
+
integrand=integrand,
|
|
2110
|
+
kernel_fn=interpolate_kernel_fn,
|
|
2111
|
+
suffix=kernel_suffix,
|
|
2112
|
+
kernel_options=kernel_options,
|
|
2113
|
+
code_transformers=[
|
|
2114
|
+
PassFieldArgsToIntegrand(
|
|
2115
|
+
arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
|
|
2116
|
+
)
|
|
2117
|
+
],
|
|
2118
|
+
)
|
|
2119
|
+
|
|
2120
|
+
return kernel, FieldStruct, ValueStruct
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
def _launch_interpolate_kernel(
|
|
2124
|
+
integrand: Integrand,
|
|
2125
|
+
kernel: wp.kernel,
|
|
2126
|
+
FieldStruct: wp.codegen.Struct,
|
|
2127
|
+
ValueStruct: wp.codegen.Struct,
|
|
2128
|
+
domain: GeometryDomain,
|
|
2129
|
+
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
2130
|
+
quadrature: Optional[Quadrature],
|
|
2131
|
+
dim: int,
|
|
2132
|
+
trial: Optional[TrialField],
|
|
2133
|
+
fields: Dict[str, FieldLike],
|
|
2134
|
+
values: Dict[str, Any],
|
|
2135
|
+
temporary_store: Optional[cache.TemporaryStore],
|
|
2136
|
+
bsr_options: Optional[Dict[str, Any]],
|
|
2137
|
+
device,
|
|
2138
|
+
) -> wp.Kernel:
|
|
2139
|
+
# Set-up launch arguments
|
|
2140
|
+
elt_arg = domain.element_arg_value(device=device)
|
|
2141
|
+
elt_index_arg = domain.element_index_arg_value(device=device)
|
|
2142
|
+
|
|
2143
|
+
field_arg_values = FieldStruct()
|
|
2144
|
+
for k, v in fields.items():
|
|
2145
|
+
if not isinstance(v, GeometryDomain):
|
|
2146
|
+
setattr(field_arg_values, k, v.eval_arg_value(device=device))
|
|
2147
|
+
|
|
2148
|
+
value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
|
|
2149
|
+
|
|
2150
|
+
if isinstance(dest, FieldRestriction):
|
|
2151
|
+
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
2152
|
+
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
2153
|
+
|
|
2154
|
+
wp.launch(
|
|
2155
|
+
kernel=kernel,
|
|
2156
|
+
dim=dest.space_restriction.node_count(),
|
|
2157
|
+
inputs=[
|
|
2158
|
+
elt_arg,
|
|
2159
|
+
elt_index_arg,
|
|
2160
|
+
dest_node_arg,
|
|
2161
|
+
dest_eval_arg,
|
|
2162
|
+
field_arg_values,
|
|
2163
|
+
value_struct_values,
|
|
2164
|
+
],
|
|
2165
|
+
device=device,
|
|
2166
|
+
)
|
|
2167
|
+
return
|
|
2168
|
+
|
|
2169
|
+
if quadrature is None:
|
|
2170
|
+
wp.launch(
|
|
2171
|
+
kernel=kernel,
|
|
2172
|
+
dim=dim,
|
|
2173
|
+
inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
|
|
2174
|
+
device=device,
|
|
2175
|
+
)
|
|
2176
|
+
return
|
|
2177
|
+
|
|
2178
|
+
qp_arg = quadrature.arg_value(device)
|
|
2179
|
+
qp_element_index_arg = quadrature.element_index_arg_value(device)
|
|
2180
|
+
if trial is None:
|
|
2181
|
+
wp.launch(
|
|
2182
|
+
kernel=kernel,
|
|
2183
|
+
dim=quadrature.evaluation_point_count(),
|
|
2184
|
+
inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2185
|
+
device=device,
|
|
2186
|
+
)
|
|
2187
|
+
return
|
|
2188
|
+
|
|
2189
|
+
nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
2190
|
+
|
|
2191
|
+
if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
|
|
2192
|
+
raise RuntimeError(
|
|
2193
|
+
f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
2194
|
+
)
|
|
2195
|
+
if dest.block_shape[1] != trial.node_dof_count:
|
|
2196
|
+
raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
|
|
2197
|
+
|
|
2198
|
+
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2199
|
+
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2200
|
+
triplet_values_temp = cache.borrow_temporary(
|
|
2201
|
+
temporary_store,
|
|
2202
|
+
dtype=dest.scalar_type,
|
|
2203
|
+
shape=(nnz, *dest.block_shape),
|
|
2204
|
+
device=device,
|
|
2205
|
+
)
|
|
2206
|
+
triplet_cols = triplet_cols_temp.array
|
|
2207
|
+
triplet_rows = triplet_rows_temp.array
|
|
2208
|
+
triplet_values = triplet_values_temp.array
|
|
2209
|
+
triplet_rows.fill_(-1)
|
|
2210
|
+
triplet_values.zero_()
|
|
2211
|
+
|
|
2212
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
2213
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
2214
|
+
|
|
2215
|
+
wp.launch(
|
|
2216
|
+
kernel=kernel,
|
|
2217
|
+
dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
|
|
2218
|
+
inputs=[
|
|
2219
|
+
qp_arg,
|
|
2220
|
+
qp_element_index_arg,
|
|
2221
|
+
elt_arg,
|
|
2222
|
+
elt_index_arg,
|
|
2223
|
+
trial_partition_arg,
|
|
2224
|
+
trial_topology_arg,
|
|
2225
|
+
field_arg_values,
|
|
2226
|
+
value_struct_values,
|
|
2227
|
+
triplet_rows,
|
|
2228
|
+
triplet_cols,
|
|
2229
|
+
triplet_values,
|
|
2230
|
+
],
|
|
2231
|
+
device=device,
|
|
2232
|
+
)
|
|
2233
|
+
|
|
2234
|
+
bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
|
|
2235
|
+
|
|
2236
|
+
|
|
2237
|
+
@integrand
|
|
2238
|
+
def _identity_field(field: Field, s: Sample):
|
|
2239
|
+
return field(s)
|
|
2240
|
+
|
|
2241
|
+
|
|
2242
|
+
def interpolate(
|
|
2243
|
+
integrand: Union[Integrand, FieldLike],
|
|
2244
|
+
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
2245
|
+
quadrature: Optional[Quadrature] = None,
|
|
2246
|
+
dim: int = 0,
|
|
2247
|
+
domain: Optional[Domain] = None,
|
|
2248
|
+
fields: Optional[Dict[str, FieldLike]] = None,
|
|
2249
|
+
values: Optional[Dict[str, Any]] = None,
|
|
2250
|
+
device=None,
|
|
2251
|
+
kernel_options: Optional[Dict[str, Any]] = None,
|
|
2252
|
+
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
2253
|
+
bsr_options: Optional[Dict[str, Any]] = None,
|
|
2254
|
+
):
|
|
2255
|
+
"""
|
|
2256
|
+
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
2257
|
+
|
|
2258
|
+
Args:
|
|
2259
|
+
integrand: Function to be interpolated: either a function with :func:`warp.fem.integrand` decorator or a field
|
|
2260
|
+
dest: Where to store the interpolation result. Can be either
|
|
2261
|
+
|
|
2262
|
+
- a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
|
|
2263
|
+
- a normal warp ``array``, or ``None``. In this case, the interpolation samples will determined by the `quadrature` or `dim` arguments, in that order.
|
|
2264
|
+
quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
|
|
2265
|
+
dim: Number of interpolation samples if `dest` is not a discrete field or restriction and `quadrature` is ``None``.
|
|
2266
|
+
In this case, the ``Sample`` passed to the `integrand` will be invalid, but the sample point index ``s.qp_index`` can be used to define custom interpolation logic.
|
|
2267
|
+
domain: Interpolation domain, only used if `dest` is not a field restriction and `quadrature` is ``None``
|
|
2268
|
+
fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
|
|
2269
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
2270
|
+
device: Device on which to perform the interpolation
|
|
2271
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
2272
|
+
temporary_store: shared pool from which to allocate temporary arrays
|
|
2273
|
+
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
2274
|
+
"""
|
|
2275
|
+
|
|
2276
|
+
if isinstance(integrand, FieldLike):
|
|
2277
|
+
fields = {"field": integrand}
|
|
2278
|
+
values = {}
|
|
2279
|
+
integrand = _identity_field
|
|
2280
|
+
|
|
2281
|
+
if fields is None:
|
|
2282
|
+
fields = {}
|
|
2283
|
+
|
|
2284
|
+
if values is None:
|
|
2285
|
+
values = {}
|
|
2286
|
+
|
|
2287
|
+
if not isinstance(integrand, Integrand):
|
|
2288
|
+
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
2289
|
+
|
|
2290
|
+
arguments = _parse_integrand_arguments(integrand, fields)
|
|
2291
|
+
if arguments.test_name:
|
|
2292
|
+
raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
|
|
2293
|
+
if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
|
|
2294
|
+
raise ValueError(
|
|
2295
|
+
f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
|
|
2296
|
+
)
|
|
2297
|
+
|
|
2298
|
+
if isinstance(dest, DiscreteField):
|
|
2299
|
+
dest = make_restriction(dest, domain=domain)
|
|
2300
|
+
|
|
2301
|
+
if isinstance(dest, FieldRestriction):
|
|
2302
|
+
domain = dest.domain
|
|
2303
|
+
elif quadrature is not None:
|
|
2304
|
+
domain = quadrature.domain
|
|
2305
|
+
|
|
2306
|
+
if arguments.domain_name:
|
|
2307
|
+
arguments.field_args[arguments.domain_name] = domain
|
|
2308
|
+
|
|
2309
|
+
_find_integrand_operators(integrand, arguments.field_args)
|
|
2310
|
+
|
|
2311
|
+
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
2312
|
+
integrand=integrand,
|
|
2313
|
+
domain=domain,
|
|
2314
|
+
dest=dest,
|
|
2315
|
+
quadrature=quadrature,
|
|
2316
|
+
arguments=arguments,
|
|
2317
|
+
kernel_options=kernel_options,
|
|
2318
|
+
)
|
|
2319
|
+
|
|
2320
|
+
return _launch_interpolate_kernel(
|
|
2321
|
+
integrand=integrand,
|
|
2322
|
+
kernel=kernel,
|
|
2323
|
+
FieldStruct=FieldStruct,
|
|
2324
|
+
ValueStruct=ValueStruct,
|
|
2325
|
+
domain=domain,
|
|
2326
|
+
dest=dest,
|
|
2327
|
+
quadrature=quadrature,
|
|
2328
|
+
dim=dim,
|
|
2329
|
+
trial=fields.get(arguments.trial_name),
|
|
2330
|
+
fields=arguments.field_args,
|
|
2331
|
+
values=values,
|
|
2332
|
+
temporary_store=temporary_store,
|
|
2333
|
+
bsr_options=bsr_options,
|
|
2334
|
+
device=device,
|
|
2335
|
+
)
|