warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/fem/linalg.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
|
@@ -157,11 +172,11 @@ def householder_qr_decomposition(A: Any):
|
|
|
157
172
|
|
|
158
173
|
for i in range(type(x).length):
|
|
159
174
|
for k in range(type(x).length):
|
|
160
|
-
x[k] = wp.
|
|
175
|
+
x[k] = wp.where(k < i, zero, A[k, i])
|
|
161
176
|
|
|
162
177
|
alpha = wp.length(x) * wp.sign(x[i])
|
|
163
178
|
x[i] += alpha
|
|
164
|
-
two_over_x_sq = wp.
|
|
179
|
+
two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
|
|
165
180
|
|
|
166
181
|
A -= wp.outer(two_over_x_sq * x, x * A)
|
|
167
182
|
Q -= wp.outer(Q * x, two_over_x_sq * x)
|
|
@@ -186,11 +201,11 @@ def householder_make_hessenberg(A: Any):
|
|
|
186
201
|
|
|
187
202
|
for i in range(1, type(x).length):
|
|
188
203
|
for k in range(type(x).length):
|
|
189
|
-
x[k] = wp.
|
|
204
|
+
x[k] = wp.where(k < i, zero, A[k, i - 1])
|
|
190
205
|
|
|
191
206
|
alpha = wp.length(x) * wp.sign(x[i])
|
|
192
207
|
x[i] += alpha
|
|
193
|
-
two_over_x_sq = wp.
|
|
208
|
+
two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
|
|
194
209
|
|
|
195
210
|
# apply on both sides
|
|
196
211
|
A -= wp.outer(two_over_x_sq * x, x * A)
|
|
@@ -211,7 +226,7 @@ def solve_triangular(R: Any, b: Any):
|
|
|
211
226
|
for i in range(b.length, 0, -1):
|
|
212
227
|
j = i - 1
|
|
213
228
|
r = b[j] - wp.dot(R[j], x)
|
|
214
|
-
x[j] = wp.
|
|
229
|
+
x[j] = wp.where(R[j, j] == zero, zero, r / R[j, j])
|
|
215
230
|
|
|
216
231
|
return x
|
|
217
232
|
|
warp/fem/operator.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any, Callable, Dict, Optional, Set
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
warp/fem/polynomial.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
from enum import Enum
|
|
3
18
|
|
warp/fem/quadrature/__init__.py
CHANGED
|
@@ -1,2 +1,17 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from .pic_quadrature import PicQuadrature
|
|
2
17
|
from .quadrature import ExplicitQuadrature, NodalQuadrature, Quadrature, RegularQuadrature
|
|
@@ -1,9 +1,24 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any, Optional, Tuple, Union
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
|
4
19
|
from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
|
|
5
20
|
from warp.fem.domain import GeometryDomain
|
|
6
|
-
from warp.fem.types import Coords, ElementIndex, make_free_sample
|
|
21
|
+
from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
|
|
7
22
|
from warp.fem.utils import compress_node_indices
|
|
8
23
|
|
|
9
24
|
from .quadrature import Quadrature
|
|
@@ -53,10 +68,10 @@ class PicQuadrature(Quadrature):
|
|
|
53
68
|
def domain(self, domain: GeometryDomain):
|
|
54
69
|
# Allow changing the quadrature domain as long as underlying geometry and element kind are the same
|
|
55
70
|
if self.domain is not None and (
|
|
56
|
-
domain.
|
|
71
|
+
domain.element_kind != self.domain.element_kind or domain.geometry.base != self.domain.geometry.base
|
|
57
72
|
):
|
|
58
73
|
raise RuntimeError(
|
|
59
|
-
"
|
|
74
|
+
"The new domain must use the same base geometry and kind of elements as the current one."
|
|
60
75
|
)
|
|
61
76
|
|
|
62
77
|
self._domain = domain
|
|
@@ -74,11 +89,11 @@ class PicQuadrature(Quadrature):
|
|
|
74
89
|
arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
|
|
75
90
|
arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
|
|
76
91
|
arg.particle_fraction = self._particle_fraction.to(device)
|
|
77
|
-
arg.particle_coords = self.
|
|
92
|
+
arg.particle_coords = self.particle_coords.to(device)
|
|
78
93
|
return arg
|
|
79
94
|
|
|
80
95
|
def total_point_count(self):
|
|
81
|
-
return self.
|
|
96
|
+
return self.particle_coords.shape[0]
|
|
82
97
|
|
|
83
98
|
def active_cell_count(self):
|
|
84
99
|
"""Number of cells containing at least one particle"""
|
|
@@ -121,6 +136,12 @@ class PicQuadrature(Quadrature):
|
|
|
121
136
|
particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
|
|
122
137
|
return particle_index
|
|
123
138
|
|
|
139
|
+
@wp.func
|
|
140
|
+
def point_evaluation_index(
|
|
141
|
+
elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
|
|
142
|
+
):
|
|
143
|
+
return qp_arg.cell_particle_offsets[element_index] + index
|
|
144
|
+
|
|
124
145
|
def fill_element_mask(self, mask: "wp.array(dtype=int)"):
|
|
125
146
|
"""Fills a mask array such that all non-empty elements are set to 1, all empty elements to zero.
|
|
126
147
|
|
|
@@ -141,7 +162,7 @@ class PicQuadrature(Quadrature):
|
|
|
141
162
|
element_mask: wp.array(dtype=int),
|
|
142
163
|
):
|
|
143
164
|
i = wp.tid()
|
|
144
|
-
element_mask[i] = wp.
|
|
165
|
+
element_mask[i] = wp.where(element_particle_offsets[i] == element_particle_offsets[i + 1], 0, 1)
|
|
145
166
|
|
|
146
167
|
@wp.kernel
|
|
147
168
|
def _compute_uniform_fraction(
|
|
@@ -152,9 +173,11 @@ class PicQuadrature(Quadrature):
|
|
|
152
173
|
p = wp.tid()
|
|
153
174
|
|
|
154
175
|
cell = cell_index[p]
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
176
|
+
if cell == NULL_ELEMENT_INDEX:
|
|
177
|
+
cell_fraction[p] = 0.0
|
|
178
|
+
else:
|
|
179
|
+
cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
|
|
180
|
+
cell_fraction[p] = 1.0 / float(cell_particle_count)
|
|
158
181
|
|
|
159
182
|
def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
|
|
160
183
|
if wp.types.is_array(positions):
|
|
@@ -174,13 +197,13 @@ class PicQuadrature(Quadrature):
|
|
|
174
197
|
|
|
175
198
|
device = positions.device
|
|
176
199
|
|
|
177
|
-
|
|
178
|
-
|
|
200
|
+
self._cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
|
|
201
|
+
self.cell_indices = self._cell_index_temp.array
|
|
179
202
|
|
|
180
203
|
self._particle_coords_temp = borrow_temporary(
|
|
181
204
|
temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
|
|
182
205
|
)
|
|
183
|
-
self.
|
|
206
|
+
self.particle_coords = self._particle_coords_temp.array
|
|
184
207
|
|
|
185
208
|
wp.launch(
|
|
186
209
|
dim=positions.shape[0],
|
|
@@ -188,25 +211,28 @@ class PicQuadrature(Quadrature):
|
|
|
188
211
|
inputs=[
|
|
189
212
|
self.domain.element_arg_value(device),
|
|
190
213
|
positions,
|
|
191
|
-
|
|
192
|
-
self.
|
|
214
|
+
self.cell_indices,
|
|
215
|
+
self.particle_coords,
|
|
193
216
|
],
|
|
194
217
|
device=device,
|
|
195
218
|
)
|
|
196
219
|
|
|
197
220
|
else:
|
|
198
|
-
|
|
199
|
-
if
|
|
221
|
+
self.cell_indices, self.particle_coords = positions
|
|
222
|
+
if self.cell_indices.shape != self.particle_coords.shape:
|
|
200
223
|
raise ValueError("Cell index and coordinates arrays must have the same shape")
|
|
201
224
|
|
|
202
|
-
|
|
225
|
+
self._cell_index_temp = None
|
|
203
226
|
self._particle_coords_temp = None
|
|
204
227
|
|
|
205
228
|
self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
|
|
206
|
-
self.domain.geometry_element_count(),
|
|
229
|
+
self.domain.geometry_element_count(),
|
|
230
|
+
self.cell_indices,
|
|
231
|
+
return_unique_nodes=True,
|
|
232
|
+
temporary_store=temporary_store,
|
|
207
233
|
)
|
|
208
234
|
|
|
209
|
-
self._compute_fraction(
|
|
235
|
+
self._compute_fraction(self.cell_indices, measures, temporary_store)
|
|
210
236
|
|
|
211
237
|
def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
|
|
212
238
|
device = cell_index.device
|
|
@@ -245,9 +271,13 @@ class PicQuadrature(Quadrature):
|
|
|
245
271
|
cell_fraction: wp.array(dtype=float),
|
|
246
272
|
):
|
|
247
273
|
p = wp.tid()
|
|
248
|
-
sample = make_free_sample(cell_index[p], cell_coords[p])
|
|
249
274
|
|
|
250
|
-
|
|
275
|
+
cell = cell_index[p]
|
|
276
|
+
if cell == NULL_ELEMENT_INDEX:
|
|
277
|
+
cell_fraction[p] = 0.0
|
|
278
|
+
else:
|
|
279
|
+
sample = make_free_sample(cell_index[p], cell_coords[p])
|
|
280
|
+
cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
|
|
251
281
|
|
|
252
282
|
wp.launch(
|
|
253
283
|
dim=measures.shape[0],
|
|
@@ -256,7 +286,7 @@ class PicQuadrature(Quadrature):
|
|
|
256
286
|
self.domain.element_arg_value(device),
|
|
257
287
|
measures,
|
|
258
288
|
cell_index,
|
|
259
|
-
self.
|
|
289
|
+
self.particle_coords,
|
|
260
290
|
self._particle_fraction,
|
|
261
291
|
],
|
|
262
292
|
device=device,
|
|
@@ -1,14 +1,36 @@
|
|
|
1
|
-
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any, Optional
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
|
4
|
-
from warp.fem import cache
|
|
19
|
+
from warp.fem import cache
|
|
20
|
+
from warp.fem.domain import GeometryDomain
|
|
5
21
|
from warp.fem.geometry import Element
|
|
6
22
|
from warp.fem.space import FunctionSpace
|
|
7
|
-
from warp.fem.types import Coords, ElementIndex
|
|
23
|
+
from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
|
|
8
24
|
|
|
9
25
|
from ..polynomial import Polynomial
|
|
10
26
|
|
|
11
27
|
|
|
28
|
+
@wp.struct
|
|
29
|
+
class QuadraturePointElementIndex:
|
|
30
|
+
domain_element_index: ElementIndex
|
|
31
|
+
qp_index_in_element: int
|
|
32
|
+
|
|
33
|
+
|
|
12
34
|
class Quadrature:
|
|
13
35
|
"""Interface class for quadrature rules"""
|
|
14
36
|
|
|
@@ -18,7 +40,7 @@ class Quadrature:
|
|
|
18
40
|
|
|
19
41
|
pass
|
|
20
42
|
|
|
21
|
-
def __init__(self, domain:
|
|
43
|
+
def __init__(self, domain: GeometryDomain):
|
|
22
44
|
self._domain = domain
|
|
23
45
|
|
|
24
46
|
@property
|
|
@@ -30,52 +52,197 @@ class Quadrature:
|
|
|
30
52
|
"""
|
|
31
53
|
Value of the argument to be passed to device
|
|
32
54
|
"""
|
|
33
|
-
arg =
|
|
55
|
+
arg = Quadrature.Arg()
|
|
34
56
|
return arg
|
|
35
57
|
|
|
36
58
|
def total_point_count(self):
|
|
37
|
-
"""
|
|
59
|
+
"""Number of unique quadrature points that can be indexed by this rule.
|
|
60
|
+
Returns a number such that `point_index()` is always smaller than this number.
|
|
61
|
+
"""
|
|
38
62
|
raise NotImplementedError()
|
|
39
63
|
|
|
64
|
+
def evaluation_point_count(self):
|
|
65
|
+
"""Number of quadrature points that needs to be evaluated, mostly for internal purposes.
|
|
66
|
+
If the indexing scheme is sparse, or if a quadrature point is shared among multiple elements
|
|
67
|
+
(e.g, nodal quadrature), `evaluation_point_count` may be different than `total_point_count()`.
|
|
68
|
+
Returns a number such that `evaluation_point_index()` is always smaller than this number.
|
|
69
|
+
"""
|
|
70
|
+
return self.total_point_count()
|
|
71
|
+
|
|
40
72
|
def max_points_per_element(self):
|
|
41
73
|
"""Maximum number of points per element if known, or ``None`` otherwise"""
|
|
42
74
|
return None
|
|
43
75
|
|
|
44
76
|
@staticmethod
|
|
45
|
-
def point_count(
|
|
77
|
+
def point_count(
|
|
78
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
79
|
+
qp_arg: Arg,
|
|
80
|
+
domain_element_index: ElementIndex,
|
|
81
|
+
geo_element_index: ElementIndex,
|
|
82
|
+
):
|
|
46
83
|
"""Number of quadrature points for a given element"""
|
|
47
84
|
raise NotImplementedError()
|
|
48
85
|
|
|
49
86
|
@staticmethod
|
|
50
87
|
def point_coords(
|
|
51
|
-
elt_arg: "
|
|
88
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
89
|
+
qp_arg: Arg,
|
|
90
|
+
domain_element_index: ElementIndex,
|
|
91
|
+
geo_element_index: ElementIndex,
|
|
92
|
+
element_qp_index: int,
|
|
52
93
|
):
|
|
53
94
|
"""Coordinates in element of the element's qp_index'th quadrature point"""
|
|
54
95
|
raise NotImplementedError()
|
|
55
96
|
|
|
56
97
|
@staticmethod
|
|
57
98
|
def point_weight(
|
|
58
|
-
elt_arg: "
|
|
99
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
100
|
+
qp_arg: Arg,
|
|
101
|
+
domain_element_index: ElementIndex,
|
|
102
|
+
geo_element_index: ElementIndex,
|
|
103
|
+
element_qp_index: int,
|
|
59
104
|
):
|
|
60
105
|
"""Weight of the element's qp_index'th quadrature point"""
|
|
61
106
|
raise NotImplementedError()
|
|
62
107
|
|
|
63
108
|
@staticmethod
|
|
64
109
|
def point_index(
|
|
65
|
-
elt_arg: "
|
|
110
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
111
|
+
qp_arg: Arg,
|
|
112
|
+
domain_element_index: ElementIndex,
|
|
113
|
+
geo_element_index: ElementIndex,
|
|
114
|
+
element_qp_index: int,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Global index of the element's qp_index'th quadrature point.
|
|
118
|
+
May be shared among elements.
|
|
119
|
+
This is what determines `qp_index` in integrands' `Sample` arguments.
|
|
120
|
+
"""
|
|
121
|
+
raise NotImplementedError()
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def point_evaluation_index(
|
|
125
|
+
elt_arg: "GeometryDomain.ElementArg",
|
|
66
126
|
qp_arg: Arg,
|
|
67
127
|
domain_element_index: ElementIndex,
|
|
68
128
|
geo_element_index: ElementIndex,
|
|
69
129
|
element_qp_index: int,
|
|
70
130
|
):
|
|
71
|
-
"""
|
|
131
|
+
"""Quadrature point index according to evaluation order.
|
|
132
|
+
Quadrature points for distinct elements must have different evaluation indices.
|
|
133
|
+
Mostly for internal/parallelization purposes.
|
|
134
|
+
"""
|
|
72
135
|
raise NotImplementedError()
|
|
73
136
|
|
|
74
137
|
def __str__(self) -> str:
|
|
75
138
|
return self.name
|
|
76
139
|
|
|
140
|
+
# By default cache the mapping from evaluation point indices to domain elements
|
|
141
|
+
|
|
142
|
+
ElementIndexArg = wp.array(dtype=QuadraturePointElementIndex)
|
|
143
|
+
|
|
144
|
+
@cache.cached_arg_value
|
|
145
|
+
def element_index_arg_value(self, device):
|
|
146
|
+
"""Builds a map from quadrature point evaluation indices to their index in the element to which they belong"""
|
|
147
|
+
|
|
148
|
+
@cache.dynamic_kernel(f"{self.name}{self.domain.name}")
|
|
149
|
+
def quadrature_point_element_indices(
|
|
150
|
+
qp_arg: self.Arg,
|
|
151
|
+
domain_arg: self.domain.ElementArg,
|
|
152
|
+
domain_index_arg: self.domain.ElementIndexArg,
|
|
153
|
+
result: wp.array(dtype=QuadraturePointElementIndex),
|
|
154
|
+
):
|
|
155
|
+
domain_element_index = wp.tid()
|
|
156
|
+
element_index = self.domain.element_index(domain_index_arg, domain_element_index)
|
|
157
|
+
|
|
158
|
+
qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
|
|
159
|
+
for k in range(qp_point_count):
|
|
160
|
+
qp_eval_index = self.point_evaluation_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
161
|
+
result[qp_eval_index] = QuadraturePointElementIndex(domain_element_index, k)
|
|
162
|
+
|
|
163
|
+
null_qp_index = QuadraturePointElementIndex()
|
|
164
|
+
null_qp_index.domain_element_index = NULL_ELEMENT_INDEX
|
|
165
|
+
result = wp.full(
|
|
166
|
+
value=null_qp_index,
|
|
167
|
+
shape=(self.evaluation_point_count()),
|
|
168
|
+
dtype=QuadraturePointElementIndex,
|
|
169
|
+
device=device,
|
|
170
|
+
)
|
|
171
|
+
wp.launch(
|
|
172
|
+
quadrature_point_element_indices,
|
|
173
|
+
device=result.device,
|
|
174
|
+
dim=self.domain.element_count(),
|
|
175
|
+
inputs=[
|
|
176
|
+
self.arg_value(result.device),
|
|
177
|
+
self.domain.element_arg_value(result.device),
|
|
178
|
+
self.domain.element_index_arg_value(result.device),
|
|
179
|
+
result,
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
@wp.func
|
|
186
|
+
def evaluation_point_element_index(
|
|
187
|
+
element_index_arg: wp.array(dtype=QuadraturePointElementIndex),
|
|
188
|
+
qp_eval_index: QuadraturePointIndex,
|
|
189
|
+
):
|
|
190
|
+
"""Maps from quadrature point evaluation indices to their index in the element to which they belong
|
|
191
|
+
If the quadrature point does not exist, should return NULL_ELEMENT_INDEX as the domain element index
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
element_index = element_index_arg[qp_eval_index]
|
|
195
|
+
return element_index.domain_element_index, element_index.qp_index_in_element
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class _QuadratureWithRegularEvaluationPoints(Quadrature):
|
|
199
|
+
"""Helper subclass for quadrature formulas which use a uniform number of
|
|
200
|
+
evaluations points per element. Avoids building explicit mapping"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, domain: GeometryDomain, N: int):
|
|
203
|
+
super().__init__(domain)
|
|
204
|
+
self._EVALUATION_POINTS_PER_ELEMENT = N
|
|
205
|
+
|
|
206
|
+
self.point_evaluation_index = self._make_regular_point_evaluation_index()
|
|
207
|
+
self.evaluation_point_element_index = self._make_regular_evaluation_point_element_index()
|
|
77
208
|
|
|
78
|
-
|
|
209
|
+
ElementIndexArg = Quadrature.Arg
|
|
210
|
+
element_index_arg_value = Quadrature.arg_value
|
|
211
|
+
|
|
212
|
+
def evaluation_point_count(self):
|
|
213
|
+
return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
|
|
214
|
+
|
|
215
|
+
def _make_regular_point_evaluation_index(self):
|
|
216
|
+
N = self._EVALUATION_POINTS_PER_ELEMENT
|
|
217
|
+
|
|
218
|
+
@cache.dynamic_func(suffix=f"{self.name}")
|
|
219
|
+
def evaluation_point_index(
|
|
220
|
+
elt_arg: self.domain.ElementArg,
|
|
221
|
+
qp_arg: self.Arg,
|
|
222
|
+
domain_element_index: ElementIndex,
|
|
223
|
+
element_index: ElementIndex,
|
|
224
|
+
qp_index: int,
|
|
225
|
+
):
|
|
226
|
+
return N * domain_element_index + qp_index
|
|
227
|
+
|
|
228
|
+
return evaluation_point_index
|
|
229
|
+
|
|
230
|
+
def _make_regular_evaluation_point_element_index(self):
|
|
231
|
+
N = self._EVALUATION_POINTS_PER_ELEMENT
|
|
232
|
+
|
|
233
|
+
@cache.dynamic_func(suffix=f"{N}")
|
|
234
|
+
def quadrature_evaluation_point_element_index(
|
|
235
|
+
qp_arg: Quadrature.Arg,
|
|
236
|
+
qp_index: QuadraturePointIndex,
|
|
237
|
+
):
|
|
238
|
+
domain_element_index = qp_index // N
|
|
239
|
+
index_in_element = qp_index - domain_element_index * N
|
|
240
|
+
return domain_element_index, index_in_element
|
|
241
|
+
|
|
242
|
+
return quadrature_evaluation_point_element_index
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
|
|
79
246
|
"""Regular quadrature formula, using a constant set of quadrature points per element"""
|
|
80
247
|
|
|
81
248
|
@wp.struct
|
|
@@ -112,16 +279,15 @@ class RegularQuadrature(Quadrature):
|
|
|
112
279
|
|
|
113
280
|
def __init__(
|
|
114
281
|
self,
|
|
115
|
-
domain:
|
|
282
|
+
domain: GeometryDomain,
|
|
116
283
|
order: int,
|
|
117
284
|
family: Polynomial = None,
|
|
118
285
|
):
|
|
119
|
-
|
|
120
|
-
|
|
286
|
+
self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
|
|
121
287
|
self.family = family
|
|
122
288
|
self.order = order
|
|
123
289
|
|
|
124
|
-
|
|
290
|
+
super().__init__(domain, self._formula.count)
|
|
125
291
|
|
|
126
292
|
self.point_count = self._make_point_count()
|
|
127
293
|
self.point_index = self._make_point_index()
|
|
@@ -212,17 +378,18 @@ class NodalQuadrature(Quadrature):
|
|
|
212
378
|
any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
|
|
213
379
|
"""
|
|
214
380
|
|
|
215
|
-
def __init__(self, domain:
|
|
216
|
-
super().__init__(domain)
|
|
217
|
-
|
|
381
|
+
def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
|
|
218
382
|
self._space = space
|
|
219
383
|
|
|
384
|
+
super().__init__(domain)
|
|
385
|
+
|
|
220
386
|
self.Arg = self._make_arg()
|
|
221
387
|
|
|
222
388
|
self.point_count = self._make_point_count()
|
|
223
389
|
self.point_index = self._make_point_index()
|
|
224
390
|
self.point_coords = self._make_point_coords()
|
|
225
391
|
self.point_weight = self._make_point_weight()
|
|
392
|
+
self.point_evaluation_index = self._make_point_evaluation_index()
|
|
226
393
|
|
|
227
394
|
@property
|
|
228
395
|
def name(self):
|
|
@@ -300,8 +467,26 @@ class NodalQuadrature(Quadrature):
|
|
|
300
467
|
|
|
301
468
|
return point_index
|
|
302
469
|
|
|
470
|
+
def evaluation_point_count(self):
|
|
471
|
+
return self.domain.element_count() * self._space.topology.MAX_NODES_PER_ELEMENT
|
|
303
472
|
|
|
304
|
-
|
|
473
|
+
def _make_point_evaluation_index(self):
|
|
474
|
+
N = self._space.topology.MAX_NODES_PER_ELEMENT
|
|
475
|
+
|
|
476
|
+
@cache.dynamic_func(suffix=self.name)
|
|
477
|
+
def evaluation_point_index(
|
|
478
|
+
elt_arg: self.domain.ElementArg,
|
|
479
|
+
qp_arg: self.Arg,
|
|
480
|
+
domain_element_index: ElementIndex,
|
|
481
|
+
element_index: ElementIndex,
|
|
482
|
+
qp_index: int,
|
|
483
|
+
):
|
|
484
|
+
return N * domain_element_index + qp_index
|
|
485
|
+
|
|
486
|
+
return evaluation_point_index
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
|
|
305
490
|
"""Quadrature using explicit per-cell points and weights.
|
|
306
491
|
|
|
307
492
|
The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
|
|
@@ -321,11 +506,7 @@ class ExplicitQuadrature(Quadrature):
|
|
|
321
506
|
points: wp.array2d(dtype=Coords)
|
|
322
507
|
weights: wp.array2d(dtype=float)
|
|
323
508
|
|
|
324
|
-
def __init__(
|
|
325
|
-
self, domain: domain.GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"
|
|
326
|
-
):
|
|
327
|
-
super().__init__(domain)
|
|
328
|
-
|
|
509
|
+
def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
|
|
329
510
|
if points.shape != weights.shape:
|
|
330
511
|
raise ValueError("Points and weights arrays must have the same shape")
|
|
331
512
|
|
|
@@ -343,7 +524,10 @@ class ExplicitQuadrature(Quadrature):
|
|
|
343
524
|
)
|
|
344
525
|
|
|
345
526
|
self._points_per_cell = points.shape[1]
|
|
527
|
+
|
|
346
528
|
self._whole_geo = points.shape[0] == domain.geometry_element_count()
|
|
529
|
+
|
|
530
|
+
super().__init__(domain, self._points_per_cell)
|
|
347
531
|
self._points = points
|
|
348
532
|
self._weights = weights
|
|
349
533
|
|
warp/fem/space/__init__.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
# isort: skip_file
|
|
2
17
|
|
|
3
18
|
from enum import Enum
|
|
@@ -97,7 +112,7 @@ def make_polynomial_basis_space(
|
|
|
97
112
|
the constructed basis space
|
|
98
113
|
"""
|
|
99
114
|
|
|
100
|
-
base_geo = geo.base
|
|
115
|
+
base_geo = geo.base
|
|
101
116
|
|
|
102
117
|
if element_basis is None:
|
|
103
118
|
element_basis = ElementBasis.LAGRANGE
|