warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +3 -4
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/example_dem.py +28 -26
- examples/example_diffray.py +37 -30
- examples/example_fluid.py +7 -3
- examples/example_jacobian_ik.py +1 -1
- examples/example_mesh_intersect.py +10 -7
- examples/example_nvdb.py +3 -3
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +9 -5
- examples/example_sim_cloth.py +29 -25
- examples/example_sim_fk_grad.py +2 -2
- examples/example_sim_fk_grad_torch.py +3 -3
- examples/example_sim_grad_bounce.py +11 -8
- examples/example_sim_grad_cloth.py +12 -9
- examples/example_sim_granular.py +2 -2
- examples/example_sim_granular_collision_sdf.py +13 -13
- examples/example_sim_neo_hookean.py +3 -3
- examples/example_sim_particle_chain.py +2 -2
- examples/example_sim_quadruped.py +8 -5
- examples/example_sim_rigid_chain.py +8 -5
- examples/example_sim_rigid_contact.py +13 -10
- examples/example_sim_rigid_fem.py +2 -2
- examples/example_sim_rigid_gyroscopic.py +2 -2
- examples/example_sim_rigid_kinematics.py +1 -1
- examples/example_sim_trajopt.py +3 -2
- examples/fem/example_apic_fluid.py +5 -7
- examples/fem/example_diffusion_mgpu.py +18 -16
- warp/__init__.py +3 -2
- warp/bin/warp.so +0 -0
- warp/build_dll.py +29 -9
- warp/builtins.py +206 -7
- warp/codegen.py +58 -38
- warp/config.py +3 -1
- warp/context.py +234 -128
- warp/fem/__init__.py +2 -2
- warp/fem/cache.py +2 -1
- warp/fem/field/nodal_field.py +18 -17
- warp/fem/geometry/hexmesh.py +11 -6
- warp/fem/geometry/quadmesh_2d.py +16 -12
- warp/fem/geometry/tetmesh.py +19 -8
- warp/fem/geometry/trimesh_2d.py +18 -7
- warp/fem/integrate.py +341 -196
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +138 -53
- warp/fem/quadrature/quadrature.py +81 -9
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_space.py +169 -51
- warp/fem/space/grid_2d_function_space.py +2 -2
- warp/fem/space/grid_3d_function_space.py +2 -2
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +9 -6
- warp/fem/space/quadmesh_2d_function_space.py +2 -2
- warp/fem/space/shape/cube_shape_function.py +27 -15
- warp/fem/space/shape/square_shape_function.py +29 -18
- warp/fem/space/tetmesh_function_space.py +2 -2
- warp/fem/space/topology.py +10 -0
- warp/fem/space/trimesh_2d_function_space.py +2 -2
- warp/fem/utils.py +10 -5
- warp/native/array.h +49 -8
- warp/native/builtin.h +31 -14
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1177 -1108
- warp/native/intersect.h +4 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +65 -6
- warp/native/mesh.h +126 -5
- warp/native/quat.h +28 -4
- warp/native/vec.h +76 -14
- warp/native/warp.cu +1 -6
- warp/render/render_opengl.py +261 -109
- warp/sim/import_mjcf.py +13 -7
- warp/sim/import_urdf.py +14 -14
- warp/sim/inertia.py +17 -18
- warp/sim/model.py +67 -67
- warp/sim/render.py +1 -1
- warp/sparse.py +6 -6
- warp/stubs.py +19 -81
- warp/tape.py +1 -1
- warp/tests/__main__.py +3 -6
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +102 -106
- warp/tests/test_arithmetic.py +39 -40
- warp/tests/test_array.py +46 -48
- warp/tests/test_array_reduce.py +25 -19
- warp/tests/test_atomic.py +62 -26
- warp/tests/test_bool.py +16 -11
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +9 -12
- warp/tests/test_closest_point_edge_edge.py +53 -57
- warp/tests/test_codegen.py +164 -134
- warp/tests/test_compile_consts.py +13 -19
- warp/tests/test_conditional.py +30 -32
- warp/tests/test_copy.py +9 -12
- warp/tests/test_ctypes.py +90 -98
- warp/tests/test_dense.py +20 -14
- warp/tests/test_devices.py +34 -35
- warp/tests/test_dlpack.py +74 -75
- warp/tests/test_examples.py +215 -97
- warp/tests/test_fabricarray.py +15 -21
- warp/tests/test_fast_math.py +14 -11
- warp/tests/test_fem.py +280 -97
- warp/tests/test_fp16.py +19 -15
- warp/tests/test_func.py +177 -194
- warp/tests/test_generics.py +71 -77
- warp/tests/test_grad.py +83 -32
- warp/tests/test_grad_customs.py +7 -9
- warp/tests/test_hash_grid.py +6 -10
- warp/tests/test_import.py +9 -23
- warp/tests/test_indexedarray.py +19 -21
- warp/tests/test_intersect.py +15 -9
- warp/tests/test_large.py +17 -19
- warp/tests/test_launch.py +14 -17
- warp/tests/test_lerp.py +63 -63
- warp/tests/test_lvalue.py +84 -35
- warp/tests/test_marching_cubes.py +9 -13
- warp/tests/test_mat.py +388 -3004
- warp/tests/test_mat_lite.py +9 -12
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +10 -11
- warp/tests/test_matmul.py +104 -100
- warp/tests/test_matmul_lite.py +72 -98
- warp/tests/test_mesh.py +35 -32
- warp/tests/test_mesh_query_aabb.py +18 -25
- warp/tests/test_mesh_query_point.py +39 -23
- warp/tests/test_mesh_query_ray.py +9 -21
- warp/tests/test_mlp.py +8 -9
- warp/tests/test_model.py +89 -93
- warp/tests/test_modules_lite.py +15 -25
- warp/tests/test_multigpu.py +87 -114
- warp/tests/test_noise.py +10 -12
- warp/tests/test_operators.py +14 -21
- warp/tests/test_options.py +10 -11
- warp/tests/test_pinned.py +16 -18
- warp/tests/test_print.py +16 -20
- warp/tests/test_quat.py +121 -88
- warp/tests/test_rand.py +12 -13
- warp/tests/test_reload.py +27 -32
- warp/tests/test_rounding.py +7 -10
- warp/tests/test_runlength_encode.py +105 -106
- warp/tests/test_smoothstep.py +8 -9
- warp/tests/test_snippet.py +13 -22
- warp/tests/test_sparse.py +30 -29
- warp/tests/test_spatial.py +179 -174
- warp/tests/test_streams.py +100 -107
- warp/tests/test_struct.py +98 -67
- warp/tests/test_tape.py +11 -17
- warp/tests/test_torch.py +89 -86
- warp/tests/test_transient_module.py +9 -12
- warp/tests/test_types.py +328 -50
- warp/tests/test_utils.py +217 -218
- warp/tests/test_vec.py +133 -2133
- warp/tests/test_vec_lite.py +8 -11
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +391 -382
- warp/tests/test_volume_write.py +122 -135
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/{test_base.py → unittest_utils.py} +138 -25
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
- warp/thirdparty/unittest_parallel.py +257 -54
- warp/types.py +119 -98
- warp/utils.py +14 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -239
- warp/tests/test_conditional_unequal_types_kernels.py +0 -14
- warp/tests/test_coverage.py +0 -38
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py
CHANGED
|
@@ -420,36 +420,6 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
|
|
|
420
420
|
return call
|
|
421
421
|
|
|
422
422
|
|
|
423
|
-
def get_integrate_null_kernel(
|
|
424
|
-
integrand_func: wp.Function,
|
|
425
|
-
domain: GeometryDomain,
|
|
426
|
-
quadrature: Quadrature,
|
|
427
|
-
FieldStruct: wp.codegen.Struct,
|
|
428
|
-
ValueStruct: wp.codegen.Struct,
|
|
429
|
-
):
|
|
430
|
-
def integrate_kernel_fn(
|
|
431
|
-
qp_arg: quadrature.Arg,
|
|
432
|
-
domain_arg: domain.ElementArg,
|
|
433
|
-
domain_index_arg: domain.ElementIndexArg,
|
|
434
|
-
fields: FieldStruct,
|
|
435
|
-
values: ValueStruct,
|
|
436
|
-
):
|
|
437
|
-
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
438
|
-
|
|
439
|
-
test_dof_index = NULL_DOF_INDEX
|
|
440
|
-
trial_dof_index = NULL_DOF_INDEX
|
|
441
|
-
|
|
442
|
-
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
443
|
-
for k in range(qp_point_count):
|
|
444
|
-
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
445
|
-
qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
446
|
-
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
447
|
-
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
448
|
-
integrand_func(sample, fields, values)
|
|
449
|
-
|
|
450
|
-
return integrate_kernel_fn
|
|
451
|
-
|
|
452
|
-
|
|
453
423
|
def get_integrate_constant_kernel(
|
|
454
424
|
integrand_func: wp.Function,
|
|
455
425
|
domain: GeometryDomain,
|
|
@@ -477,7 +447,7 @@ def get_integrate_constant_kernel(
|
|
|
477
447
|
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
478
448
|
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
479
449
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
480
|
-
|
|
450
|
+
|
|
481
451
|
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
482
452
|
vol = domain.element_measure(domain_arg, sample)
|
|
483
453
|
|
|
@@ -497,6 +467,7 @@ def get_integrate_linear_kernel(
|
|
|
497
467
|
FieldStruct: wp.codegen.Struct,
|
|
498
468
|
ValueStruct: wp.codegen.Struct,
|
|
499
469
|
test: TestField,
|
|
470
|
+
output_dtype,
|
|
500
471
|
accumulate_dtype,
|
|
501
472
|
):
|
|
502
473
|
def integrate_kernel_fn(
|
|
@@ -506,32 +477,36 @@ def get_integrate_linear_kernel(
|
|
|
506
477
|
test_arg: test.space_restriction.NodeArg,
|
|
507
478
|
fields: FieldStruct,
|
|
508
479
|
values: ValueStruct,
|
|
509
|
-
result: wp.array2d(dtype=
|
|
480
|
+
result: wp.array2d(dtype=output_dtype),
|
|
510
481
|
):
|
|
511
|
-
local_node_index = wp.tid()
|
|
482
|
+
local_node_index, test_dof = wp.tid()
|
|
512
483
|
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
|
|
513
484
|
element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
|
|
514
485
|
|
|
515
486
|
trial_dof_index = NULL_DOF_INDEX
|
|
516
487
|
|
|
488
|
+
val_sum = accumulate_dtype(0.0)
|
|
489
|
+
|
|
517
490
|
for n in range(element_count):
|
|
518
491
|
node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
|
|
519
492
|
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
|
|
520
493
|
|
|
494
|
+
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
495
|
+
|
|
521
496
|
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
522
497
|
for k in range(qp_point_count):
|
|
523
498
|
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
524
|
-
|
|
525
|
-
|
|
499
|
+
qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
526
500
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
527
|
-
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
|
|
528
501
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
502
|
+
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
|
|
503
|
+
|
|
504
|
+
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
505
|
+
val = integrand_func(sample, fields, values)
|
|
533
506
|
|
|
534
|
-
|
|
507
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
508
|
+
|
|
509
|
+
result[node_index, test_dof] = output_dtype(val_sum)
|
|
535
510
|
|
|
536
511
|
return integrate_kernel_fn
|
|
537
512
|
|
|
@@ -542,6 +517,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
542
517
|
FieldStruct: wp.codegen.Struct,
|
|
543
518
|
ValueStruct: wp.codegen.Struct,
|
|
544
519
|
test: TestField,
|
|
520
|
+
output_dtype,
|
|
545
521
|
accumulate_dtype,
|
|
546
522
|
):
|
|
547
523
|
def integrate_kernel_fn(
|
|
@@ -550,7 +526,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
550
526
|
test_restriction_arg: test.space_restriction.NodeArg,
|
|
551
527
|
fields: FieldStruct,
|
|
552
528
|
values: ValueStruct,
|
|
553
|
-
result: wp.array2d(dtype=
|
|
529
|
+
result: wp.array2d(dtype=output_dtype),
|
|
554
530
|
):
|
|
555
531
|
local_node_index, dof = wp.tid()
|
|
556
532
|
|
|
@@ -595,7 +571,7 @@ def get_integrate_linear_nodal_kernel(
|
|
|
595
571
|
|
|
596
572
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
597
573
|
|
|
598
|
-
result[node_index, dof] = val_sum
|
|
574
|
+
result[node_index, dof] = output_dtype(val_sum)
|
|
599
575
|
|
|
600
576
|
return integrate_kernel_fn
|
|
601
577
|
|
|
@@ -608,6 +584,7 @@ def get_integrate_bilinear_kernel(
|
|
|
608
584
|
ValueStruct: wp.codegen.Struct,
|
|
609
585
|
test: TestField,
|
|
610
586
|
trial: TrialField,
|
|
587
|
+
output_dtype,
|
|
611
588
|
accumulate_dtype,
|
|
612
589
|
):
|
|
613
590
|
NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
|
|
@@ -624,19 +601,26 @@ def get_integrate_bilinear_kernel(
|
|
|
624
601
|
row_offsets: wp.array(dtype=int),
|
|
625
602
|
triplet_rows: wp.array(dtype=int),
|
|
626
603
|
triplet_cols: wp.array(dtype=int),
|
|
627
|
-
triplet_values: wp.array3d(dtype=
|
|
604
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
628
605
|
):
|
|
629
|
-
test_local_node_index = wp.tid()
|
|
606
|
+
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
|
|
630
607
|
|
|
631
608
|
element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
|
|
632
609
|
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
|
|
633
610
|
|
|
611
|
+
trial_dof_index = DofIndex(trial_node, trial_dof)
|
|
612
|
+
|
|
634
613
|
for element in range(element_count):
|
|
635
614
|
test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
|
|
636
615
|
element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
|
|
637
616
|
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
638
617
|
|
|
639
|
-
|
|
618
|
+
test_dof_index = DofIndex(
|
|
619
|
+
test_element_index.node_index_in_element,
|
|
620
|
+
test_dof,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
val_sum = accumulate_dtype(0.0)
|
|
640
624
|
|
|
641
625
|
for k in range(qp_point_count):
|
|
642
626
|
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
@@ -645,42 +629,28 @@ def get_integrate_bilinear_kernel(
|
|
|
645
629
|
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
646
630
|
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
|
|
647
631
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
trial_dof_index,
|
|
665
|
-
)
|
|
666
|
-
val = integrand_func(sample, fields, values)
|
|
667
|
-
triplet_values[offset_cur, i, j] = triplet_values[offset_cur, i, j] + accumulate_dtype(
|
|
668
|
-
qp_weight * vol * val
|
|
669
|
-
)
|
|
670
|
-
|
|
671
|
-
offset_cur += 1
|
|
672
|
-
|
|
673
|
-
# Set column indices
|
|
674
|
-
offset_cur = start_offset
|
|
675
|
-
for trial_n in range(NODES_PER_ELEMENT):
|
|
632
|
+
sample = Sample(
|
|
633
|
+
element_index,
|
|
634
|
+
coords,
|
|
635
|
+
qp_index,
|
|
636
|
+
qp_weight,
|
|
637
|
+
test_dof_index,
|
|
638
|
+
trial_dof_index,
|
|
639
|
+
)
|
|
640
|
+
val = integrand_func(sample, fields, values)
|
|
641
|
+
val_sum += accumulate_dtype(qp_weight * vol * val)
|
|
642
|
+
|
|
643
|
+
block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
|
|
644
|
+
triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
|
|
645
|
+
|
|
646
|
+
# Set row and column indices
|
|
647
|
+
if test_dof == 0 and trial_dof == 0:
|
|
676
648
|
trial_node_index = trial.space_partition.partition_node_index(
|
|
677
649
|
trial_partition_arg,
|
|
678
|
-
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index,
|
|
650
|
+
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
|
|
679
651
|
)
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
triplet_cols[offset_cur] = trial_node_index
|
|
683
|
-
offset_cur += 1
|
|
652
|
+
triplet_rows[block_offset] = test_node_index
|
|
653
|
+
triplet_cols[block_offset] = trial_node_index
|
|
684
654
|
|
|
685
655
|
return integrate_kernel_fn
|
|
686
656
|
|
|
@@ -691,6 +661,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
691
661
|
FieldStruct: wp.codegen.Struct,
|
|
692
662
|
ValueStruct: wp.codegen.Struct,
|
|
693
663
|
test: TestField,
|
|
664
|
+
output_dtype,
|
|
694
665
|
accumulate_dtype,
|
|
695
666
|
):
|
|
696
667
|
def integrate_kernel_fn(
|
|
@@ -701,7 +672,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
701
672
|
values: ValueStruct,
|
|
702
673
|
triplet_rows: wp.array(dtype=int),
|
|
703
674
|
triplet_cols: wp.array(dtype=int),
|
|
704
|
-
triplet_values: wp.array3d(dtype=
|
|
675
|
+
triplet_values: wp.array3d(dtype=output_dtype),
|
|
705
676
|
):
|
|
706
677
|
local_node_index, test_dof, trial_dof = wp.tid()
|
|
707
678
|
|
|
@@ -729,7 +700,6 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
729
700
|
node_element_index.node_index_in_element,
|
|
730
701
|
)
|
|
731
702
|
|
|
732
|
-
|
|
733
703
|
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
|
|
734
704
|
trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
|
|
735
705
|
|
|
@@ -746,7 +716,7 @@ def get_integrate_bilinear_nodal_kernel(
|
|
|
746
716
|
|
|
747
717
|
val_sum += accumulate_dtype(node_weight * vol * val)
|
|
748
718
|
|
|
749
|
-
triplet_values[local_node_index, test_dof, trial_dof] = val_sum
|
|
719
|
+
triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
|
|
750
720
|
triplet_rows[local_node_index] = node_index
|
|
751
721
|
triplet_cols[local_node_index] = node_index
|
|
752
722
|
|
|
@@ -763,9 +733,12 @@ def _generate_integrate_kernel(
|
|
|
763
733
|
trial: Optional[TrialField],
|
|
764
734
|
trial_name: str,
|
|
765
735
|
fields: Dict[str, FieldLike],
|
|
736
|
+
output_dtype: type,
|
|
766
737
|
accumulate_dtype: type,
|
|
767
738
|
kernel_options: Dict[str, Any] = {},
|
|
768
739
|
) -> wp.Kernel:
|
|
740
|
+
output_dtype = wp.types.type_scalar_type(output_dtype)
|
|
741
|
+
|
|
769
742
|
# Extract field arguments from integrand
|
|
770
743
|
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
771
744
|
integrand, fields=fields, domain=domain
|
|
@@ -775,7 +748,7 @@ def _generate_integrate_kernel(
|
|
|
775
748
|
ValueStruct = _gen_value_struct(value_args)
|
|
776
749
|
|
|
777
750
|
# Check if kernel exist in cache
|
|
778
|
-
kernel_suffix = f"_itg_{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
|
|
751
|
+
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
|
|
779
752
|
if nodal:
|
|
780
753
|
kernel_suffix += "_nodal"
|
|
781
754
|
else:
|
|
@@ -819,6 +792,7 @@ def _generate_integrate_kernel(
|
|
|
819
792
|
FieldStruct,
|
|
820
793
|
ValueStruct,
|
|
821
794
|
test=test,
|
|
795
|
+
output_dtype=output_dtype,
|
|
822
796
|
accumulate_dtype=accumulate_dtype,
|
|
823
797
|
)
|
|
824
798
|
else:
|
|
@@ -829,6 +803,7 @@ def _generate_integrate_kernel(
|
|
|
829
803
|
FieldStruct,
|
|
830
804
|
ValueStruct,
|
|
831
805
|
test=test,
|
|
806
|
+
output_dtype=output_dtype,
|
|
832
807
|
accumulate_dtype=accumulate_dtype,
|
|
833
808
|
)
|
|
834
809
|
else:
|
|
@@ -839,6 +814,7 @@ def _generate_integrate_kernel(
|
|
|
839
814
|
FieldStruct,
|
|
840
815
|
ValueStruct,
|
|
841
816
|
test=test,
|
|
817
|
+
output_dtype=output_dtype,
|
|
842
818
|
accumulate_dtype=accumulate_dtype,
|
|
843
819
|
)
|
|
844
820
|
else:
|
|
@@ -850,6 +826,7 @@ def _generate_integrate_kernel(
|
|
|
850
826
|
ValueStruct,
|
|
851
827
|
test=test,
|
|
852
828
|
trial=trial,
|
|
829
|
+
output_dtype=output_dtype,
|
|
853
830
|
accumulate_dtype=accumulate_dtype,
|
|
854
831
|
)
|
|
855
832
|
|
|
@@ -949,32 +926,46 @@ def _launch_integrate_kernel(
|
|
|
949
926
|
|
|
950
927
|
# Linear form
|
|
951
928
|
if trial is None:
|
|
952
|
-
if test.space.VALUE_DOF_COUNT == 1:
|
|
953
|
-
accumulate_array_dtype = accumulate_dtype
|
|
954
|
-
else:
|
|
955
|
-
accumulate_array_dtype = cache.cached_vec_type(length=test.space.VALUE_DOF_COUNT, dtype=accumulate_dtype)
|
|
956
|
-
|
|
957
|
-
if output is not None and output.size < test.space_partition.node_count():
|
|
958
|
-
raise RuntimeError(f"Output array must be of size at least {test.space_partition.node_count()}")
|
|
959
|
-
|
|
960
|
-
accumulate_in_place = wp.types.types_equal(accumulate_array_dtype, output_dtype)
|
|
961
|
-
|
|
962
929
|
# If an output array is provided with the correct type, accumulate directly into it
|
|
963
930
|
# Otherwise, grab a temporary array
|
|
964
|
-
if output is
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
931
|
+
if output is None:
|
|
932
|
+
if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
|
|
933
|
+
output_shape = (test.space_partition.node_count(),)
|
|
934
|
+
elif type_length(output_dtype) == 1:
|
|
935
|
+
output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
|
|
936
|
+
else:
|
|
937
|
+
raise RuntimeError(
|
|
938
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
output_temporary = cache.borrow_temporary(
|
|
968
942
|
temporary_store=temporary_store,
|
|
969
|
-
shape=
|
|
970
|
-
dtype=
|
|
943
|
+
shape=output_shape,
|
|
944
|
+
dtype=output_dtype,
|
|
971
945
|
device=device,
|
|
972
|
-
requires_grad=output is not None and output.requires_grad,
|
|
973
946
|
)
|
|
974
|
-
|
|
947
|
+
|
|
948
|
+
output = output_temporary.array
|
|
949
|
+
|
|
950
|
+
else:
|
|
951
|
+
output_temporary = None
|
|
952
|
+
|
|
953
|
+
if output.shape[0] < test.space_partition.node_count():
|
|
954
|
+
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
|
|
955
|
+
|
|
956
|
+
output_dtype = output.dtype
|
|
957
|
+
if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
|
|
958
|
+
if type_length(output_dtype) != 1:
|
|
959
|
+
raise RuntimeError(
|
|
960
|
+
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
|
|
961
|
+
)
|
|
962
|
+
if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
|
|
963
|
+
raise RuntimeError(
|
|
964
|
+
f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
|
|
965
|
+
)
|
|
975
966
|
|
|
976
967
|
# Launch the integration on the kernel on a 2d scalar view of the actual array
|
|
977
|
-
|
|
968
|
+
output.zero_()
|
|
978
969
|
|
|
979
970
|
def as_2d_array(array):
|
|
980
971
|
return wp.array(
|
|
@@ -984,11 +975,11 @@ def _launch_integrate_kernel(
|
|
|
984
975
|
owner=False,
|
|
985
976
|
device=array.device,
|
|
986
977
|
shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
|
|
987
|
-
dtype=
|
|
978
|
+
dtype=wp.types.type_scalar_type(output_dtype),
|
|
988
979
|
grad=None if array.grad is None else as_2d_array(array.grad),
|
|
989
980
|
)
|
|
990
981
|
|
|
991
|
-
|
|
982
|
+
output_view = output if output.ndim == 2 else as_2d_array(output)
|
|
992
983
|
|
|
993
984
|
if nodal:
|
|
994
985
|
wp.launch(
|
|
@@ -1000,14 +991,14 @@ def _launch_integrate_kernel(
|
|
|
1000
991
|
test_arg,
|
|
1001
992
|
field_arg_values,
|
|
1002
993
|
value_struct_values,
|
|
1003
|
-
|
|
994
|
+
output_view,
|
|
1004
995
|
],
|
|
1005
996
|
device=device,
|
|
1006
997
|
)
|
|
1007
998
|
else:
|
|
1008
999
|
wp.launch(
|
|
1009
1000
|
kernel=kernel,
|
|
1010
|
-
dim=test.space_restriction.node_count(),
|
|
1001
|
+
dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
|
|
1011
1002
|
inputs=[
|
|
1012
1003
|
qp_arg,
|
|
1013
1004
|
domain_elt_arg,
|
|
@@ -1015,47 +1006,23 @@ def _launch_integrate_kernel(
|
|
|
1015
1006
|
test_arg,
|
|
1016
1007
|
field_arg_values,
|
|
1017
1008
|
value_struct_values,
|
|
1018
|
-
|
|
1009
|
+
output_view,
|
|
1019
1010
|
],
|
|
1020
1011
|
device=device,
|
|
1021
1012
|
)
|
|
1022
1013
|
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
return output
|
|
1026
|
-
|
|
1027
|
-
if accumulate_in_place:
|
|
1028
|
-
return accumulate_temporary.detach()
|
|
1014
|
+
if output_temporary is not None:
|
|
1015
|
+
return output_temporary.detach()
|
|
1029
1016
|
|
|
1030
|
-
|
|
1031
|
-
if output is not None:
|
|
1032
|
-
cast_result = output
|
|
1033
|
-
elif type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
|
|
1034
|
-
cast_result = wp.empty(
|
|
1035
|
-
dtype=output_dtype,
|
|
1036
|
-
shape=accumulate_array.shape,
|
|
1037
|
-
device=device,
|
|
1038
|
-
requires_grad=accumulate_array.requires_grad,
|
|
1039
|
-
)
|
|
1040
|
-
else:
|
|
1041
|
-
cast_result = wp.empty(
|
|
1042
|
-
dtype=output_dtype,
|
|
1043
|
-
shape=accumulate_2d_view.shape,
|
|
1044
|
-
device=device,
|
|
1045
|
-
requires_grad=accumulate_array.requires_grad,
|
|
1046
|
-
)
|
|
1047
|
-
|
|
1048
|
-
array_cast(in_array=accumulate_array, out_array=cast_result)
|
|
1049
|
-
accumulate_temporary.release() # Do not wait for garbage collection
|
|
1050
|
-
return cast_result
|
|
1017
|
+
return output
|
|
1051
1018
|
|
|
1052
1019
|
# Bilinear form
|
|
1053
1020
|
|
|
1054
1021
|
if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
|
|
1055
|
-
block_type =
|
|
1022
|
+
block_type = output_dtype
|
|
1056
1023
|
else:
|
|
1057
1024
|
block_type = cache.cached_mat_type(
|
|
1058
|
-
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=
|
|
1025
|
+
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
|
|
1059
1026
|
)
|
|
1060
1027
|
|
|
1061
1028
|
if nodal:
|
|
@@ -1072,7 +1039,7 @@ def _launch_integrate_kernel(
|
|
|
1072
1039
|
test.space.VALUE_DOF_COUNT,
|
|
1073
1040
|
trial.space.VALUE_DOF_COUNT,
|
|
1074
1041
|
),
|
|
1075
|
-
dtype=
|
|
1042
|
+
dtype=output_dtype,
|
|
1076
1043
|
device=device,
|
|
1077
1044
|
)
|
|
1078
1045
|
triplet_cols = triplet_cols_temp.array
|
|
@@ -1105,7 +1072,12 @@ def _launch_integrate_kernel(
|
|
|
1105
1072
|
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
|
|
1106
1073
|
wp.launch(
|
|
1107
1074
|
kernel=kernel,
|
|
1108
|
-
dim=
|
|
1075
|
+
dim=(
|
|
1076
|
+
test.space_restriction.node_count(),
|
|
1077
|
+
trial.space.topology.NODES_PER_ELEMENT,
|
|
1078
|
+
test.space.VALUE_DOF_COUNT,
|
|
1079
|
+
trial.space.VALUE_DOF_COUNT,
|
|
1080
|
+
),
|
|
1109
1081
|
inputs=[
|
|
1110
1082
|
qp_arg,
|
|
1111
1083
|
domain_elt_arg,
|
|
@@ -1123,38 +1095,27 @@ def _launch_integrate_kernel(
|
|
|
1123
1095
|
device=device,
|
|
1124
1096
|
)
|
|
1125
1097
|
|
|
1126
|
-
compress_in_place = accumulate_dtype == output_dtype
|
|
1127
|
-
|
|
1128
1098
|
if output is not None:
|
|
1129
1099
|
if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
|
|
1130
1100
|
raise RuntimeError(
|
|
1131
1101
|
f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
|
|
1132
1102
|
)
|
|
1133
1103
|
|
|
1134
|
-
if output is not None and compress_in_place:
|
|
1135
|
-
bsr_matrix = output
|
|
1136
1104
|
else:
|
|
1137
|
-
|
|
1105
|
+
output = bsr_zeros(
|
|
1138
1106
|
rows_of_blocks=test.space_partition.node_count(),
|
|
1139
1107
|
cols_of_blocks=trial.space_partition.node_count(),
|
|
1140
1108
|
block_type=block_type,
|
|
1141
1109
|
device=device,
|
|
1142
1110
|
)
|
|
1143
1111
|
|
|
1144
|
-
bsr_set_from_triplets(
|
|
1112
|
+
bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
|
|
1145
1113
|
|
|
1146
1114
|
# Do not wait for garbage collection
|
|
1147
1115
|
triplet_values_temp.release()
|
|
1148
1116
|
triplet_rows_temp.release()
|
|
1149
1117
|
triplet_cols_temp.release()
|
|
1150
1118
|
|
|
1151
|
-
if compress_in_place:
|
|
1152
|
-
return bsr_matrix
|
|
1153
|
-
|
|
1154
|
-
if output is None:
|
|
1155
|
-
return bsr_copy(bsr_matrix, scalar_type=output_dtype)
|
|
1156
|
-
|
|
1157
|
-
bsr_assign(src=bsr_matrix, dest=output)
|
|
1158
1119
|
return output
|
|
1159
1120
|
|
|
1160
1121
|
|
|
@@ -1181,7 +1142,7 @@ def integrate(
|
|
|
1181
1142
|
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
|
|
1182
1143
|
nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
|
|
1183
1144
|
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
|
|
1184
|
-
values: Additional variable values to be passed to the integrand, can
|
|
1145
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
|
|
1185
1146
|
temporary_store: shared pool from which to allocate temporary arrays
|
|
1186
1147
|
accumulate_dtype: Scalar type to be used for accumulating integration samples
|
|
1187
1148
|
output: Sparse matrix or warp array into which to store the result of the integration
|
|
@@ -1246,6 +1207,7 @@ def integrate(
|
|
|
1246
1207
|
trial_name=trial_name,
|
|
1247
1208
|
fields=fields,
|
|
1248
1209
|
accumulate_dtype=accumulate_dtype,
|
|
1210
|
+
output_dtype=output_dtype,
|
|
1249
1211
|
kernel_options=kernel_options,
|
|
1250
1212
|
)
|
|
1251
1213
|
|
|
@@ -1268,7 +1230,7 @@ def integrate(
|
|
|
1268
1230
|
)
|
|
1269
1231
|
|
|
1270
1232
|
|
|
1271
|
-
def
|
|
1233
|
+
def get_interpolate_to_field_function(
|
|
1272
1234
|
integrand_func: wp.Function,
|
|
1273
1235
|
domain: GeometryDomain,
|
|
1274
1236
|
FieldStruct: wp.codegen.Struct,
|
|
@@ -1277,7 +1239,8 @@ def get_interpolate_kernel(
|
|
|
1277
1239
|
):
|
|
1278
1240
|
value_type = dest.space.dtype
|
|
1279
1241
|
|
|
1280
|
-
def
|
|
1242
|
+
def interpolate_to_field_fn(
|
|
1243
|
+
local_node_index: int,
|
|
1281
1244
|
domain_arg: domain.ElementArg,
|
|
1282
1245
|
domain_index_arg: domain.ElementIndexArg,
|
|
1283
1246
|
dest_node_arg: dest.space_restriction.NodeArg,
|
|
@@ -1285,19 +1248,15 @@ def get_interpolate_kernel(
|
|
|
1285
1248
|
fields: FieldStruct,
|
|
1286
1249
|
values: ValueStruct,
|
|
1287
1250
|
):
|
|
1288
|
-
local_node_index = wp.tid()
|
|
1289
1251
|
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1290
|
-
|
|
1291
1252
|
element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
|
|
1292
|
-
if element_count == 0:
|
|
1293
|
-
return
|
|
1294
1253
|
|
|
1295
1254
|
test_dof_index = NULL_DOF_INDEX
|
|
1296
1255
|
trial_dof_index = NULL_DOF_INDEX
|
|
1297
1256
|
node_weight = 1.0
|
|
1298
1257
|
|
|
1299
|
-
# Volume-weighted average
|
|
1300
|
-
# Superfluous if the function is continuous, but
|
|
1258
|
+
# Volume-weighted average across elements
|
|
1259
|
+
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
|
|
1301
1260
|
|
|
1302
1261
|
val_sum = value_type(0.0)
|
|
1303
1262
|
vol_sum = float(0.0)
|
|
@@ -1328,15 +1287,112 @@ def get_interpolate_kernel(
|
|
|
1328
1287
|
vol_sum += vol
|
|
1329
1288
|
val_sum += vol * val
|
|
1330
1289
|
|
|
1290
|
+
return val_sum, vol_sum
|
|
1291
|
+
|
|
1292
|
+
return interpolate_to_field_fn
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
def get_interpolate_to_field_kernel(
|
|
1296
|
+
interpolate_to_field_fn: wp.Function,
|
|
1297
|
+
domain: GeometryDomain,
|
|
1298
|
+
FieldStruct: wp.codegen.Struct,
|
|
1299
|
+
ValueStruct: wp.codegen.Struct,
|
|
1300
|
+
dest: FieldRestriction,
|
|
1301
|
+
):
|
|
1302
|
+
def interpolate_to_field_kernel_fn(
|
|
1303
|
+
domain_arg: domain.ElementArg,
|
|
1304
|
+
domain_index_arg: domain.ElementIndexArg,
|
|
1305
|
+
dest_node_arg: dest.space_restriction.NodeArg,
|
|
1306
|
+
dest_eval_arg: dest.field.EvalArg,
|
|
1307
|
+
fields: FieldStruct,
|
|
1308
|
+
values: ValueStruct,
|
|
1309
|
+
):
|
|
1310
|
+
local_node_index = wp.tid()
|
|
1311
|
+
|
|
1312
|
+
val_sum, vol_sum = interpolate_to_field_fn(
|
|
1313
|
+
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1331
1316
|
if vol_sum > 0.0:
|
|
1317
|
+
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
|
|
1332
1318
|
dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
|
|
1333
1319
|
|
|
1334
|
-
return
|
|
1320
|
+
return interpolate_to_field_kernel_fn
|
|
1321
|
+
|
|
1322
|
+
|
|
1323
|
+
def get_interpolate_to_array_kernel(
|
|
1324
|
+
integrand_func: wp.Function,
|
|
1325
|
+
domain: GeometryDomain,
|
|
1326
|
+
quadrature: Quadrature,
|
|
1327
|
+
FieldStruct: wp.codegen.Struct,
|
|
1328
|
+
ValueStruct: wp.codegen.Struct,
|
|
1329
|
+
value_type: type,
|
|
1330
|
+
):
|
|
1331
|
+
def interpolate_to_array_kernel_fn(
|
|
1332
|
+
qp_arg: quadrature.Arg,
|
|
1333
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1334
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1335
|
+
fields: FieldStruct,
|
|
1336
|
+
values: ValueStruct,
|
|
1337
|
+
result: wp.array(dtype=value_type),
|
|
1338
|
+
):
|
|
1339
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
1340
|
+
|
|
1341
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1342
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1343
|
+
|
|
1344
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1345
|
+
for k in range(qp_point_count):
|
|
1346
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1347
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1348
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1349
|
+
|
|
1350
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1351
|
+
|
|
1352
|
+
result[qp_index] = integrand_func(sample, fields, values)
|
|
1353
|
+
|
|
1354
|
+
return interpolate_to_array_kernel_fn
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
def get_interpolate_nonvalued_kernel(
|
|
1358
|
+
integrand_func: wp.Function,
|
|
1359
|
+
domain: GeometryDomain,
|
|
1360
|
+
quadrature: Quadrature,
|
|
1361
|
+
FieldStruct: wp.codegen.Struct,
|
|
1362
|
+
ValueStruct: wp.codegen.Struct,
|
|
1363
|
+
):
|
|
1364
|
+
def interpolate_nonvalued_kernel_fn(
|
|
1365
|
+
qp_arg: quadrature.Arg,
|
|
1366
|
+
domain_arg: quadrature.domain.ElementArg,
|
|
1367
|
+
domain_index_arg: quadrature.domain.ElementIndexArg,
|
|
1368
|
+
fields: FieldStruct,
|
|
1369
|
+
values: ValueStruct,
|
|
1370
|
+
):
|
|
1371
|
+
element_index = domain.element_index(domain_index_arg, wp.tid())
|
|
1335
1372
|
|
|
1373
|
+
test_dof_index = NULL_DOF_INDEX
|
|
1374
|
+
trial_dof_index = NULL_DOF_INDEX
|
|
1336
1375
|
|
|
1337
|
-
|
|
1338
|
-
|
|
1376
|
+
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
|
|
1377
|
+
for k in range(qp_point_count):
|
|
1378
|
+
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
|
|
1379
|
+
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
|
|
1380
|
+
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
|
|
1339
1381
|
|
|
1382
|
+
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
|
|
1383
|
+
integrand_func(sample, fields, values)
|
|
1384
|
+
|
|
1385
|
+
return interpolate_nonvalued_kernel_fn
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
def _generate_interpolate_kernel(
|
|
1389
|
+
integrand: Integrand,
|
|
1390
|
+
domain: GeometryDomain,
|
|
1391
|
+
dest: Optional[Union[FieldLike, wp.array]],
|
|
1392
|
+
quadrature: Optional[Quadrature],
|
|
1393
|
+
fields: Dict[str, FieldLike],
|
|
1394
|
+
kernel_options: Dict[str, Any] = {},
|
|
1395
|
+
) -> wp.Kernel:
|
|
1340
1396
|
# Extract field arguments from integrand
|
|
1341
1397
|
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
|
|
1342
1398
|
integrand, fields=fields, domain=domain
|
|
@@ -1354,9 +1410,14 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
|
|
|
1354
1410
|
ValueStruct = _gen_value_struct(value_args)
|
|
1355
1411
|
|
|
1356
1412
|
# Check if kernel exist in cache
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1413
|
+
if isinstance(dest, FieldRestriction):
|
|
1414
|
+
kernel_suffix = (
|
|
1415
|
+
f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
|
|
1416
|
+
)
|
|
1417
|
+
elif wp.types.is_array(dest):
|
|
1418
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
|
|
1419
|
+
else:
|
|
1420
|
+
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
|
|
1360
1421
|
|
|
1361
1422
|
kernel = cache.get_integrand_kernel(
|
|
1362
1423
|
integrand=integrand,
|
|
@@ -1366,18 +1427,61 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
|
|
|
1366
1427
|
return kernel, FieldStruct, ValueStruct
|
|
1367
1428
|
|
|
1368
1429
|
# Generate interpolation kernel
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1430
|
+
if isinstance(dest, FieldRestriction):
|
|
1431
|
+
# need to split into kernel + function for diffferentiability
|
|
1432
|
+
interpolate_fn = get_interpolate_to_field_function(
|
|
1433
|
+
integrand_func,
|
|
1434
|
+
domain,
|
|
1435
|
+
dest=dest,
|
|
1436
|
+
FieldStruct=FieldStruct,
|
|
1437
|
+
ValueStruct=ValueStruct,
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
interpolate_fn = cache.get_integrand_function(
|
|
1441
|
+
integrand=integrand,
|
|
1442
|
+
func=interpolate_fn,
|
|
1443
|
+
suffix=kernel_suffix,
|
|
1444
|
+
code_transformers=[
|
|
1445
|
+
PassFieldArgsToIntegrand(
|
|
1446
|
+
arg_names=integrand.argspec.args,
|
|
1447
|
+
field_args=field_args.keys(),
|
|
1448
|
+
value_args=value_args.keys(),
|
|
1449
|
+
sample_name=sample_name,
|
|
1450
|
+
domain_name=domain_name,
|
|
1451
|
+
)
|
|
1452
|
+
],
|
|
1453
|
+
)
|
|
1454
|
+
|
|
1455
|
+
interpolate_kernel_fn = get_interpolate_to_field_kernel(
|
|
1456
|
+
interpolate_fn,
|
|
1457
|
+
domain,
|
|
1458
|
+
dest=dest,
|
|
1459
|
+
FieldStruct=FieldStruct,
|
|
1460
|
+
ValueStruct=ValueStruct,
|
|
1461
|
+
)
|
|
1462
|
+
elif wp.types.is_array(dest):
|
|
1463
|
+
interpolate_kernel_fn = get_interpolate_to_array_kernel(
|
|
1464
|
+
integrand_func,
|
|
1465
|
+
domain=domain,
|
|
1466
|
+
quadrature=quadrature,
|
|
1467
|
+
value_type=dest.dtype,
|
|
1468
|
+
FieldStruct=FieldStruct,
|
|
1469
|
+
ValueStruct=ValueStruct,
|
|
1470
|
+
)
|
|
1471
|
+
else:
|
|
1472
|
+
interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
|
|
1473
|
+
integrand_func,
|
|
1474
|
+
domain=domain,
|
|
1475
|
+
quadrature=quadrature,
|
|
1476
|
+
FieldStruct=FieldStruct,
|
|
1477
|
+
ValueStruct=ValueStruct,
|
|
1478
|
+
)
|
|
1376
1479
|
|
|
1377
1480
|
kernel = cache.get_integrand_kernel(
|
|
1378
1481
|
integrand=integrand,
|
|
1379
1482
|
kernel_fn=interpolate_kernel_fn,
|
|
1380
1483
|
suffix=kernel_suffix,
|
|
1484
|
+
kernel_options=kernel_options,
|
|
1381
1485
|
code_transformers=[
|
|
1382
1486
|
PassFieldArgsToIntegrand(
|
|
1383
1487
|
arg_names=integrand.argspec.args,
|
|
@@ -1396,16 +1500,16 @@ def _launch_interpolate_kernel(
|
|
|
1396
1500
|
kernel: wp.kernel,
|
|
1397
1501
|
FieldStruct: wp.codegen.Struct,
|
|
1398
1502
|
ValueStruct: wp.codegen.Struct,
|
|
1399
|
-
|
|
1503
|
+
domain: GeometryDomain,
|
|
1504
|
+
dest: Optional[Union[FieldRestriction, wp.array]],
|
|
1505
|
+
quadrature: Optional[Quadrature],
|
|
1400
1506
|
fields: Dict[str, FieldLike],
|
|
1401
1507
|
values: Dict[str, Any],
|
|
1402
1508
|
device,
|
|
1403
1509
|
) -> wp.Kernel:
|
|
1404
1510
|
# Set-up launch arguments
|
|
1405
|
-
elt_arg =
|
|
1406
|
-
elt_index_arg =
|
|
1407
|
-
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
1408
|
-
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
1511
|
+
elt_arg = domain.element_arg_value(device=device)
|
|
1512
|
+
elt_index_arg = domain.element_index_arg_value(device=device)
|
|
1409
1513
|
|
|
1410
1514
|
field_arg_values = FieldStruct()
|
|
1411
1515
|
for k, v in fields.items():
|
|
@@ -1415,37 +1519,65 @@ def _launch_interpolate_kernel(
|
|
|
1415
1519
|
for k, v in values.items():
|
|
1416
1520
|
setattr(value_struct_values, k, v)
|
|
1417
1521
|
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1522
|
+
if isinstance(dest, FieldRestriction):
|
|
1523
|
+
dest_node_arg = dest.space_restriction.node_arg(device=device)
|
|
1524
|
+
dest_eval_arg = dest.field.eval_arg_value(device=device)
|
|
1525
|
+
|
|
1526
|
+
wp.launch(
|
|
1527
|
+
kernel=kernel,
|
|
1528
|
+
dim=dest.space_restriction.node_count(),
|
|
1529
|
+
inputs=[
|
|
1530
|
+
elt_arg,
|
|
1531
|
+
elt_index_arg,
|
|
1532
|
+
dest_node_arg,
|
|
1533
|
+
dest_eval_arg,
|
|
1534
|
+
field_arg_values,
|
|
1535
|
+
value_struct_values,
|
|
1536
|
+
],
|
|
1537
|
+
device=device,
|
|
1538
|
+
)
|
|
1539
|
+
elif wp.types.is_array(dest):
|
|
1540
|
+
qp_arg = quadrature.arg_value(device)
|
|
1541
|
+
wp.launch(
|
|
1542
|
+
kernel=kernel,
|
|
1543
|
+
dim=domain.element_count(),
|
|
1544
|
+
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
|
|
1545
|
+
device=device,
|
|
1546
|
+
)
|
|
1547
|
+
else:
|
|
1548
|
+
qp_arg = quadrature.arg_value(device)
|
|
1549
|
+
wp.launch(
|
|
1550
|
+
kernel=kernel,
|
|
1551
|
+
dim=domain.element_count(),
|
|
1552
|
+
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
|
|
1553
|
+
device=device,
|
|
1554
|
+
)
|
|
1431
1555
|
|
|
1432
1556
|
|
|
1433
1557
|
def interpolate(
|
|
1434
1558
|
integrand: Integrand,
|
|
1435
|
-
dest: Union[DiscreteField, FieldRestriction],
|
|
1559
|
+
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
|
|
1560
|
+
quadrature: Optional[Quadrature] = None,
|
|
1436
1561
|
fields: Dict[str, FieldLike] = {},
|
|
1437
1562
|
values: Dict[str, Any] = {},
|
|
1438
1563
|
device=None,
|
|
1564
|
+
kernel_options: Dict[str, Any] = {},
|
|
1439
1565
|
):
|
|
1440
1566
|
"""
|
|
1441
|
-
Interpolates a function and assigns the result to a discrete field.
|
|
1567
|
+
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
|
|
1442
1568
|
|
|
1443
1569
|
Args:
|
|
1444
1570
|
integrand: Function to be interpolated, must have :func:`integrand` decorator
|
|
1445
|
-
dest:
|
|
1571
|
+
dest: Where to store the interpolation result. Can be either
|
|
1572
|
+
|
|
1573
|
+
- a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
|
|
1574
|
+
- a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
|
|
1575
|
+
- ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is reponsible for dealing with the interpolation result.
|
|
1576
|
+
quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
|
|
1446
1577
|
fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
|
|
1447
|
-
values: Additional variable values to be passed to the integrand, can
|
|
1578
|
+
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
|
|
1448
1579
|
device: Device on which to perform the interpolation
|
|
1580
|
+
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
|
|
1449
1581
|
"""
|
|
1450
1582
|
if not isinstance(integrand, Integrand):
|
|
1451
1583
|
raise ValueError("integrand must be tagged with @integrand decorator")
|
|
@@ -1454,20 +1586,33 @@ def interpolate(
|
|
|
1454
1586
|
if test is not None or trial is not None:
|
|
1455
1587
|
raise ValueError("Test or Trial fields should not be used for interpolation")
|
|
1456
1588
|
|
|
1457
|
-
if
|
|
1589
|
+
if isinstance(dest, DiscreteField):
|
|
1458
1590
|
dest = make_restriction(dest)
|
|
1459
1591
|
|
|
1592
|
+
if isinstance(dest, FieldRestriction):
|
|
1593
|
+
domain = dest.domain
|
|
1594
|
+
else:
|
|
1595
|
+
if quadrature is None:
|
|
1596
|
+
raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
|
|
1597
|
+
|
|
1598
|
+
domain = quadrature.domain
|
|
1599
|
+
|
|
1460
1600
|
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
|
|
1461
1601
|
integrand=integrand,
|
|
1602
|
+
domain=domain,
|
|
1462
1603
|
dest=dest,
|
|
1604
|
+
quadrature=quadrature,
|
|
1463
1605
|
fields=fields,
|
|
1606
|
+
kernel_options=kernel_options,
|
|
1464
1607
|
)
|
|
1465
1608
|
|
|
1466
1609
|
return _launch_interpolate_kernel(
|
|
1467
1610
|
kernel=kernel,
|
|
1468
1611
|
FieldStruct=FieldStruct,
|
|
1469
1612
|
ValueStruct=ValueStruct,
|
|
1613
|
+
domain=domain,
|
|
1470
1614
|
dest=dest,
|
|
1615
|
+
quadrature=quadrature,
|
|
1471
1616
|
fields=fields,
|
|
1472
1617
|
values=values,
|
|
1473
1618
|
device=device,
|