warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/fem/field/virtual.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, ClassVar, Dict, Set
16
+ from typing import Any, ClassVar, Dict, Optional, Set
17
17
 
18
18
  import warp as wp
19
19
  import warp.fem.operator as operator
@@ -22,7 +22,16 @@ from warp.fem.domain import GeometryDomain
22
22
  from warp.fem.linalg import basis_coefficient, generalized_inner, generalized_outer
23
23
  from warp.fem.quadrature import Quadrature
24
24
  from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction
25
- from warp.fem.types import NULL_NODE_INDEX, DofIndex, Sample, get_node_coord, get_node_index_in_element
25
+ from warp.fem.types import (
26
+ NULL_ELEMENT_INDEX,
27
+ NULL_NODE_INDEX,
28
+ DofIndex,
29
+ ElementIndex,
30
+ NodeElementIndex,
31
+ Sample,
32
+ get_node_coord,
33
+ get_node_index_in_element,
34
+ )
26
35
  from warp.fem.utils import type_zero_element
27
36
 
28
37
  from .field import SpaceField
@@ -567,7 +576,13 @@ class LocalTrialField(LocalAdjointField):
567
576
  return s.trial_dof
568
577
 
569
578
 
570
- def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, accumulate_dtype: type):
579
+ def make_linear_dispatch_kernel(
580
+ test: LocalTestField,
581
+ quadrature: Quadrature,
582
+ accumulate_dtype: type,
583
+ tile_size: int = 1,
584
+ kernel_options: Optional[Dict[str, Any]] = None,
585
+ ):
571
586
  global_test: TestField = test.global_field
572
587
  space_restriction = global_test.space_restriction
573
588
  domain = global_test.domain
@@ -583,8 +598,42 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
583
598
  TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
584
599
 
585
600
  TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
601
+ TEST_NODE_DOF_COUNT = test.node_dof_count
602
+
603
+ res_vec = cache.cached_vec_type(length=test.node_dof_count, dtype=accumulate_dtype)
604
+ qp_vec = cache.cached_vec_type(length=test.node_dof_count, dtype=float)
605
+
606
+ @cache.dynamic_func(f"{test.name}_{quadrature.name}")
607
+ def next_qp(
608
+ qp: int,
609
+ elem_offset: int,
610
+ qp_point_count: int,
611
+ element_index: ElementIndex,
612
+ test_element_index: NodeElementIndex,
613
+ element_end: int,
614
+ qp_arg: quadrature.Arg,
615
+ domain_arg: domain.ElementArg,
616
+ domain_index_arg: domain.ElementIndexArg,
617
+ test_arg: space_restriction.NodeArg,
618
+ ):
619
+ while qp >= qp_point_count and elem_offset < element_end:
620
+ # Next element
621
+ elem_offset += 1
622
+
623
+ if elem_offset < element_end:
624
+ qp -= qp_point_count
625
+ test_element_index = space_restriction.node_element_index(test_arg, elem_offset)
626
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
627
+ qp_point_count = quadrature.point_count(
628
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
629
+ )
586
630
 
587
- @cache.dynamic_kernel(f"{test.name}_{quadrature.name}_{wp.types.get_type_code(accumulate_dtype)}")
631
+ return qp, elem_offset, qp_point_count, element_index, test_element_index
632
+
633
+ @cache.dynamic_kernel(
634
+ f"{test.name}_{quadrature.name}_{wp.types.get_type_code(accumulate_dtype)}_{tile_size}",
635
+ kernel_options=kernel_options,
636
+ )
588
637
  def dispatch_linear_kernel_fn(
589
638
  qp_arg: quadrature.Arg,
590
639
  domain_arg: domain.ElementArg,
@@ -594,33 +643,47 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
594
643
  local_result: wp.array3d(dtype=Any),
595
644
  result: wp.array2d(dtype=Any),
596
645
  ):
597
- local_node_index, test_node_dof = wp.tid()
646
+ local_node_index, lane = wp.tid()
647
+
598
648
  node_index = space_restriction.node_partition_index(test_arg, local_node_index)
599
649
  element_beg, element_end = space_restriction.node_element_range(test_arg, node_index)
600
650
 
601
- val_sum = accumulate_dtype(0.0)
602
-
603
- for n in range(element_beg, element_end):
604
- test_element_index = space_restriction.node_element_index(test_arg, n)
605
- element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
606
-
607
- qp_point_count = quadrature.point_count(
608
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index
651
+ val_sum = res_vec()
652
+
653
+ elem_offset = element_beg - 1
654
+ qp_point_count = int(0)
655
+ qp = lane
656
+ test_element_index = NodeElementIndex()
657
+ element_index = ElementIndex(NULL_ELEMENT_INDEX)
658
+
659
+ while elem_offset < element_end:
660
+ qp, elem_offset, qp_point_count, element_index, test_element_index = next_qp(
661
+ qp,
662
+ elem_offset,
663
+ qp_point_count,
664
+ element_index,
665
+ test_element_index,
666
+ element_end,
667
+ qp_arg,
668
+ domain_arg,
669
+ domain_index_arg,
670
+ test_arg,
609
671
  )
610
- for k in range(qp_point_count):
672
+
673
+ if qp < qp_point_count:
611
674
  qp_index = quadrature.point_index(
612
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
675
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
613
676
  )
614
677
  qp_eval_index = quadrature.point_evaluation_index(
615
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
678
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
616
679
  )
617
680
  coords = quadrature.point_coords(
618
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
681
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
619
682
  )
620
683
 
621
684
  qp_result = local_result[qp_eval_index]
622
685
 
623
- qp_sum = float(0.0)
686
+ qp_sum = qp_vec()
624
687
 
625
688
  if wp.static(0 != TEST_INNER_COUNT):
626
689
  w = test.space.element_inner_weight(
@@ -631,9 +694,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
631
694
  test_element_index.node_index_in_element,
632
695
  qp_index,
633
696
  )
634
- for val_dof in range(TEST_NODE_DOF_DIM):
635
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
636
- qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
697
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
698
+ for val_dof in range(TEST_NODE_DOF_DIM):
699
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
700
+ qp_sum[test_node_dof] += (
701
+ basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
702
+ )
637
703
 
638
704
  if wp.static(0 != TEST_OUTER_COUNT):
639
705
  w = test.space.element_outer_weight(
@@ -644,9 +710,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
644
710
  test_element_index.node_index_in_element,
645
711
  qp_index,
646
712
  )
647
- for val_dof in range(TEST_NODE_DOF_DIM):
648
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
649
- qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
713
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
714
+ for val_dof in range(TEST_NODE_DOF_DIM):
715
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
716
+ qp_sum[test_node_dof] += (
717
+ basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
718
+ )
650
719
 
651
720
  if wp.static(0 != TEST_INNER_GRAD_COUNT):
652
721
  w_grad = test.space.element_inner_weight_gradient(
@@ -657,13 +726,14 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
657
726
  test_element_index.node_index_in_element,
658
727
  qp_index,
659
728
  )
660
- for val_dof in range(TEST_NODE_DOF_DIM):
661
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
662
- for grad_dof in range(TEST_INNER_GRAD_COUNT):
663
- qp_sum += (
664
- basis_coefficient(w_grad, val_dof, grad_dof)
665
- * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
666
- )
729
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
730
+ for val_dof in range(TEST_NODE_DOF_DIM):
731
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
732
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
733
+ qp_sum[test_node_dof] += (
734
+ basis_coefficient(w_grad, val_dof, grad_dof)
735
+ * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
736
+ )
667
737
 
668
738
  if wp.static(0 != TEST_OUTER_GRAD_COUNT):
669
739
  w_grad = test.space.element_outer_weight_gradient(
@@ -674,23 +744,36 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
674
744
  test_element_index.node_index_in_element,
675
745
  qp_index,
676
746
  )
677
- for val_dof in range(TEST_NODE_DOF_DIM):
678
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
679
- for grad_dof in range(TEST_OUTER_GRAD_COUNT):
680
- qp_sum += (
681
- basis_coefficient(w_grad, val_dof, grad_dof)
682
- * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
683
- )
747
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
748
+ for val_dof in range(TEST_NODE_DOF_DIM):
749
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
750
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
751
+ qp_sum[test_node_dof] += (
752
+ basis_coefficient(w_grad, val_dof, grad_dof)
753
+ * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
754
+ )
684
755
 
685
- val_sum += accumulate_dtype(qp_sum)
756
+ val_sum += res_vec(qp_sum)
757
+ qp += wp.static(tile_size)
686
758
 
687
- result[node_index, test_node_dof] += result.dtype(val_sum)
759
+ if wp.static(tile_size == 1):
760
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
761
+ result[node_index, test_node_dof] += result.dtype(val_sum[test_node_dof])
762
+ else:
763
+ t_sum = wp.tile_sum(wp.tile(val_sum, preserve_type=True))[0]
764
+ for test_node_dof in range(lane, TEST_NODE_DOF_COUNT, wp.static(tile_size)):
765
+ result[node_index, test_node_dof] += result.dtype(t_sum[test_node_dof])
688
766
 
689
767
  return dispatch_linear_kernel_fn
690
768
 
691
769
 
692
770
  def make_bilinear_dispatch_kernel(
693
- test: LocalTestField, trial: LocalTrialField, quadrature: Quadrature, accumulate_dtype: type
771
+ test: LocalTestField,
772
+ trial: LocalTrialField,
773
+ quadrature: Quadrature,
774
+ accumulate_dtype: type,
775
+ tile_size: int = 1,
776
+ kernel_options: Optional[Dict[str, Any]] = None,
694
777
  ):
695
778
  global_test: TestField = test.global_field
696
779
  space_restriction = global_test.space_restriction
@@ -718,12 +801,24 @@ def make_bilinear_dispatch_kernel(
718
801
 
719
802
  TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
720
803
  TRIAL_NODE_DOF_DIM = trial.value_dof_count // trial.node_dof_count
804
+ TEST_TRIAL_NODE_DOF_DIM = TEST_NODE_DOF_DIM * TRIAL_NODE_DOF_DIM
805
+
806
+ TEST_NODE_DOF_COUNT = test.node_dof_count
807
+ TRIAL_NODE_DOF_COUNT = trial.node_dof_count
808
+ TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
809
+ TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
721
810
 
722
811
  MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
723
812
 
724
813
  trial_dof_vec = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
814
+ test_dof_vec = cache.cached_vec_type(length=test.TAYLOR_DOF_COUNT, dtype=float)
815
+
816
+ val_t = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=accumulate_dtype)
725
817
 
726
- @cache.dynamic_kernel(f"{trial.name}_{test.name}_{quadrature.name}{wp.types.get_type_code(accumulate_dtype)}")
818
+ @cache.dynamic_kernel(
819
+ f"{trial.name}_{test.name}_{quadrature.name}{wp.types.get_type_code(accumulate_dtype)}_{tile_size}",
820
+ kernel_options=kernel_options,
821
+ )
727
822
  def dispatch_bilinear_kernel_fn(
728
823
  qp_arg: quadrature.Arg,
729
824
  domain_arg: domain.ElementArg,
@@ -733,163 +828,166 @@ def make_bilinear_dispatch_kernel(
733
828
  trial_partition_arg: trial.space_partition.PartitionArg,
734
829
  trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
735
830
  trial_space_arg: trial.space.SpaceArg,
736
- local_result: wp.array4d(dtype=trial_dof_vec),
831
+ local_result: wp.array4d(dtype=float),
737
832
  triplet_rows: wp.array(dtype=int),
738
833
  triplet_cols: wp.array(dtype=int),
739
834
  triplet_values: wp.array3d(dtype=Any),
740
835
  ):
741
- test_local_node_index, test_node_dof, trial_node_dof, trial_node = wp.tid()
836
+ test_node_offset, trial_node, lane = wp.tid()
742
837
 
743
- test_node_index = space_restriction.node_partition_index(test_arg, test_local_node_index)
744
- element_beg, element_end = space_restriction.node_element_range(test_arg, test_node_index)
838
+ test_node_index = space_restriction.node_partition_index_from_element_offset(test_arg, test_node_offset)
745
839
 
746
- for element in range(element_beg, element_end):
747
- test_element_index = space_restriction.node_element_index(test_arg, element)
748
- element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
749
- test_node = test_element_index.node_index_in_element
840
+ test_element_index = space_restriction.node_element_index(test_arg, test_node_offset)
841
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
842
+ test_node = test_element_index.node_index_in_element
750
843
 
751
- element_trial_node_count = trial.space.topology.element_node_count(
752
- domain_arg, trial_topology_arg, element_index
753
- )
844
+ element_trial_node_count = trial.space.topology.element_node_count(
845
+ domain_arg, trial_topology_arg, element_index
846
+ )
754
847
 
755
- qp_point_count = wp.where(
756
- trial_node < element_trial_node_count,
757
- quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
758
- 0,
848
+ if trial_node >= element_trial_node_count:
849
+ block_offset = test_node_offset * MAX_NODES_PER_ELEMENT + trial_node
850
+ triplet_rows[block_offset] = NULL_NODE_INDEX
851
+ triplet_cols[block_offset] = NULL_NODE_INDEX
852
+ return
853
+
854
+ qp_point_count = quadrature.point_count(
855
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
856
+ )
857
+ qp_dof_count = qp_point_count * TEST_TRIAL_NODE_DOF_DIM
858
+
859
+ val_sum = val_t()
860
+
861
+ for dof in range(lane, qp_dof_count, wp.static(tile_size)):
862
+ k = dof // TEST_TRIAL_NODE_DOF_DIM
863
+ test_trial_val_dof = dof - k * TEST_TRIAL_NODE_DOF_DIM
864
+ test_val_dof = test_trial_val_dof // TRIAL_NODE_DOF_DIM
865
+ trial_val_dof = test_trial_val_dof - test_val_dof * TRIAL_NODE_DOF_DIM
866
+
867
+ qp_index = quadrature.point_index(
868
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
869
+ )
870
+ qp_eval_index = quadrature.point_evaluation_index(
871
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
872
+ )
873
+ coords = quadrature.point_coords(
874
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
759
875
  )
760
876
 
761
- val_sum = accumulate_dtype(0.0)
877
+ # test shape functions
878
+ w_test = test_dof_vec()
762
879
 
763
- for k in range(qp_point_count):
764
- qp_index = quadrature.point_index(
765
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
766
- )
767
- qp_eval_index = quadrature.point_evaluation_index(
768
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
769
- )
770
- coords = quadrature.point_coords(
771
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
880
+ if wp.static(0 != TEST_INNER_COUNT):
881
+ w_test_inner = test.space.element_inner_weight(
882
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
772
883
  )
884
+ w_test[TEST_INNER_BEGIN] = basis_coefficient(w_test_inner, test_val_dof)
773
885
 
774
- qp_result = local_result[qp_eval_index]
775
- trial_result = float(0.0)
776
-
777
- if wp.static(0 != TEST_INNER_COUNT):
778
- w_test_inner = test.space.element_inner_weight(
779
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
780
- )
886
+ if wp.static(0 != TEST_OUTER_COUNT):
887
+ w_test_outer = test.space.element_outer_weight(
888
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
889
+ )
890
+ w_test[TEST_OUTER_BEGIN] = basis_coefficient(w_test_outer, test_val_dof)
781
891
 
782
- if wp.static(0 != TEST_OUTER_COUNT):
783
- w_test_outer = test.space.element_outer_weight(
784
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
892
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
893
+ w_test_grad_inner = test.space.element_inner_weight_gradient(
894
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
895
+ )
896
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
897
+ w_test[TEST_INNER_GRAD_BEGIN + grad_dof] = basis_coefficient(
898
+ w_test_grad_inner, test_val_dof, grad_dof
785
899
  )
786
900
 
787
- if wp.static(0 != TEST_INNER_GRAD_COUNT):
788
- w_test_grad_inner = test.space.element_inner_weight_gradient(
789
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
901
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
902
+ w_test_grad_outer = test.space.element_outer_weight_gradient(
903
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
904
+ )
905
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
906
+ w_test[TEST_OUTER_GRAD_BEGIN + grad_dof] = basis_coefficient(
907
+ w_test_grad_outer, test_val_dof, grad_dof
790
908
  )
791
909
 
792
- if wp.static(0 != TEST_OUTER_GRAD_COUNT):
793
- w_test_grad_outer = test.space.element_outer_weight_gradient(
794
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
795
- )
910
+ # trial shape functions
911
+ w_trial = trial_dof_vec()
796
912
 
797
- if wp.static(0 != TRIAL_INNER_COUNT):
798
- w_trial_inner = trial.space.element_inner_weight(
799
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
800
- )
913
+ if wp.static(0 != TRIAL_INNER_COUNT):
914
+ w_trial_inner = trial.space.element_inner_weight(
915
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
916
+ )
917
+ w_trial[TRIAL_INNER_BEGIN] = basis_coefficient(w_trial_inner, trial_val_dof)
801
918
 
802
- if wp.static(0 != TRIAL_OUTER_COUNT):
803
- w_trial_outer = trial.space.element_outer_weight(
804
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
805
- )
919
+ if wp.static(0 != TRIAL_OUTER_COUNT):
920
+ w_trial_outer = trial.space.element_outer_weight(
921
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
922
+ )
923
+ w_trial[TRIAL_OUTER_BEGIN] = basis_coefficient(w_trial_outer, trial_val_dof)
806
924
 
807
- if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
808
- w_trial_grad_inner = trial.space.element_inner_weight_gradient(
809
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
925
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
926
+ w_trial_grad_inner = trial.space.element_inner_weight_gradient(
927
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
928
+ )
929
+ for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
930
+ w_trial[TRIAL_INNER_GRAD_BEGIN + grad_dof] = basis_coefficient(
931
+ w_trial_grad_inner, trial_val_dof, grad_dof
810
932
  )
811
933
 
812
- if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
813
- w_trial_grad_outer = trial.space.element_outer_weight_gradient(
814
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
934
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
935
+ w_trial_grad_outer = trial.space.element_outer_weight_gradient(
936
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
937
+ )
938
+ for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
939
+ w_trial[TRIAL_OUTER_GRAD_BEGIN + grad_dof] = basis_coefficient(
940
+ w_trial_grad_outer, trial_val_dof, grad_dof
815
941
  )
816
942
 
817
- for trial_val_dof in range(TRIAL_NODE_DOF_DIM):
943
+ # triple product test @ qp @ trial
944
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
945
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
946
+ for trial_node_dof in range(TRIAL_NODE_DOF_COUNT):
947
+ dof_res = float(0.0)
818
948
  trial_dof = trial_node_dof * TRIAL_NODE_DOF_DIM + trial_val_dof
819
- test_result = trial_dof_vec(0.0)
820
-
821
- if wp.static(0 != TEST_INNER_COUNT):
822
- for test_val_dof in range(TEST_NODE_DOF_DIM):
823
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
824
- test_result += (
825
- basis_coefficient(w_test_inner, test_val_dof)
826
- * qp_result[test_dof, trial_dof, TEST_INNER_BEGIN]
827
- )
828
949
 
829
- if wp.static(0 != TEST_OUTER_COUNT):
830
- for test_val_dof in range(TEST_NODE_DOF_DIM):
831
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
832
- test_result += (
833
- basis_coefficient(w_test_outer, test_val_dof)
834
- * qp_result[test_dof, trial_dof, TEST_OUTER_BEGIN]
950
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
951
+ test_res = float(0.0)
952
+ for trial_taylor_dof in range(TRIAL_TAYLOR_DOF_COUNT):
953
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
954
+ test_res += (
955
+ local_result[test_dof, trial_dof, qp_eval_index, taylor_dof] * w_trial[trial_taylor_dof]
835
956
  )
957
+ dof_res += w_test[test_taylor_dof] * test_res
836
958
 
837
- if wp.static(0 != TEST_INNER_GRAD_COUNT):
838
- for test_val_dof in range(TEST_NODE_DOF_DIM):
839
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
840
- for grad_dof in range(TEST_INNER_GRAD_COUNT):
841
- test_result += (
842
- basis_coefficient(w_test_grad_inner, test_val_dof, grad_dof)
843
- * qp_result[test_dof, trial_dof, grad_dof + TEST_INNER_GRAD_BEGIN]
844
- )
959
+ val_sum[test_node_dof, trial_node_dof] += accumulate_dtype(dof_res)
845
960
 
846
- if wp.static(0 != TEST_OUTER_GRAD_COUNT):
847
- for test_val_dof in range(TEST_NODE_DOF_DIM):
848
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
849
- for grad_dof in range(TEST_OUTER_GRAD_COUNT):
850
- test_result += (
851
- basis_coefficient(w_test_grad_outer, test_val_dof, grad_dof)
852
- * qp_result[test_dof, trial_dof, grad_dof + TEST_OUTER_GRAD_BEGIN]
853
- )
961
+ # write block value
962
+ block_offset = test_node_offset * MAX_NODES_PER_ELEMENT + trial_node
963
+ if wp.static(tile_size) > 1:
964
+ val_sum = wp.tile_sum(wp.tile(val_sum, preserve_type=True))[0]
854
965
 
855
- if wp.static(0 != TRIAL_INNER_COUNT):
856
- trial_result += basis_coefficient(w_trial_inner, trial_val_dof) * test_result[TRIAL_INNER_BEGIN]
966
+ for dof in range(lane, wp.static(TEST_NODE_DOF_COUNT * TRIAL_NODE_DOF_COUNT), wp.static(tile_size)):
967
+ test_node_dof = dof // TRIAL_NODE_DOF_COUNT
968
+ trial_node_dof = dof - TRIAL_NODE_DOF_COUNT * test_node_dof
857
969
 
858
- if wp.static(0 != TRIAL_OUTER_COUNT):
859
- trial_result += basis_coefficient(w_trial_outer, trial_val_dof) * test_result[TRIAL_OUTER_BEGIN]
860
-
861
- if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
862
- for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
863
- trial_result += (
864
- basis_coefficient(w_trial_grad_inner, trial_val_dof, grad_dof)
865
- * test_result[grad_dof + TRIAL_INNER_GRAD_BEGIN]
866
- )
867
-
868
- if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
869
- for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
870
- trial_result += (
871
- basis_coefficient(w_trial_grad_outer, trial_val_dof, grad_dof)
872
- * test_result[grad_dof + TRIAL_OUTER_GRAD_BEGIN]
873
- )
874
-
875
- val_sum += accumulate_dtype(trial_result)
876
-
877
- block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
878
- triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(val_sum)
879
-
880
- # Set row and column indices
881
- if test_node_dof == 0 and trial_node_dof == 0:
882
- if trial_node < element_trial_node_count:
883
- trial_node_index = trial.space_partition.partition_node_index(
884
- trial_partition_arg,
885
- trial.space.topology.element_node_index(
886
- domain_arg, trial_topology_arg, element_index, trial_node
887
- ),
970
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(
971
+ val_sum[test_node_dof, trial_node_dof]
972
+ )
973
+ else:
974
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
975
+ for trial_node_dof in range(TRIAL_NODE_DOF_COUNT):
976
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(
977
+ val_sum[test_node_dof, trial_node_dof]
888
978
  )
889
- else:
890
- trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
891
979
 
892
- triplet_rows[block_offset] = test_node_index
893
- triplet_cols[block_offset] = trial_node_index
980
+ # Set row and column indices
981
+ if lane == 0:
982
+ if trial_node < element_trial_node_count:
983
+ trial_node_index = trial.space_partition.partition_node_index(
984
+ trial_partition_arg,
985
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
986
+ )
987
+ else:
988
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
989
+
990
+ triplet_rows[block_offset] = test_node_index
991
+ triplet_cols[block_offset] = trial_node_index
894
992
 
895
993
  return dispatch_bilinear_kernel_fn
@@ -542,17 +542,17 @@ class Geometry:
542
542
 
543
543
  pos_type = cache.cached_vec_type(self.dimension, dtype=float)
544
544
 
545
- @cache.dynamic_func(suffix=self.name)
545
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
546
546
  def cell_lookup(args: self.CellArg, pos: pos_type, max_dist: float):
547
547
  return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
548
548
 
549
- @cache.dynamic_func(suffix=self.name)
549
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
550
550
  def cell_lookup(args: self.CellArg, pos: pos_type, guess: Sample):
551
551
  guess_pos = self.cell_position(args, guess)
552
552
  max_dist = wp.length(guess_pos - pos)
553
553
  return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
554
554
 
555
- @cache.dynamic_func(suffix=self.name)
555
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
556
556
  def cell_lookup(args: self.CellArg, pos: pos_type):
557
557
  max_dist = 0.0
558
558
  return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
@@ -561,13 +561,13 @@ class Geometry:
561
561
  filtered_cell_lookup = self.make_filtered_cell_lookup(filter_func=_array_load)
562
562
  pos_type = cache.cached_vec_type(self.dimension, dtype=float)
563
563
 
564
- @cache.dynamic_func(suffix=self.name)
564
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
565
565
  def cell_lookup(
566
566
  args: self.CellArg, pos: pos_type, max_dist: float, filter_array: wp.array(dtype=Any), filter_target: Any
567
567
  ):
568
568
  return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)
569
569
 
570
- @cache.dynamic_func(suffix=self.name)
570
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
571
571
  def cell_lookup(args: self.CellArg, pos: pos_type, filter_array: wp.array(dtype=Any), filter_target: Any):
572
572
  max_dist = 0.0
573
573
  return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)