warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -34,7 +34,7 @@ from warp.fem.field import (
|
|
|
34
34
|
make_restriction,
|
|
35
35
|
)
|
|
36
36
|
from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
|
|
37
|
-
from warp.fem.linalg import array_axpy
|
|
37
|
+
from warp.fem.linalg import array_axpy, basis_coefficient
|
|
38
38
|
from warp.fem.operator import Integrand, Operator, at_node, integrand
|
|
39
39
|
from warp.fem.quadrature import Quadrature, RegularQuadrature
|
|
40
40
|
from warp.fem.types import (
|
|
@@ -493,7 +493,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
493
493
|
callee = getattr(call.func, "id", None)
|
|
494
494
|
|
|
495
495
|
if callee == self._func_name:
|
|
496
|
-
# Replace function arguments with
|
|
496
|
+
# Replace function arguments with our generated structs
|
|
497
497
|
call.args.clear()
|
|
498
498
|
for arg in self._arg_names:
|
|
499
499
|
if arg == self._domain_name:
|
|
@@ -576,33 +576,33 @@ def get_integrate_constant_kernel(
|
|
|
576
576
|
):
|
|
577
577
|
def integrate_kernel_fn(
|
|
578
578
|
qp_arg: quadrature.Arg,
|
|
579
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
579
580
|
domain_arg: domain.ElementArg,
|
|
580
581
|
domain_index_arg: domain.ElementIndexArg,
|
|
581
582
|
fields: FieldStruct,
|
|
582
583
|
values: ValueStruct,
|
|
583
584
|
result: wp.array(dtype=accumulate_dtype),
|
|
584
585
|
):
|
|
585
|
-
|
|
586
|
+
qp_eval_index = wp.tid()
|
|
587
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
588
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
589
|
+
return
|
|
590
|
+
|
|
586
591
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
587
|
-
|
|
592
|
+
|
|
593
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
594
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
595
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
588
596
|
|
|
589
597
|
test_dof_index = NULL_DOF_INDEX
|
|
590
598
|
trial_dof_index = NULL_DOF_INDEX
|
|
591
599
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
595
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
596
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
597
|
-
|
|
598
|
-
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
599
|
-
vol = domain.element_measure(domain_arg, sample)
|
|
600
|
-
|
|
601
|
-
val = integrand_func(sample, fields, values)
|
|
600
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
601
|
+
vol = domain.element_measure(domain_arg, sample)
|
|
602
602
|
|
|
603
|
-
|
|
603
|
+
val = integrand_func(sample, fields, values)
|
|
604
604
|
|
|
605
|
-
wp.atomic_add(result, 0,
|
|
605
|
+
wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
|
|
606
606
|
|
|
607
607
|
return integrate_kernel_fn
|
|
608
608
|
|
|
@@ -745,35 +745,35 @@ def get_integrate_linear_local_kernel(
|
|
|
745
745
|
ValueStruct: wp.codegen.Struct,
|
|
746
746
|
test: LocalTestField,
|
|
747
747
|
):
|
|
748
|
-
TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
|
|
749
|
-
|
|
750
748
|
def integrate_kernel_fn(
|
|
751
749
|
qp_arg: quadrature.Arg,
|
|
750
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
752
751
|
domain_arg: domain.ElementArg,
|
|
753
752
|
domain_index_arg: domain.ElementIndexArg,
|
|
754
753
|
fields: FieldStruct,
|
|
755
754
|
values: ValueStruct,
|
|
756
755
|
result: wp.array3d(dtype=float),
|
|
757
756
|
):
|
|
758
|
-
|
|
759
|
-
|
|
757
|
+
qp_eval_index, taylor_dof, test_dof = wp.tid()
|
|
758
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
760
759
|
|
|
761
|
-
|
|
762
|
-
|
|
760
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
761
|
+
return
|
|
763
762
|
|
|
764
|
-
|
|
765
|
-
for qp in range(qp_point_count):
|
|
766
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
767
|
-
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
768
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
763
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
769
764
|
|
|
770
|
-
|
|
765
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
766
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
767
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
771
768
|
|
|
772
|
-
|
|
769
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
773
770
|
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
771
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
772
|
+
test_dof_index = DofIndex(taylor_dof, test_dof)
|
|
773
|
+
|
|
774
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
775
|
+
val = integrand_func(sample, fields, values)
|
|
776
|
+
result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
|
|
777
777
|
|
|
778
778
|
return integrate_kernel_fn
|
|
779
779
|
|
|
@@ -818,10 +818,10 @@ def get_integrate_bilinear_kernel(
|
|
|
818
818
|
element_trial_node_count = trial.space.topology.element_node_count(
|
|
819
819
|
domain_arg, trial_topology_arg, element_index
|
|
820
820
|
)
|
|
821
|
-
qp_point_count = wp.
|
|
821
|
+
qp_point_count = wp.where(
|
|
822
822
|
trial_node < element_trial_node_count,
|
|
823
|
-
0,
|
|
824
823
|
quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
|
|
824
|
+
0,
|
|
825
825
|
)
|
|
826
826
|
|
|
827
827
|
test_dof_index = DofIndex(
|
|
@@ -963,36 +963,38 @@ def get_integrate_bilinear_local_kernel(
|
|
|
963
963
|
|
|
964
964
|
def integrate_kernel_fn(
|
|
965
965
|
qp_arg: quadrature.Arg,
|
|
966
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
966
967
|
domain_arg: domain.ElementArg,
|
|
967
968
|
domain_index_arg: domain.ElementIndexArg,
|
|
968
969
|
fields: FieldStruct,
|
|
969
970
|
values: ValueStruct,
|
|
970
971
|
result: wp.array4d(dtype=float),
|
|
971
972
|
):
|
|
972
|
-
|
|
973
|
+
qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
|
|
974
|
+
|
|
975
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
976
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
977
|
+
return
|
|
978
|
+
|
|
973
979
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
974
980
|
|
|
975
|
-
|
|
976
|
-
|
|
981
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
982
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
983
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
977
984
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
981
|
-
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
982
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
985
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
986
|
+
qp_vol = vol * qp_weight
|
|
983
987
|
|
|
984
|
-
|
|
985
|
-
qp_vol = vol * qp_weight
|
|
988
|
+
trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
|
|
986
989
|
|
|
987
|
-
|
|
988
|
-
|
|
990
|
+
for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
|
|
991
|
+
taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
|
|
989
992
|
|
|
990
|
-
|
|
991
|
-
trial_dof_index = DofIndex(qp_index, trial_dof_offset + trial_taylor_dof)
|
|
993
|
+
test_dof_index = DofIndex(test_taylor_dof, test_dof)
|
|
992
994
|
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
995
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
996
|
+
val = integrand_func(sample, fields, values)
|
|
997
|
+
result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
|
|
996
998
|
|
|
997
999
|
return integrate_kernel_fn
|
|
998
1000
|
|
|
@@ -1138,6 +1140,7 @@ def _launch_integrate_kernel(
|
|
|
1138
1140
|
output_dtype: type,
|
|
1139
1141
|
output: Optional[Union[wp.array, BsrMatrix]],
|
|
1140
1142
|
add_to_output: bool,
|
|
1143
|
+
bsr_options: Optional[Dict[str, Any]],
|
|
1141
1144
|
device,
|
|
1142
1145
|
):
|
|
1143
1146
|
# Set-up launch arguments
|
|
@@ -1175,9 +1178,10 @@ def _launch_integrate_kernel(
|
|
|
1175
1178
|
|
|
1176
1179
|
wp.launch(
|
|
1177
1180
|
kernel=kernel,
|
|
1178
|
-
dim=
|
|
1181
|
+
dim=quadrature.evaluation_point_count(),
|
|
1179
1182
|
inputs=[
|
|
1180
1183
|
qp_arg,
|
|
1184
|
+
quadrature.element_index_arg_value(device),
|
|
1181
1185
|
domain_elt_arg,
|
|
1182
1186
|
domain_elt_index_arg,
|
|
1183
1187
|
field_arg_values,
|
|
@@ -1279,15 +1283,16 @@ def _launch_integrate_kernel(
|
|
|
1279
1283
|
temporary_store=temporary_store,
|
|
1280
1284
|
device=device,
|
|
1281
1285
|
requires_grad=output.requires_grad,
|
|
1282
|
-
shape=(quadrature.
|
|
1286
|
+
shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
|
|
1283
1287
|
dtype=float,
|
|
1284
1288
|
)
|
|
1285
1289
|
|
|
1286
1290
|
wp.launch(
|
|
1287
1291
|
kernel=kernel,
|
|
1288
|
-
dim=
|
|
1292
|
+
dim=local_result.array.shape,
|
|
1289
1293
|
inputs=[
|
|
1290
1294
|
qp_arg,
|
|
1295
|
+
quadrature.element_index_arg_value(device),
|
|
1291
1296
|
domain_elt_arg,
|
|
1292
1297
|
domain_elt_index_arg,
|
|
1293
1298
|
field_arg_values,
|
|
@@ -1389,7 +1394,7 @@ def _launch_integrate_kernel(
|
|
|
1389
1394
|
device=device,
|
|
1390
1395
|
requires_grad=False,
|
|
1391
1396
|
shape=(
|
|
1392
|
-
quadrature.
|
|
1397
|
+
quadrature.evaluation_point_count(),
|
|
1393
1398
|
test.value_dof_count,
|
|
1394
1399
|
trial.value_dof_count,
|
|
1395
1400
|
test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
|
|
@@ -1399,9 +1404,15 @@ def _launch_integrate_kernel(
|
|
|
1399
1404
|
|
|
1400
1405
|
wp.launch(
|
|
1401
1406
|
kernel=kernel,
|
|
1402
|
-
dim=(
|
|
1407
|
+
dim=(
|
|
1408
|
+
quadrature.evaluation_point_count(),
|
|
1409
|
+
test.value_dof_count,
|
|
1410
|
+
trial.value_dof_count,
|
|
1411
|
+
trial.TAYLOR_DOF_COUNT,
|
|
1412
|
+
),
|
|
1403
1413
|
inputs=[
|
|
1404
1414
|
qp_arg,
|
|
1415
|
+
quadrature.element_index_arg_value(device),
|
|
1405
1416
|
domain_elt_arg,
|
|
1406
1417
|
domain_elt_index_arg,
|
|
1407
1418
|
field_arg_values,
|
|
@@ -1496,7 +1507,7 @@ def _launch_integrate_kernel(
|
|
|
1496
1507
|
else:
|
|
1497
1508
|
bsr_result = output
|
|
1498
1509
|
|
|
1499
|
-
bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values)
|
|
1510
|
+
bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
|
|
1500
1511
|
|
|
1501
1512
|
# Do not wait for garbage collection
|
|
1502
1513
|
triplet_values_temp.release()
|
|
@@ -1541,8 +1552,9 @@ def integrate(
|
|
|
1541
1552
|
device=None,
|
|
1542
1553
|
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
1543
1554
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
1544
|
-
assembly: str = None,
|
|
1555
|
+
assembly: Optional[str] = None,
|
|
1545
1556
|
add: bool = False,
|
|
1557
|
+
bsr_options: Optional[Dict[str, Any]] = None,
|
|
1546
1558
|
):
|
|
1547
1559
|
"""
|
|
1548
1560
|
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
|
|
@@ -1566,6 +1578,7 @@ def integrate(
|
|
|
1566
1578
|
- "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
|
|
1567
1579
|
- `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
|
|
1568
1580
|
add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
|
|
1581
|
+
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
1569
1582
|
"""
|
|
1570
1583
|
if fields is None:
|
|
1571
1584
|
fields = {}
|
|
@@ -1678,6 +1691,7 @@ def integrate(
|
|
|
1678
1691
|
output_dtype=output_dtype,
|
|
1679
1692
|
output=output,
|
|
1680
1693
|
add_to_output=add,
|
|
1694
|
+
bsr_options=bsr_options,
|
|
1681
1695
|
device=device,
|
|
1682
1696
|
)
|
|
1683
1697
|
|
|
@@ -1823,53 +1837,128 @@ def get_interpolate_at_quadrature_kernel(
|
|
|
1823
1837
|
):
|
|
1824
1838
|
def interpolate_at_quadrature_nonvalued_kernel_fn(
|
|
1825
1839
|
qp_arg: quadrature.Arg,
|
|
1840
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1826
1841
|
domain_arg: quadrature.domain.ElementArg,
|
|
1827
1842
|
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1828
1843
|
fields: FieldStruct,
|
|
1829
1844
|
values: ValueStruct,
|
|
1830
1845
|
result: wp.array(dtype=float),
|
|
1831
1846
|
):
|
|
1832
|
-
|
|
1847
|
+
qp_eval_index = wp.tid()
|
|
1848
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1849
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1850
|
+
return
|
|
1851
|
+
|
|
1833
1852
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1834
1853
|
|
|
1835
1854
|
test_dof_index = NULL_DOF_INDEX
|
|
1836
1855
|
trial_dof_index = NULL_DOF_INDEX
|
|
1837
1856
|
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1842
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1857
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1858
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1859
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1843
1860
|
|
|
1844
|
-
|
|
1845
|
-
|
|
1861
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1862
|
+
integrand_func(sample, fields, values)
|
|
1846
1863
|
|
|
1847
1864
|
def interpolate_at_quadrature_kernel_fn(
|
|
1848
1865
|
qp_arg: quadrature.Arg,
|
|
1866
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1849
1867
|
domain_arg: quadrature.domain.ElementArg,
|
|
1850
1868
|
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1851
1869
|
fields: FieldStruct,
|
|
1852
1870
|
values: ValueStruct,
|
|
1853
1871
|
result: wp.array(dtype=value_type),
|
|
1854
1872
|
):
|
|
1855
|
-
|
|
1873
|
+
qp_eval_index = wp.tid()
|
|
1874
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1875
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1876
|
+
return
|
|
1877
|
+
|
|
1856
1878
|
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1857
1879
|
|
|
1858
1880
|
test_dof_index = NULL_DOF_INDEX
|
|
1859
1881
|
trial_dof_index = NULL_DOF_INDEX
|
|
1860
1882
|
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1865
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
|
|
1883
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1884
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1885
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1866
1886
|
|
|
1867
|
-
|
|
1868
|
-
|
|
1887
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1888
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1869
1889
|
|
|
1870
1890
|
return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
|
|
1871
1891
|
|
|
1872
1892
|
|
|
1893
|
+
def get_interpolate_jacobian_at_quadrature_kernel(
|
|
1894
|
+
integrand_func: wp.Function,
|
|
1895
|
+
domain: GeometryDomain,
|
|
1896
|
+
quadrature: Quadrature,
|
|
1897
|
+
FieldStruct: wp.codegen.Struct,
|
|
1898
|
+
ValueStruct: wp.codegen.Struct,
|
|
1899
|
+
trial: TrialField,
|
|
1900
|
+
value_size: int,
|
|
1901
|
+
value_type: type,
|
|
1902
|
+
):
|
|
1903
|
+
MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
1904
|
+
VALUE_SIZE = wp.constant(value_size)
|
|
1905
|
+
|
|
1906
|
+
def interpolate_jacobian_kernel_fn(
|
|
1907
|
+
qp_arg: quadrature.Arg,
|
|
1908
|
+
qp_element_index_arg: quadrature.ElementIndexArg,
|
|
1909
|
+
domain_arg: domain.ElementArg,
|
|
1910
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1911
|
+
trial_partition_arg: trial.space_partition.PartitionArg,
|
|
1912
|
+
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
|
|
1913
|
+
fields: FieldStruct,
|
|
1914
|
+
values: ValueStruct,
|
|
1915
|
+
triplet_rows: wp.array(dtype=int),
|
|
1916
|
+
triplet_cols: wp.array(dtype=int),
|
|
1917
|
+
triplet_values: wp.array3d(dtype=value_type),
|
|
1918
|
+
):
|
|
1919
|
+
qp_eval_index, trial_node, trial_dof = wp.tid()
|
|
1920
|
+
domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
|
|
1921
|
+
|
|
1922
|
+
if domain_element_index == NULL_ELEMENT_INDEX:
|
|
1923
|
+
return
|
|
1924
|
+
|
|
1925
|
+
element_index = domain.element_index(domain_index_arg, domain_element_index)
|
|
1926
|
+
if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
|
|
1927
|
+
return
|
|
1928
|
+
|
|
1929
|
+
element_trial_node_count = trial.space.topology.element_node_count(
|
|
1930
|
+
domain_arg, trial_topology_arg, element_index
|
|
1931
|
+
)
|
|
1932
|
+
|
|
1933
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1934
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1935
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
|
|
1936
|
+
|
|
1937
|
+
block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
|
|
1938
|
+
|
|
1939
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1940
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
1941
|
+
|
|
1942
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1943
|
+
val = integrand_func(sample, fields, values)
|
|
1944
|
+
|
|
1945
|
+
for k in range(VALUE_SIZE):
|
|
1946
|
+
triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
|
|
1947
|
+
|
|
1948
|
+
if trial_dof == 0:
|
|
1949
|
+
if trial_node < element_trial_node_count:
|
|
1950
|
+
trial_node_index = trial.space_partition.partition_node_index(
|
|
1951
|
+
trial_partition_arg,
|
|
1952
|
+
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
|
|
1953
|
+
)
|
|
1954
|
+
else:
|
|
1955
|
+
trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
|
|
1956
|
+
triplet_rows[block_offset] = qp_index
|
|
1957
|
+
triplet_cols[block_offset] = trial_node_index
|
|
1958
|
+
|
|
1959
|
+
return interpolate_jacobian_kernel_fn
|
|
1960
|
+
|
|
1961
|
+
|
|
1873
1962
|
def get_interpolate_free_kernel(
|
|
1874
1963
|
integrand_func: wp.Function,
|
|
1875
1964
|
domain: GeometryDomain,
|
|
@@ -1939,9 +2028,9 @@ def _generate_interpolate_kernel(
|
|
|
1939
2028
|
dest_dtype = dest.dtype if dest else None
|
|
1940
2029
|
type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
|
|
1941
2030
|
if quadrature is None:
|
|
1942
|
-
kernel_suffix = f"_itp_{field_names}_{type_str}"
|
|
2031
|
+
kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
|
|
1943
2032
|
else:
|
|
1944
|
-
kernel_suffix = f"_itp_{field_names}_{quadrature.name}_{type_str}"
|
|
2033
|
+
kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
|
|
1945
2034
|
|
|
1946
2035
|
kernel = cache.get_integrand_kernel(
|
|
1947
2036
|
integrand=integrand,
|
|
@@ -1986,14 +2075,27 @@ def _generate_interpolate_kernel(
|
|
|
1986
2075
|
ValueStruct=ValueStruct,
|
|
1987
2076
|
)
|
|
1988
2077
|
elif quadrature is not None:
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
2078
|
+
if arguments.trial_name:
|
|
2079
|
+
trial = arguments.field_args[arguments.trial_name]
|
|
2080
|
+
interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
|
|
2081
|
+
integrand_func,
|
|
2082
|
+
domain=domain,
|
|
2083
|
+
quadrature=quadrature,
|
|
2084
|
+
FieldStruct=FieldStruct,
|
|
2085
|
+
ValueStruct=ValueStruct,
|
|
2086
|
+
trial=trial,
|
|
2087
|
+
value_size=dest.block_shape[0],
|
|
2088
|
+
value_type=dest.scalar_type,
|
|
2089
|
+
)
|
|
2090
|
+
else:
|
|
2091
|
+
interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
|
|
2092
|
+
integrand_func,
|
|
2093
|
+
domain=domain,
|
|
2094
|
+
quadrature=quadrature,
|
|
2095
|
+
value_type=dest_dtype,
|
|
2096
|
+
FieldStruct=FieldStruct,
|
|
2097
|
+
ValueStruct=ValueStruct,
|
|
2098
|
+
)
|
|
1997
2099
|
else:
|
|
1998
2100
|
interpolate_kernel_fn = get_interpolate_free_kernel(
|
|
1999
2101
|
integrand_func,
|
|
@@ -2027,8 +2129,11 @@ def _launch_interpolate_kernel(
|
|
|
2027
2129
|
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
2028
2130
|
quadrature: Optional[Quadrature],
|
|
2029
2131
|
dim: int,
|
|
2132
|
+
trial: Optional[TrialField],
|
|
2030
2133
|
fields: Dict[str, FieldLike],
|
|
2031
2134
|
values: Dict[str, Any],
|
|
2135
|
+
temporary_store: Optional[cache.TemporaryStore],
|
|
2136
|
+
bsr_options: Optional[Dict[str, Any]],
|
|
2032
2137
|
device,
|
|
2033
2138
|
) -> wp.Kernel:
|
|
2034
2139
|
# Set-up launch arguments
|
|
@@ -2059,21 +2164,74 @@ def _launch_interpolate_kernel(
|
|
|
2059
2164
|
],
|
|
2060
2165
|
device=device,
|
|
2061
2166
|
)
|
|
2062
|
-
|
|
2063
|
-
|
|
2167
|
+
return
|
|
2168
|
+
|
|
2169
|
+
if quadrature is None:
|
|
2064
2170
|
wp.launch(
|
|
2065
2171
|
kernel=kernel,
|
|
2066
|
-
dim=
|
|
2067
|
-
inputs=[
|
|
2172
|
+
dim=dim,
|
|
2173
|
+
inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
|
|
2068
2174
|
device=device,
|
|
2069
2175
|
)
|
|
2070
|
-
|
|
2176
|
+
return
|
|
2177
|
+
|
|
2178
|
+
qp_arg = quadrature.arg_value(device)
|
|
2179
|
+
qp_element_index_arg = quadrature.element_index_arg_value(device)
|
|
2180
|
+
if trial is None:
|
|
2071
2181
|
wp.launch(
|
|
2072
2182
|
kernel=kernel,
|
|
2073
|
-
dim=
|
|
2074
|
-
inputs=[
|
|
2183
|
+
dim=quadrature.evaluation_point_count(),
|
|
2184
|
+
inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
2075
2185
|
device=device,
|
|
2076
2186
|
)
|
|
2187
|
+
return
|
|
2188
|
+
|
|
2189
|
+
nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
|
|
2190
|
+
|
|
2191
|
+
if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
|
|
2192
|
+
raise RuntimeError(
|
|
2193
|
+
f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
2194
|
+
)
|
|
2195
|
+
if dest.block_shape[1] != trial.node_dof_count:
|
|
2196
|
+
raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
|
|
2197
|
+
|
|
2198
|
+
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2199
|
+
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
|
|
2200
|
+
triplet_values_temp = cache.borrow_temporary(
|
|
2201
|
+
temporary_store,
|
|
2202
|
+
dtype=dest.scalar_type,
|
|
2203
|
+
shape=(nnz, *dest.block_shape),
|
|
2204
|
+
device=device,
|
|
2205
|
+
)
|
|
2206
|
+
triplet_cols = triplet_cols_temp.array
|
|
2207
|
+
triplet_rows = triplet_rows_temp.array
|
|
2208
|
+
triplet_values = triplet_values_temp.array
|
|
2209
|
+
triplet_rows.fill_(-1)
|
|
2210
|
+
triplet_values.zero_()
|
|
2211
|
+
|
|
2212
|
+
trial_partition_arg = trial.space_partition.partition_arg_value(device)
|
|
2213
|
+
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
2214
|
+
|
|
2215
|
+
wp.launch(
|
|
2216
|
+
kernel=kernel,
|
|
2217
|
+
dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
|
|
2218
|
+
inputs=[
|
|
2219
|
+
qp_arg,
|
|
2220
|
+
qp_element_index_arg,
|
|
2221
|
+
elt_arg,
|
|
2222
|
+
elt_index_arg,
|
|
2223
|
+
trial_partition_arg,
|
|
2224
|
+
trial_topology_arg,
|
|
2225
|
+
field_arg_values,
|
|
2226
|
+
value_struct_values,
|
|
2227
|
+
triplet_rows,
|
|
2228
|
+
triplet_cols,
|
|
2229
|
+
triplet_values,
|
|
2230
|
+
],
|
|
2231
|
+
device=device,
|
|
2232
|
+
)
|
|
2233
|
+
|
|
2234
|
+
bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
|
|
2077
2235
|
|
|
2078
2236
|
|
|
2079
2237
|
@integrand
|
|
@@ -2091,6 +2249,8 @@ def interpolate(
|
|
|
2091
2249
|
values: Optional[Dict[str, Any]] = None,
|
|
2092
2250
|
device=None,
|
|
2093
2251
|
kernel_options: Optional[Dict[str, Any]] = None,
|
|
2252
|
+
temporary_store: Optional[cache.TemporaryStore] = None,
|
|
2253
|
+
bsr_options: Optional[Dict[str, Any]] = None,
|
|
2094
2254
|
):
|
|
2095
2255
|
"""
|
|
2096
2256
|
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
@@ -2109,6 +2269,8 @@ def interpolate(
|
|
|
2109
2269
|
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
|
|
2110
2270
|
device: Device on which to perform the interpolation
|
|
2111
2271
|
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
2272
|
+
temporary_store: shared pool from which to allocate temporary arrays
|
|
2273
|
+
bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
|
|
2112
2274
|
"""
|
|
2113
2275
|
|
|
2114
2276
|
if isinstance(integrand, FieldLike):
|
|
@@ -2126,8 +2288,12 @@ def interpolate(
|
|
|
2126
2288
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
2127
2289
|
|
|
2128
2290
|
arguments = _parse_integrand_arguments(integrand, fields)
|
|
2129
|
-
if arguments.test_name
|
|
2130
|
-
raise ValueError("Test
|
|
2291
|
+
if arguments.test_name:
|
|
2292
|
+
raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
|
|
2293
|
+
if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
|
|
2294
|
+
raise ValueError(
|
|
2295
|
+
f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
|
|
2296
|
+
)
|
|
2131
2297
|
|
|
2132
2298
|
if isinstance(dest, DiscreteField):
|
|
2133
2299
|
dest = make_restriction(dest, domain=domain)
|
|
@@ -2160,7 +2326,10 @@ def interpolate(
|
|
|
2160
2326
|
dest=dest,
|
|
2161
2327
|
quadrature=quadrature,
|
|
2162
2328
|
dim=dim,
|
|
2329
|
+
trial=fields.get(arguments.trial_name),
|
|
2163
2330
|
fields=arguments.field_args,
|
|
2164
2331
|
values=values,
|
|
2332
|
+
temporary_store=temporary_store,
|
|
2333
|
+
bsr_options=bsr_options,
|
|
2165
2334
|
device=device,
|
|
2166
2335
|
)
|
warp/fem/linalg.py
CHANGED
|
@@ -172,11 +172,11 @@ def householder_qr_decomposition(A: Any):
|
|
|
172
172
|
|
|
173
173
|
for i in range(type(x).length):
|
|
174
174
|
for k in range(type(x).length):
|
|
175
|
-
x[k] = wp.
|
|
175
|
+
x[k] = wp.where(k < i, zero, A[k, i])
|
|
176
176
|
|
|
177
177
|
alpha = wp.length(x) * wp.sign(x[i])
|
|
178
178
|
x[i] += alpha
|
|
179
|
-
two_over_x_sq = wp.
|
|
179
|
+
two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
|
|
180
180
|
|
|
181
181
|
A -= wp.outer(two_over_x_sq * x, x * A)
|
|
182
182
|
Q -= wp.outer(Q * x, two_over_x_sq * x)
|
|
@@ -201,11 +201,11 @@ def householder_make_hessenberg(A: Any):
|
|
|
201
201
|
|
|
202
202
|
for i in range(1, type(x).length):
|
|
203
203
|
for k in range(type(x).length):
|
|
204
|
-
x[k] = wp.
|
|
204
|
+
x[k] = wp.where(k < i, zero, A[k, i - 1])
|
|
205
205
|
|
|
206
206
|
alpha = wp.length(x) * wp.sign(x[i])
|
|
207
207
|
x[i] += alpha
|
|
208
|
-
two_over_x_sq = wp.
|
|
208
|
+
two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
|
|
209
209
|
|
|
210
210
|
# apply on both sides
|
|
211
211
|
A -= wp.outer(two_over_x_sq * x, x * A)
|
|
@@ -226,7 +226,7 @@ def solve_triangular(R: Any, b: Any):
|
|
|
226
226
|
for i in range(b.length, 0, -1):
|
|
227
227
|
j = i - 1
|
|
228
228
|
r = b[j] - wp.dot(R[j], x)
|
|
229
|
-
x[j] = wp.
|
|
229
|
+
x[j] = wp.where(R[j, j] == zero, zero, r / R[j, j])
|
|
230
230
|
|
|
231
231
|
return x
|
|
232
232
|
|