warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import ast
|
|
2
17
|
import inspect
|
|
3
18
|
import textwrap
|
|
@@ -19,7 +34,7 @@ from warp.fem.field import (
|
|
|
19
34
|
make_restriction,
|
|
20
35
|
)
|
|
21
36
|
from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
|
|
22
|
-
from warp.fem.linalg import array_axpy
|
|
37
|
+
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
23
38
|
from warp.fem.operator import Integrand, Operator, at_node, integrand
|
|
24
39
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
25
40
|
from warp.fem.types import (
|
|
@@ -478,7 +493,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
478
493
|
callee = getattr(call.func, "id", None)
|
|
479
494
|
|
|
480
495
|
if callee == self._func_name:
|
|
481
|
-
# Replace function arguments with
|
|
496
|
+
# Replace function arguments with our generated structs
|
|
482
497
|
call.args.clear()
|
|
483
498
|
for arg in self._arg_names:
|
|
484
499
|
if arg == self._domain_name:
|
|
@@ -561,33 +576,33 @@ def get_integrate_constant_kernel(
|
|
|
561
576
|
):
|
|
562
577
|
def integrate_kernel_fn(
|
|
563
578
|
qp_arg: quadrature.Arg,
|
|
579
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
564
580
|
domain_arg: domain.ElementArg,
|
|
565
581
|
domain_index_arg: domain.ElementIndexArg,
|
|
566
582
|
fields: FieldStruct,
|
|
567
583
|
values: ValueStruct,
|
|
568
584
|
result: wp.array(dtype=accumulate_dtype),
|
|
569
585
|
):
|
|
570
|
-
|
|
586
|
+
qp_eval_index = wp.tid()
|
|
587
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
588
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
589
|
+
return
|
|
590
|
+
|
|
571
591
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
572
|
-
|
|
592
|
+
|
|
593
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
594
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
595
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
573
596
|
|
|
574
597
|
test_dof_index = NULL_DOF_INDEX
|
|
575
598
|
trial_dof_index = NULL_DOF_INDEX
|
|
576
599
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
580
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
581
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
582
|
-
|
|
583
|
-
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
584
|
-
vol = domain.element_measure(domain_arg, sample)
|
|
600
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
601
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
585
602
|
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
elem_sum += accumulate_dtype(qp_weight * vol * val)
|
|
603
|
+
val = integrand_func(sample, fields, values)
|
|
589
604
|
|
|
590
|
-
wp.atomic_add(result, 0,
|
|
605
|
+
wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
|
|
591
606
|
|
|
592
607
|
return integrate_kernel_fn
|
|
593
608
|
|
|
@@ -730,35 +745,35 @@ def get_integrate_linear_local_kernel(
|
|
|
730
745
|
ValueStruct: wp.codegen.Struct,
|
|
731
746
|
test: LocalTestField,
|
|
732
747
|
):
|
|
733
|
-
TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
|
|
734
|
-
|
|
735
748
|
def integrate_kernel_fn(
|
|
736
749
|
qp_arg: quadrature.Arg,
|
|
750
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
737
751
|
domain_arg: domain.ElementArg,
|
|
738
752
|
domain_index_arg: domain.ElementIndexArg,
|
|
739
753
|
fields: FieldStruct,
|
|
740
754
|
values: ValueStruct,
|
|
741
755
|
result: wp.array3d(dtype=float),
|
|
742
756
|
):
|
|
743
|
-
|
|
744
|
-
|
|
757
|
+
qp_eval_index, taylor_dof, test_dof = wp.tid()
|
|
758
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
745
759
|
|
|
746
|
-
|
|
747
|
-
|
|
760
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
761
|
+
return
|
|
748
762
|
|
|
749
|
-
|
|
750
|
-
for qp in range(qp_point_count):
|
|
751
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
752
|
-
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
753
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
763
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
754
764
|
|
|
755
|
-
|
|
765
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
766
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
767
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
756
768
|
|
|
757
|
-
|
|
769
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
758
770
|
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
771
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
772
|
+
test_dof_index = DofIndex(taylor_dof, test_dof)
|
|
773
|
+
|
|
774
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
775
|
+
val = integrand_func(sample, fields, values)
|
|
776
|
+
result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
|
|
762
777
|
|
|
763
778
|
return integrate_kernel_fn
|
|
764
779
|
|
|
@@ -803,10 +818,10 @@ def get_integrate_bilinear_kernel(
|
|
|
803
818
|
element_trial_node_count = trial.space.topology.element_node_count(
|
|
804
819
|
domain_arg, trial_topology_arg, element_index
|
|
805
820
|
)
|
|
806
|
-
qp_point_count = wp.
|
|
821
|
+
qp_point_count = wp.where(
|
|
807
822
|
trial_node < element_trial_node_count,
|
|
808
|
-
0,
|
|
809
823
|
quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
|
|
824
|
+
0,
|
|
810
825
|
)
|
|
811
826
|
|
|
812
827
|
test_dof_index = DofIndex(
|
|
@@ -948,36 +963,38 @@ def get_integrate_bilinear_local_kernel(
|
|
|
948
963
|
|
|
949
964
|
def integrate_kernel_fn(
|
|
950
965
|
qp_arg: quadrature.Arg,
|
|
966
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
951
967
|
domain_arg: domain.ElementArg,
|
|
952
968
|
domain_index_arg: domain.ElementIndexArg,
|
|
953
969
|
fields: FieldStruct,
|
|
954
970
|
values: ValueStruct,
|
|
955
971
|
result: wp.array4d(dtype=float),
|
|
956
972
|
):
|
|
957
|
-
|
|
973
|
+
qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
|
|
974
|
+
|
|
975
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
976
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
977
|
+
return
|
|
978
|
+
|
|
958
979
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
959
980
|
|
|
960
|
-
|
|
961
|
-
|
|
981
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
982
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
983
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
962
984
|
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
966
|
-
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
967
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
985
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
986
|
+
qp_vol = vol * qp_weight
|
|
968
987
|
|
|
969
|
-
|
|
970
|
-
qp_vol = vol * qp_weight
|
|
988
|
+
trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
|
|
971
989
|
|
|
972
|
-
|
|
973
|
-
|
|
990
|
+
for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
|
|
991
|
+
taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
|
|
974
992
|
|
|
975
|
-
|
|
976
|
-
trial_dof_index = DofIndex(qp_index, trial_dof_offset + trial_taylor_dof)
|
|
993
|
+
test_dof_index = DofIndex(test_taylor_dof, test_dof)
|
|
977
994
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
995
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
996
|
+
val = integrand_func(sample, fields, values)
|
|
997
|
+
result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
|
|
981
998
|
|
|
982
999
|
return integrate_kernel_fn
|
|
983
1000
|
|
|
@@ -1123,6 +1140,7 @@ def _launch_integrate_kernel(
|
|
|
1123
1140
|
output_dtype: type,
|
|
1124
1141
|
output: Optional[Union[wp.array, BsrMatrix]],
|
|
1125
1142
|
add_to_output: bool,
|
|
1143
|
+
bsr_options: Optional[Dict[str, Any]],
|
|
1126
1144
|
device,
|
|
1127
1145
|
):
|
|
1128
1146
|
# Set-up launch arguments
|
|
@@ -1160,9 +1178,10 @@ def _launch_integrate_kernel(
|
|
|
1160
1178
|
|
|
1161
1179
|
wp.launch(
|
|
1162
1180
|
kernel=kernel,
|
|
1163
|
-
dim=
|
|
1181
|
+
dim=quadrature.evaluation_point_count(),
|
|
1164
1182
|
inputs=[
|
|
1165
1183
|
qp_arg,
|
|
1184
|
+
quadrature.element_index_arg_value(device),
|
|
1166
1185
|
domain_elt_arg,
|
|
1167
1186
|
domain_elt_index_arg,
|
|
1168
1187
|
field_arg_values,
|
|
@@ -1264,15 +1283,16 @@ def _launch_integrate_kernel(
|
|
|
1264
1283
|
temporary_store=temporary_store,
|
|
1265
1284
|
device=device,
|
|
1266
1285
|
requires_grad=output.requires_grad,
|
|
1267
|
-
shape=(quadrature.
|
|
1286
|
+
shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
|
|
1268
1287
|
dtype=float,
|
|
1269
1288
|
)
|
|
1270
1289
|
|
|
1271
1290
|
wp.launch(
|
|
1272
1291
|
kernel=kernel,
|
|
1273
|
-
dim=
|
|
1292
|
+
dim=local_result.array.shape,
|
|
1274
1293
|
inputs=[
|
|
1275
1294
|
qp_arg,
|
|
1295
|
+
quadrature.element_index_arg_value(device),
|
|
1276
1296
|
domain_elt_arg,
|
|
1277
1297
|
domain_elt_index_arg,
|
|
1278
1298
|
field_arg_values,
|
|
@@ -1374,7 +1394,7 @@ def _launch_integrate_kernel(
|
|
|
1374
1394
|
device=device,
|
|
1375
1395
|
requires_grad=False,
|
|
1376
1396
|
shape=(
|
|
1377
|
-
quadrature.
|
|
1397
|
+
quadrature.evaluation_point_count(),
|
|
1378
1398
|
test.value_dof_count,
|
|
1379
1399
|
trial.value_dof_count,
|
|
1380
1400
|
test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
|
|
@@ -1384,9 +1404,15 @@ def _launch_integrate_kernel(
|
|
|
1384
1404
|
|
|
1385
1405
|
wp.launch(
|
|
1386
1406
|
kernel=kernel,
|
|
1387
|
-
dim=(
|
|
1407
|
+
dim=(
|
|
1408
|
+
quadrature.evaluation_point_count(),
|
|
1409
|
+
test.value_dof_count,
|
|
1410
|
+
trial.value_dof_count,
|
|
1411
|
+
trial.TAYLOR_DOF_COUNT,
|
|
1412
|
+
),
|
|
1388
1413
|
inputs=[
|
|
1389
1414
|
qp_arg,
|
|
1415
|
+
quadrature.element_index_arg_value(device),
|
|
1390
1416
|
domain_elt_arg,
|
|
1391
1417
|
domain_elt_index_arg,
|
|
1392
1418
|
field_arg_values,
|
|
@@ -1481,7 +1507,7 @@ def _launch_integrate_kernel(
|
|
|
1481
1507
|
else:
|
|
1482
1508
|
bsr_result = output
|
|
1483
1509
|
|
|
1484
|
-
bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values)
|
|
1510
|
+
bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
|
|
1485
1511
|
|
|
1486
1512
|
# Do not wait for garbage collection
|
|
1487
1513
|
triplet_values_temp.release()
|
|
@@ -1526,8 +1552,9 @@ def integrate(
|
|
|
1526
1552
|
device=None,
|
|
1527
1553
|
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
1528
1554
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1529
|
-
assembly: str = None,
|
|
1555
|
+
assembly: Optional[str] = None,
|
|
1530
1556
|
add: bool = False,
|
|
1557
|
+
bsr_options: Optional[Dict[str, Any]] = None,
|
|
1531
1558
|
):
|
|
1532
1559
|
"""
|
|
1533
1560
|
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
|
|
@@ -1551,6 +1578,7 @@ def integrate(
|
|
|
1551
1578
|
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
|
|
1552
1579
|
- `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
|
|
1553
1580
|
add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
|
|
1581
|
+
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
1554
1582
|
"""
|
|
1555
1583
|
if fields is None:
|
|
1556
1584
|
fields = {}
|
|
@@ -1663,6 +1691,7 @@ def integrate(
|
|
|
1663
1691
|
output_dtype=output_dtype,
|
|
1664
1692
|
output=output,
|
|
1665
1693
|
add_to_output=add,
|
|
1694
|
+
bsr_options=bsr_options,
|
|
1666
1695
|
device=device,
|
|
1667
1696
|
)
|
|
1668
1697
|
|
|
@@ -1808,53 +1837,128 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
1808
1837
|
):
|
|
1809
1838
|
def interpolate_at_quadrature_nonvalued_kernel_fn(
|
|
1810
1839
|
qp_arg: quadrature.Arg,
|
|
1840
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1811
1841
|
domain_arg: quadrature.domain.ElementArg,
|
|
1812
1842
|
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1813
1843
|
fields: FieldStruct,
|
|
1814
1844
|
values: ValueStruct,
|
|
1815
1845
|
result: wp.array(dtype=float),
|
|
1816
1846
|
):
|
|
1817
|
-
|
|
1847
|
+
qp_eval_index = wp.tid()
|
|
1848
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1849
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1850
|
+
return
|
|
1851
|
+
|
|
1818
1852
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1819
1853
|
|
|
1820
1854
|
test_dof_index = NULL_DOF_INDEX
|
|
1821
1855
|
trial_dof_index = NULL_DOF_INDEX
|
|
1822
1856
|
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1827
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1857
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1858
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1859
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1828
1860
|
|
|
1829
|
-
|
|
1830
|
-
|
|
1861
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1862
|
+
integrand_func(sample, fields, values)
|
|
1831
1863
|
|
|
1832
1864
|
def interpolate_at_quadrature_kernel_fn(
|
|
1833
1865
|
qp_arg: quadrature.Arg,
|
|
1866
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1834
1867
|
domain_arg: quadrature.domain.ElementArg,
|
|
1835
1868
|
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1836
1869
|
fields: FieldStruct,
|
|
1837
1870
|
values: ValueStruct,
|
|
1838
1871
|
result: wp.array(dtype=value_type),
|
|
1839
1872
|
):
|
|
1840
|
-
|
|
1873
|
+
qp_eval_index = wp.tid()
|
|
1874
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1875
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1876
|
+
return
|
|
1877
|
+
|
|
1841
1878
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1842
1879
|
|
|
1843
1880
|
test_dof_index = NULL_DOF_INDEX
|
|
1844
1881
|
trial_dof_index = NULL_DOF_INDEX
|
|
1845
1882
|
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1850
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1883
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1884
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1885
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1851
1886
|
|
|
1852
|
-
|
|
1853
|
-
|
|
1887
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1888
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1854
1889
|
|
|
1855
1890
|
return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
|
|
1856
1891
|
|
|
1857
1892
|
|
|
1893
|
+
def get_interpolate_jacobian_at_quadrature_kernel(
|
|
1894
|
+
integrand_func: wp.Function,
|
|
1895
|
+
domain: GeometryDomain,
|
|
1896
|
+
quadrature: Quadrature,
|
|
1897
|
+
FieldStruct: wp.codegen.Struct,
|
|
1898
|
+
ValueStruct: wp.codegen.Struct,
|
|
1899
|
+
trial: TrialField,
|
|
1900
|
+
value_size: int,
|
|
1901
|
+
value_type: type,
|
|
1902
|
+
):
|
|
1903
|
+
MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
1904
|
+
VALUE_SIZE = wp.constant(value_size)
|
|
1905
|
+
|
|
1906
|
+
def interpolate_jacobian_kernel_fn(
|
|
1907
|
+
qp_arg: quadrature.Arg,
|
|
1908
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1909
|
+
domain_arg: domain.ElementArg,
|
|
1910
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1911
|
+
trial_partition_arg: trial.space_partition.PartitionArg,
|
|
1912
|
+
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
1913
|
+
fields: FieldStruct,
|
|
1914
|
+
values: ValueStruct,
|
|
1915
|
+
triplet_rows: wp.array(dtype=int),
|
|
1916
|
+
triplet_cols: wp.array(dtype=int),
|
|
1917
|
+
triplet_values: wp.array3d(dtype=value_type),
|
|
1918
|
+
):
|
|
1919
|
+
qp_eval_index, trial_node, trial_dof = wp.tid()
|
|
1920
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1921
|
+
|
|
1922
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1923
|
+
return
|
|
1924
|
+
|
|
1925
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1926
|
+
if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
|
|
1927
|
+
return
|
|
1928
|
+
|
|
1929
|
+
element_trial_node_count = trial.space.topology.element_node_count(
|
|
1930
|
+
domain_arg, trial_topology_arg, element_index
|
|
1931
|
+
)
|
|
1932
|
+
|
|
1933
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1934
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1935
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1936
|
+
|
|
1937
|
+
block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
|
|
1938
|
+
|
|
1939
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1940
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
1941
|
+
|
|
1942
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1943
|
+
val = integrand_func(sample, fields, values)
|
|
1944
|
+
|
|
1945
|
+
for k in range(VALUE_SIZE):
|
|
1946
|
+
triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
|
|
1947
|
+
|
|
1948
|
+
if trial_dof == 0:
|
|
1949
|
+
if trial_node < element_trial_node_count:
|
|
1950
|
+
trial_node_index = trial.space_partition.partition_node_index(
|
|
1951
|
+
trial_partition_arg,
|
|
1952
|
+
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
|
|
1953
|
+
)
|
|
1954
|
+
else:
|
|
1955
|
+
trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
|
|
1956
|
+
triplet_rows[block_offset] = qp_index
|
|
1957
|
+
triplet_cols[block_offset] = trial_node_index
|
|
1958
|
+
|
|
1959
|
+
return interpolate_jacobian_kernel_fn
|
|
1960
|
+
|
|
1961
|
+
|
|
1858
1962
|
def get_interpolate_free_kernel(
|
|
1859
1963
|
integrand_func: wp.Function,
|
|
1860
1964
|
domain: GeometryDomain,
|
|
@@ -1924,9 +2028,9 @@ def _generate_interpolate_kernel(
|
|
|
1924
2028
|
dest_dtype = dest.dtype if dest else None
|
|
1925
2029
|
type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
|
|
1926
2030
|
if quadrature is None:
|
|
1927
|
-
kernel_suffix = f"_itp_{field_names}_{type_str}"
|
|
2031
|
+
kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
|
|
1928
2032
|
else:
|
|
1929
|
-
kernel_suffix = f"_itp_{field_names}_{quadrature.name}_{type_str}"
|
|
2033
|
+
kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
|
|
1930
2034
|
|
|
1931
2035
|
kernel = cache.get_integrand_kernel(
|
|
1932
2036
|
integrand=integrand,
|
|
@@ -1971,14 +2075,27 @@ def _generate_interpolate_kernel(
|
|
|
1971
2075
|
ValueStruct=ValueStruct,
|
|
1972
2076
|
)
|
|
1973
2077
|
elif quadrature is not None:
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
2078
|
+
if arguments.trial_name:
|
|
2079
|
+
trial = arguments.field_args[arguments.trial_name]
|
|
2080
|
+
interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
|
|
2081
|
+
integrand_func,
|
|
2082
|
+
domain=domain,
|
|
2083
|
+
quadrature=quadrature,
|
|
2084
|
+
FieldStruct=FieldStruct,
|
|
2085
|
+
ValueStruct=ValueStruct,
|
|
2086
|
+
trial=trial,
|
|
2087
|
+
value_size=dest.block_shape[0],
|
|
2088
|
+
value_type=dest.scalar_type,
|
|
2089
|
+
)
|
|
2090
|
+
else:
|
|
2091
|
+
interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
|
|
2092
|
+
integrand_func,
|
|
2093
|
+
domain=domain,
|
|
2094
|
+
quadrature=quadrature,
|
|
2095
|
+
value_type=dest_dtype,
|
|
2096
|
+
FieldStruct=FieldStruct,
|
|
2097
|
+
ValueStruct=ValueStruct,
|
|
2098
|
+
)
|
|
1982
2099
|
else:
|
|
1983
2100
|
interpolate_kernel_fn = get_interpolate_free_kernel(
|
|
1984
2101
|
integrand_func,
|
|
@@ -2012,8 +2129,11 @@ def _launch_interpolate_kernel(
|
|
|
2012
2129
|
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
2013
2130
|
quadrature: Optional[Quadrature],
|
|
2014
2131
|
dim: int,
|
|
2132
|
+
trial: Optional[TrialField],
|
|
2015
2133
|
fields: Dict[str, FieldLike],
|
|
2016
2134
|
values: Dict[str, Any],
|
|
2135
|
+
temporary_store: Optional[cache.TemporaryStore],
|
|
2136
|
+
bsr_options: Optional[Dict[str, Any]],
|
|
2017
2137
|
device,
|
|
2018
2138
|
) -> wp.Kernel:
|
|
2019
2139
|
# Set-up launch arguments
|
|
@@ -2044,21 +2164,74 @@ def _launch_interpolate_kernel(
|
|
|
2044
2164
|
],
|
|
2045
2165
|
device=device,
|
|
2046
2166
|
)
|
|
2047
|
-
|
|
2048
|
-
|
|
2167
|
+
return
|
|
2168
|
+
|
|
2169
|
+
if quadrature is None:
|
|
2049
2170
|
wp.launch(
|
|
2050
2171
|
kernel=kernel,
|
|
2051
|
-
dim=
|
|
2052
|
-
inputs=[
|
|
2172
|
+
dim=dim,
|
|
2173
|
+
inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
|
|
2053
2174
|
device=device,
|
|
2054
2175
|
)
|
|
2055
|
-
|
|
2176
|
+
return
|
|
2177
|
+
|
|
2178
|
+
qp_arg = quadrature.arg_value(device)
|
|
2179
|
+
qp_element_index_arg = quadrature.element_index_arg_value(device)
|
|
2180
|
+
if trial is None:
|
|
2056
2181
|
wp.launch(
|
|
2057
2182
|
kernel=kernel,
|
|
2058
|
-
dim=
|
|
2059
|
-
inputs=[
|
|
2183
|
+
dim=quadrature.evaluation_point_count(),
|
|
2184
|
+
inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2060
2185
|
device=device,
|
|
2061
2186
|
)
|
|
2187
|
+
return
|
|
2188
|
+
|
|
2189
|
+
nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
2190
|
+
|
|
2191
|
+
if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
|
|
2192
|
+
raise RuntimeError(
|
|
2193
|
+
f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
2194
|
+
)
|
|
2195
|
+
if dest.block_shape[1] != trial.node_dof_count:
|
|
2196
|
+
raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
|
|
2197
|
+
|
|
2198
|
+
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2199
|
+
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2200
|
+
triplet_values_temp = cache.borrow_temporary(
|
|
2201
|
+
temporary_store,
|
|
2202
|
+
dtype=dest.scalar_type,
|
|
2203
|
+
shape=(nnz, *dest.block_shape),
|
|
2204
|
+
device=device,
|
|
2205
|
+
)
|
|
2206
|
+
triplet_cols = triplet_cols_temp.array
|
|
2207
|
+
triplet_rows = triplet_rows_temp.array
|
|
2208
|
+
triplet_values = triplet_values_temp.array
|
|
2209
|
+
triplet_rows.fill_(-1)
|
|
2210
|
+
triplet_values.zero_()
|
|
2211
|
+
|
|
2212
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
2213
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
2214
|
+
|
|
2215
|
+
wp.launch(
|
|
2216
|
+
kernel=kernel,
|
|
2217
|
+
dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
|
|
2218
|
+
inputs=[
|
|
2219
|
+
qp_arg,
|
|
2220
|
+
qp_element_index_arg,
|
|
2221
|
+
elt_arg,
|
|
2222
|
+
elt_index_arg,
|
|
2223
|
+
trial_partition_arg,
|
|
2224
|
+
trial_topology_arg,
|
|
2225
|
+
field_arg_values,
|
|
2226
|
+
value_struct_values,
|
|
2227
|
+
triplet_rows,
|
|
2228
|
+
triplet_cols,
|
|
2229
|
+
triplet_values,
|
|
2230
|
+
],
|
|
2231
|
+
device=device,
|
|
2232
|
+
)
|
|
2233
|
+
|
|
2234
|
+
bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
|
|
2062
2235
|
|
|
2063
2236
|
|
|
2064
2237
|
@integrand
|
|
@@ -2076,6 +2249,8 @@ def interpolate(
|
|
|
2076
2249
|
values: Optional[Dict[str, Any]] = None,
|
|
2077
2250
|
device=None,
|
|
2078
2251
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
2252
|
+
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
2253
|
+
bsr_options: Optional[Dict[str, Any]] = None,
|
|
2079
2254
|
):
|
|
2080
2255
|
"""
|
|
2081
2256
|
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
@@ -2094,6 +2269,8 @@ def interpolate(
|
|
|
2094
2269
|
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
2095
2270
|
device: Device on which to perform the interpolation
|
|
2096
2271
|
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
2272
|
+
temporary_store: shared pool from which to allocate temporary arrays
|
|
2273
|
+
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
2097
2274
|
"""
|
|
2098
2275
|
|
|
2099
2276
|
if isinstance(integrand, FieldLike):
|
|
@@ -2111,8 +2288,12 @@ def interpolate(
|
|
|
2111
2288
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
2112
2289
|
|
|
2113
2290
|
arguments = _parse_integrand_arguments(integrand, fields)
|
|
2114
|
-
if arguments.test_name
|
|
2115
|
-
raise ValueError("Test
|
|
2291
|
+
if arguments.test_name:
|
|
2292
|
+
raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
|
|
2293
|
+
if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
|
|
2294
|
+
raise ValueError(
|
|
2295
|
+
f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
|
|
2296
|
+
)
|
|
2116
2297
|
|
|
2117
2298
|
if isinstance(dest, DiscreteField):
|
|
2118
2299
|
dest = make_restriction(dest, domain=domain)
|
|
@@ -2145,7 +2326,10 @@ def interpolate(
|
|
|
2145
2326
|
dest=dest,
|
|
2146
2327
|
quadrature=quadrature,
|
|
2147
2328
|
dim=dim,
|
|
2329
|
+
trial=fields.get(arguments.trial_name),
|
|
2148
2330
|
fields=arguments.field_args,
|
|
2149
2331
|
values=values,
|
|
2332
|
+
temporary_store=temporary_store,
|
|
2333
|
+
bsr_options=bsr_options,
|
|
2150
2334
|
device=device,
|
|
2151
2335
|
)
|