warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,591 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any, Optional
|
|
17
|
+
|
|
18
|
+
import warp as wp
|
|
19
|
+
from warp.fem import cache
|
|
20
|
+
from warp.fem.domain import GeometryDomain
|
|
21
|
+
from warp.fem.geometry import Element
|
|
22
|
+
from warp.fem.space import FunctionSpace
|
|
23
|
+
from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
|
|
24
|
+
|
|
25
|
+
from ..polynomial import Polynomial
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@wp.struct
|
|
29
|
+
class QuadraturePointElementIndex:
|
|
30
|
+
domain_element_index: ElementIndex
|
|
31
|
+
qp_index_in_element: int
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Quadrature:
|
|
35
|
+
"""Interface class for quadrature rules"""
|
|
36
|
+
|
|
37
|
+
@wp.struct
|
|
38
|
+
class Arg:
|
|
39
|
+
"""Structure containing arguments to be passed to device functions"""
|
|
40
|
+
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def __init__(self, domain: GeometryDomain):
|
|
44
|
+
self._domain = domain
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def domain(self):
|
|
48
|
+
"""Domain over which this quadrature is defined"""
|
|
49
|
+
return self._domain
|
|
50
|
+
|
|
51
|
+
def arg_value(self, device) -> "Arg":
|
|
52
|
+
"""
|
|
53
|
+
Value of the argument to be passed to device
|
|
54
|
+
"""
|
|
55
|
+
arg = Quadrature.Arg()
|
|
56
|
+
return arg
|
|
57
|
+
|
|
58
|
+
def total_point_count(self):
|
|
59
|
+
"""Number of unique quadrature points that can be indexed by this rule.
|
|
60
|
+
Returns a number such that `point_index()` is always smaller than this number.
|
|
61
|
+
"""
|
|
62
|
+
raise NotImplementedError()
|
|
63
|
+
|
|
64
|
+
def evaluation_point_count(self):
|
|
65
|
+
"""Number of quadrature points that needs to be evaluated, mostly for internal purposes.
|
|
66
|
+
If the indexing scheme is sparse, or if a quadrature point is shared among multiple elements
|
|
67
|
+
(e.g, nodal quadrature), `evaluation_point_count` may be different than `total_point_count()`.
|
|
68
|
+
Returns a number such that `evaluation_point_index()` is always smaller than this number.
|
|
69
|
+
"""
|
|
70
|
+
return self.total_point_count()
|
|
71
|
+
|
|
72
|
+
def max_points_per_element(self):
|
|
73
|
+
"""Maximum number of points per element if known, or ``None`` otherwise"""
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def point_count(
|
|
78
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
79
|
+
qp_arg: Arg,
|
|
80
|
+
domain_element_index: ElementIndex,
|
|
81
|
+
geo_element_index: ElementIndex,
|
|
82
|
+
):
|
|
83
|
+
"""Number of quadrature points for a given element"""
|
|
84
|
+
raise NotImplementedError()
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def point_coords(
|
|
88
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
89
|
+
qp_arg: Arg,
|
|
90
|
+
domain_element_index: ElementIndex,
|
|
91
|
+
geo_element_index: ElementIndex,
|
|
92
|
+
element_qp_index: int,
|
|
93
|
+
):
|
|
94
|
+
"""Coordinates in element of the element's qp_index'th quadrature point"""
|
|
95
|
+
raise NotImplementedError()
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def point_weight(
|
|
99
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
100
|
+
qp_arg: Arg,
|
|
101
|
+
domain_element_index: ElementIndex,
|
|
102
|
+
geo_element_index: ElementIndex,
|
|
103
|
+
element_qp_index: int,
|
|
104
|
+
):
|
|
105
|
+
"""Weight of the element's qp_index'th quadrature point"""
|
|
106
|
+
raise NotImplementedError()
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def point_index(
|
|
110
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
111
|
+
qp_arg: Arg,
|
|
112
|
+
domain_element_index: ElementIndex,
|
|
113
|
+
geo_element_index: ElementIndex,
|
|
114
|
+
element_qp_index: int,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Global index of the element's qp_index'th quadrature point.
|
|
118
|
+
May be shared among elements.
|
|
119
|
+
This is what determines `qp_index` in integrands' `Sample` arguments.
|
|
120
|
+
"""
|
|
121
|
+
raise NotImplementedError()
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def point_evaluation_index(
|
|
125
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
126
|
+
qp_arg: Arg,
|
|
127
|
+
domain_element_index: ElementIndex,
|
|
128
|
+
geo_element_index: ElementIndex,
|
|
129
|
+
element_qp_index: int,
|
|
130
|
+
):
|
|
131
|
+
"""Quadrature point index according to evaluation order.
|
|
132
|
+
Quadrature points for distinct elements must have different evaluation indices.
|
|
133
|
+
Mostly for internal/parallelization purposes.
|
|
134
|
+
"""
|
|
135
|
+
raise NotImplementedError()
|
|
136
|
+
|
|
137
|
+
def __str__(self) -> str:
|
|
138
|
+
return self.name
|
|
139
|
+
|
|
140
|
+
# By default cache the mapping from evaluation point indices to domain elements
|
|
141
|
+
|
|
142
|
+
ElementIndexArg = wp.array(dtype=QuadraturePointElementIndex)
|
|
143
|
+
|
|
144
|
+
@cache.cached_arg_value
|
|
145
|
+
def element_index_arg_value(self, device):
|
|
146
|
+
"""Builds a map from quadrature point evaluation indices to their index in the element to which they belong"""
|
|
147
|
+
|
|
148
|
+
@cache.dynamic_kernel(f"{self.name}{self.domain.name}")
|
|
149
|
+
def quadrature_point_element_indices(
|
|
150
|
+
qp_arg: self.Arg,
|
|
151
|
+
domain_arg: self.domain.ElementArg,
|
|
152
|
+
domain_index_arg: self.domain.ElementIndexArg,
|
|
153
|
+
result: wp.array(dtype=QuadraturePointElementIndex),
|
|
154
|
+
):
|
|
155
|
+
domain_element_index = wp.tid()
|
|
156
|
+
element_index = self.domain.element_index(domain_index_arg, domain_element_index)
|
|
157
|
+
|
|
158
|
+
qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
159
|
+
for k in range(qp_point_count):
|
|
160
|
+
qp_eval_index = self.point_evaluation_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
161
|
+
result[qp_eval_index] = QuadraturePointElementIndex(domain_element_index, k)
|
|
162
|
+
|
|
163
|
+
null_qp_index = QuadraturePointElementIndex()
|
|
164
|
+
null_qp_index.domain_element_index = NULL_ELEMENT_INDEX
|
|
165
|
+
result = wp.full(
|
|
166
|
+
value=null_qp_index,
|
|
167
|
+
shape=(self.evaluation_point_count()),
|
|
168
|
+
dtype=QuadraturePointElementIndex,
|
|
169
|
+
device=device,
|
|
170
|
+
)
|
|
171
|
+
wp.launch(
|
|
172
|
+
quadrature_point_element_indices,
|
|
173
|
+
device=result.device,
|
|
174
|
+
dim=self.domain.element_count(),
|
|
175
|
+
inputs=[
|
|
176
|
+
self.arg_value(result.device),
|
|
177
|
+
self.domain.element_arg_value(result.device),
|
|
178
|
+
self.domain.element_index_arg_value(result.device),
|
|
179
|
+
result,
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
@wp.func
|
|
186
|
+
def evaluation_point_element_index(
|
|
187
|
+
element_index_arg: wp.array(dtype=QuadraturePointElementIndex),
|
|
188
|
+
qp_eval_index: QuadraturePointIndex,
|
|
189
|
+
):
|
|
190
|
+
"""Maps from quadrature point evaluation indices to their index in the element to which they belong
|
|
191
|
+
If the quadrature point does not exist, should return NULL_ELEMENT_INDEX as the domain element index
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
element_index = element_index_arg[qp_eval_index]
|
|
195
|
+
return element_index.domain_element_index, element_index.qp_index_in_element
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class _QuadratureWithRegularEvaluationPoints(Quadrature):
|
|
199
|
+
"""Helper subclass for quadrature formulas which use a uniform number of
|
|
200
|
+
evaluations points per element. Avoids building explicit mapping"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, domain: GeometryDomain, N: int):
|
|
203
|
+
super().__init__(domain)
|
|
204
|
+
self._EVALUATION_POINTS_PER_ELEMENT = N
|
|
205
|
+
|
|
206
|
+
self.point_evaluation_index = self._make_regular_point_evaluation_index()
|
|
207
|
+
self.evaluation_point_element_index = self._make_regular_evaluation_point_element_index()
|
|
208
|
+
|
|
209
|
+
ElementIndexArg = Quadrature.Arg
|
|
210
|
+
element_index_arg_value = Quadrature.arg_value
|
|
211
|
+
|
|
212
|
+
def evaluation_point_count(self):
|
|
213
|
+
return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
|
|
214
|
+
|
|
215
|
+
def _make_regular_point_evaluation_index(self):
|
|
216
|
+
N = self._EVALUATION_POINTS_PER_ELEMENT
|
|
217
|
+
|
|
218
|
+
@cache.dynamic_func(suffix=f"{self.name}")
|
|
219
|
+
def evaluation_point_index(
|
|
220
|
+
elt_arg: self.domain.ElementArg,
|
|
221
|
+
qp_arg: self.Arg,
|
|
222
|
+
domain_element_index: ElementIndex,
|
|
223
|
+
element_index: ElementIndex,
|
|
224
|
+
qp_index: int,
|
|
225
|
+
):
|
|
226
|
+
return N * domain_element_index + qp_index
|
|
227
|
+
|
|
228
|
+
return evaluation_point_index
|
|
229
|
+
|
|
230
|
+
def _make_regular_evaluation_point_element_index(self):
|
|
231
|
+
N = self._EVALUATION_POINTS_PER_ELEMENT
|
|
232
|
+
|
|
233
|
+
@cache.dynamic_func(suffix=f"{N}")
|
|
234
|
+
def quadrature_evaluation_point_element_index(
|
|
235
|
+
qp_arg: Quadrature.Arg,
|
|
236
|
+
qp_index: QuadraturePointIndex,
|
|
237
|
+
):
|
|
238
|
+
domain_element_index = qp_index // N
|
|
239
|
+
index_in_element = qp_index - domain_element_index * N
|
|
240
|
+
return domain_element_index, index_in_element
|
|
241
|
+
|
|
242
|
+
return quadrature_evaluation_point_element_index
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
|
|
246
|
+
"""Regular quadrature formula, using a constant set of quadrature points per element"""
|
|
247
|
+
|
|
248
|
+
@wp.struct
|
|
249
|
+
class Arg:
|
|
250
|
+
# Quadrature points and weights used to be passed as Warp constants,
|
|
251
|
+
# but this tended to incur register spilling for high point counts
|
|
252
|
+
points: wp.array(dtype=Coords)
|
|
253
|
+
weights: wp.array(dtype=float)
|
|
254
|
+
|
|
255
|
+
# Cache common formulas so we do dot have to do h2d transfer for each call
|
|
256
|
+
class CachedFormula:
|
|
257
|
+
_cache = {}
|
|
258
|
+
|
|
259
|
+
def __init__(self, element: Element, order: int, family: Polynomial):
|
|
260
|
+
self.points, self.weights = element.instantiate_quadrature(order, family)
|
|
261
|
+
self.count = wp.constant(len(self.points))
|
|
262
|
+
|
|
263
|
+
@cache.cached_arg_value
|
|
264
|
+
def arg_value(self, device):
|
|
265
|
+
arg = RegularQuadrature.Arg()
|
|
266
|
+
arg.points = wp.array(self.points, device=device, dtype=Coords)
|
|
267
|
+
arg.weights = wp.array(self.weights, device=device, dtype=float)
|
|
268
|
+
return arg
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def get(element: Element, order: int, family: Polynomial):
|
|
272
|
+
key = (element.__class__.__name__, order, family)
|
|
273
|
+
try:
|
|
274
|
+
return RegularQuadrature.CachedFormula._cache[key]
|
|
275
|
+
except KeyError:
|
|
276
|
+
quadrature = RegularQuadrature.CachedFormula(element, order, family)
|
|
277
|
+
RegularQuadrature.CachedFormula._cache[key] = quadrature
|
|
278
|
+
return quadrature
|
|
279
|
+
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
domain: GeometryDomain,
|
|
283
|
+
order: int,
|
|
284
|
+
family: Polynomial = None,
|
|
285
|
+
):
|
|
286
|
+
self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
|
|
287
|
+
self.family = family
|
|
288
|
+
self.order = order
|
|
289
|
+
|
|
290
|
+
super().__init__(domain, self._formula.count)
|
|
291
|
+
|
|
292
|
+
self.point_count = self._make_point_count()
|
|
293
|
+
self.point_index = self._make_point_index()
|
|
294
|
+
self.point_coords = self._make_point_coords()
|
|
295
|
+
self.point_weight = self._make_point_weight()
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def name(self):
|
|
299
|
+
return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
|
|
300
|
+
|
|
301
|
+
def total_point_count(self):
|
|
302
|
+
return self._formula.count * self.domain.element_count()
|
|
303
|
+
|
|
304
|
+
def max_points_per_element(self):
|
|
305
|
+
return self._formula.count
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def points(self):
|
|
309
|
+
return self._formula.points
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def weights(self):
|
|
313
|
+
return self._formula.weights
|
|
314
|
+
|
|
315
|
+
def arg_value(self, device):
|
|
316
|
+
return self._formula.arg_value(device)
|
|
317
|
+
|
|
318
|
+
def _make_point_count(self):
|
|
319
|
+
N = self._formula.count
|
|
320
|
+
|
|
321
|
+
@cache.dynamic_func(suffix=self.name)
|
|
322
|
+
def point_count(
|
|
323
|
+
elt_arg: self.domain.ElementArg,
|
|
324
|
+
qp_arg: self.Arg,
|
|
325
|
+
domain_element_index: ElementIndex,
|
|
326
|
+
element_index: ElementIndex,
|
|
327
|
+
):
|
|
328
|
+
return N
|
|
329
|
+
|
|
330
|
+
return point_count
|
|
331
|
+
|
|
332
|
+
def _make_point_coords(self):
|
|
333
|
+
@cache.dynamic_func(suffix=self.name)
|
|
334
|
+
def point_coords(
|
|
335
|
+
elt_arg: self.domain.ElementArg,
|
|
336
|
+
qp_arg: self.Arg,
|
|
337
|
+
domain_element_index: ElementIndex,
|
|
338
|
+
element_index: ElementIndex,
|
|
339
|
+
qp_index: int,
|
|
340
|
+
):
|
|
341
|
+
return qp_arg.points[qp_index]
|
|
342
|
+
|
|
343
|
+
return point_coords
|
|
344
|
+
|
|
345
|
+
def _make_point_weight(self):
|
|
346
|
+
@cache.dynamic_func(suffix=self.name)
|
|
347
|
+
def point_weight(
|
|
348
|
+
elt_arg: self.domain.ElementArg,
|
|
349
|
+
qp_arg: self.Arg,
|
|
350
|
+
domain_element_index: ElementIndex,
|
|
351
|
+
element_index: ElementIndex,
|
|
352
|
+
qp_index: int,
|
|
353
|
+
):
|
|
354
|
+
return qp_arg.weights[qp_index]
|
|
355
|
+
|
|
356
|
+
return point_weight
|
|
357
|
+
|
|
358
|
+
def _make_point_index(self):
|
|
359
|
+
N = self._formula.count
|
|
360
|
+
|
|
361
|
+
@cache.dynamic_func(suffix=self.name)
|
|
362
|
+
def point_index(
|
|
363
|
+
elt_arg: self.domain.ElementArg,
|
|
364
|
+
qp_arg: self.Arg,
|
|
365
|
+
domain_element_index: ElementIndex,
|
|
366
|
+
element_index: ElementIndex,
|
|
367
|
+
qp_index: int,
|
|
368
|
+
):
|
|
369
|
+
return N * domain_element_index + qp_index
|
|
370
|
+
|
|
371
|
+
return point_index
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class NodalQuadrature(Quadrature):
|
|
375
|
+
"""Quadrature using space node points as quadrature points
|
|
376
|
+
|
|
377
|
+
Note that in contrast to the `nodal=True` flag for :func:`integrate`, using this quadrature does not imply
|
|
378
|
+
any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
|
|
382
|
+
self._space = space
|
|
383
|
+
|
|
384
|
+
super().__init__(domain)
|
|
385
|
+
|
|
386
|
+
self.Arg = self._make_arg()
|
|
387
|
+
|
|
388
|
+
self.point_count = self._make_point_count()
|
|
389
|
+
self.point_index = self._make_point_index()
|
|
390
|
+
self.point_coords = self._make_point_coords()
|
|
391
|
+
self.point_weight = self._make_point_weight()
|
|
392
|
+
self.point_evaluation_index = self._make_point_evaluation_index()
|
|
393
|
+
|
|
394
|
+
@property
|
|
395
|
+
def name(self):
|
|
396
|
+
return f"{self.__class__.__name__}_{self._space.name}"
|
|
397
|
+
|
|
398
|
+
def total_point_count(self):
|
|
399
|
+
return self._space.node_count()
|
|
400
|
+
|
|
401
|
+
def max_points_per_element(self):
|
|
402
|
+
return self._space.topology.MAX_NODES_PER_ELEMENT
|
|
403
|
+
|
|
404
|
+
def _make_arg(self):
|
|
405
|
+
@cache.dynamic_struct(suffix=self.name)
|
|
406
|
+
class Arg:
|
|
407
|
+
space_arg: self._space.SpaceArg
|
|
408
|
+
topo_arg: self._space.topology.TopologyArg
|
|
409
|
+
|
|
410
|
+
return Arg
|
|
411
|
+
|
|
412
|
+
@cache.cached_arg_value
|
|
413
|
+
def arg_value(self, device):
|
|
414
|
+
arg = self.Arg()
|
|
415
|
+
arg.space_arg = self._space.space_arg_value(device)
|
|
416
|
+
arg.topo_arg = self._space.topology.topo_arg_value(device)
|
|
417
|
+
return arg
|
|
418
|
+
|
|
419
|
+
def _make_point_count(self):
|
|
420
|
+
@cache.dynamic_func(suffix=self.name)
|
|
421
|
+
def point_count(
|
|
422
|
+
elt_arg: self.domain.ElementArg,
|
|
423
|
+
qp_arg: self.Arg,
|
|
424
|
+
domain_element_index: ElementIndex,
|
|
425
|
+
element_index: ElementIndex,
|
|
426
|
+
):
|
|
427
|
+
return self._space.topology.element_node_count(elt_arg, qp_arg.topo_arg, element_index)
|
|
428
|
+
|
|
429
|
+
return point_count
|
|
430
|
+
|
|
431
|
+
def _make_point_coords(self):
|
|
432
|
+
@cache.dynamic_func(suffix=self.name)
|
|
433
|
+
def point_coords(
|
|
434
|
+
elt_arg: self.domain.ElementArg,
|
|
435
|
+
qp_arg: self.Arg,
|
|
436
|
+
domain_element_index: ElementIndex,
|
|
437
|
+
element_index: ElementIndex,
|
|
438
|
+
qp_index: int,
|
|
439
|
+
):
|
|
440
|
+
return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
|
|
441
|
+
|
|
442
|
+
return point_coords
|
|
443
|
+
|
|
444
|
+
def _make_point_weight(self):
|
|
445
|
+
@cache.dynamic_func(suffix=self.name)
|
|
446
|
+
def point_weight(
|
|
447
|
+
elt_arg: self.domain.ElementArg,
|
|
448
|
+
qp_arg: self.Arg,
|
|
449
|
+
domain_element_index: ElementIndex,
|
|
450
|
+
element_index: ElementIndex,
|
|
451
|
+
qp_index: int,
|
|
452
|
+
):
|
|
453
|
+
return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
|
|
454
|
+
|
|
455
|
+
return point_weight
|
|
456
|
+
|
|
457
|
+
def _make_point_index(self):
|
|
458
|
+
@cache.dynamic_func(suffix=self.name)
|
|
459
|
+
def point_index(
|
|
460
|
+
elt_arg: self.domain.ElementArg,
|
|
461
|
+
qp_arg: self.Arg,
|
|
462
|
+
domain_element_index: ElementIndex,
|
|
463
|
+
element_index: ElementIndex,
|
|
464
|
+
qp_index: int,
|
|
465
|
+
):
|
|
466
|
+
return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
|
|
467
|
+
|
|
468
|
+
return point_index
|
|
469
|
+
|
|
470
|
+
def evaluation_point_count(self):
|
|
471
|
+
return self.domain.element_count() * self._space.topology.MAX_NODES_PER_ELEMENT
|
|
472
|
+
|
|
473
|
+
def _make_point_evaluation_index(self):
|
|
474
|
+
N = self._space.topology.MAX_NODES_PER_ELEMENT
|
|
475
|
+
|
|
476
|
+
@cache.dynamic_func(suffix=self.name)
|
|
477
|
+
def evaluation_point_index(
|
|
478
|
+
elt_arg: self.domain.ElementArg,
|
|
479
|
+
qp_arg: self.Arg,
|
|
480
|
+
domain_element_index: ElementIndex,
|
|
481
|
+
element_index: ElementIndex,
|
|
482
|
+
qp_index: int,
|
|
483
|
+
):
|
|
484
|
+
return N * domain_element_index + qp_index
|
|
485
|
+
|
|
486
|
+
return evaluation_point_index
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
|
|
490
|
+
"""Quadrature using explicit per-cell points and weights.
|
|
491
|
+
|
|
492
|
+
The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
|
|
493
|
+
Quadrature points may be provided for either the whole geometry or just the domain's elements.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
domain: Domain of definition of the quadrature formula
|
|
497
|
+
points: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the coordinates of each quadrature point.
|
|
498
|
+
weights: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the weight for each quadrature point.
|
|
499
|
+
|
|
500
|
+
See also: :class:`PicQuadrature`
|
|
501
|
+
"""
|
|
502
|
+
|
|
503
|
+
@wp.struct
|
|
504
|
+
class Arg:
|
|
505
|
+
points_per_cell: int
|
|
506
|
+
points: wp.array2d(dtype=Coords)
|
|
507
|
+
weights: wp.array2d(dtype=float)
|
|
508
|
+
|
|
509
|
+
def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
|
|
510
|
+
if points.shape != weights.shape:
|
|
511
|
+
raise ValueError("Points and weights arrays must have the same shape")
|
|
512
|
+
|
|
513
|
+
if points.shape[0] == domain.geometry_element_count():
|
|
514
|
+
self.point_index = ExplicitQuadrature._point_index_geo
|
|
515
|
+
self.point_coords = ExplicitQuadrature._point_coords_geo
|
|
516
|
+
self.point_weight = ExplicitQuadrature._point_weight_geo
|
|
517
|
+
elif points.shape[0] == domain.element_count():
|
|
518
|
+
self.point_index = ExplicitQuadrature._point_index_domain
|
|
519
|
+
self.point_coords = ExplicitQuadrature._point_coords_domain
|
|
520
|
+
self.point_weight = ExplicitQuadrature._point_weight_domain
|
|
521
|
+
else:
|
|
522
|
+
raise NotImplementedError(
|
|
523
|
+
"The number of rows of points and weights must match the element count of either the domain or the geometry"
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
self._points_per_cell = points.shape[1]
|
|
527
|
+
|
|
528
|
+
self._whole_geo = points.shape[0] == domain.geometry_element_count()
|
|
529
|
+
|
|
530
|
+
super().__init__(domain, self._points_per_cell)
|
|
531
|
+
self._points = points
|
|
532
|
+
self._weights = weights
|
|
533
|
+
|
|
534
|
+
@property
|
|
535
|
+
def name(self):
|
|
536
|
+
return f"{self.__class__.__name__}_{self._whole_geo}"
|
|
537
|
+
|
|
538
|
+
def total_point_count(self):
|
|
539
|
+
return self._weights.size
|
|
540
|
+
|
|
541
|
+
def max_points_per_element(self):
|
|
542
|
+
return self._points_per_cell
|
|
543
|
+
|
|
544
|
+
@cache.cached_arg_value
|
|
545
|
+
def arg_value(self, device):
|
|
546
|
+
arg = self.Arg()
|
|
547
|
+
arg.points_per_cell = self._points_per_cell
|
|
548
|
+
arg.points = self._points.to(device)
|
|
549
|
+
arg.weights = self._weights.to(device)
|
|
550
|
+
|
|
551
|
+
return arg
|
|
552
|
+
|
|
553
|
+
@wp.func
|
|
554
|
+
def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
|
|
555
|
+
return qp_arg.points.shape[1]
|
|
556
|
+
|
|
557
|
+
@wp.func
|
|
558
|
+
def _point_coords_domain(
|
|
559
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
560
|
+
):
|
|
561
|
+
return qp_arg.points[domain_element_index, qp_index]
|
|
562
|
+
|
|
563
|
+
@wp.func
|
|
564
|
+
def _point_weight_domain(
|
|
565
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
566
|
+
):
|
|
567
|
+
return qp_arg.weights[domain_element_index, qp_index]
|
|
568
|
+
|
|
569
|
+
@wp.func
|
|
570
|
+
def _point_index_domain(
|
|
571
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
572
|
+
):
|
|
573
|
+
return qp_arg.points_per_cell * domain_element_index + qp_index
|
|
574
|
+
|
|
575
|
+
@wp.func
|
|
576
|
+
def _point_coords_geo(
|
|
577
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
578
|
+
):
|
|
579
|
+
return qp_arg.points[element_index, qp_index]
|
|
580
|
+
|
|
581
|
+
@wp.func
|
|
582
|
+
def _point_weight_geo(
|
|
583
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
584
|
+
):
|
|
585
|
+
return qp_arg.weights[element_index, qp_index]
|
|
586
|
+
|
|
587
|
+
@wp.func
|
|
588
|
+
def _point_index_geo(
|
|
589
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
|
|
590
|
+
):
|
|
591
|
+
return qp_arg.points_per_cell * element_index + qp_index
|