warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/fem/field/virtual.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, ClassVar, Dict, Set
16
+ from typing import Any, ClassVar, Dict, Optional, Set
17
17
 
18
18
  import warp as wp
19
19
  import warp.fem.operator as operator
@@ -22,7 +22,16 @@ from warp.fem.domain import GeometryDomain
22
22
  from warp.fem.linalg import basis_coefficient, generalized_inner, generalized_outer
23
23
  from warp.fem.quadrature import Quadrature
24
24
  from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction
25
- from warp.fem.types import NULL_NODE_INDEX, DofIndex, Sample, get_node_coord, get_node_index_in_element
25
+ from warp.fem.types import (
26
+ NULL_ELEMENT_INDEX,
27
+ NULL_NODE_INDEX,
28
+ DofIndex,
29
+ ElementIndex,
30
+ NodeElementIndex,
31
+ Sample,
32
+ get_node_coord,
33
+ get_node_index_in_element,
34
+ )
26
35
  from warp.fem.utils import type_zero_element
27
36
 
28
37
  from .field import SpaceField
@@ -365,6 +374,8 @@ class LocalAdjointField(SpaceField):
365
374
  self._TAYLOR_DOF_COUNTS = LocalAdjointField.DofOffsets(0)
366
375
  self.TAYLOR_DOF_COUNT = 0
367
376
 
377
+ cache.setup_dynamic_attributes(self)
378
+
368
379
  def notify_operator_usage(self, ops: Set[operator.Operator]):
369
380
  # Rebuild degrees-of-freedom offsets based on used operators
370
381
 
@@ -565,7 +576,13 @@ class LocalTrialField(LocalAdjointField):
565
576
  return s.trial_dof
566
577
 
567
578
 
568
- def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, accumulate_dtype: type):
579
+ def make_linear_dispatch_kernel(
580
+ test: LocalTestField,
581
+ quadrature: Quadrature,
582
+ accumulate_dtype: type,
583
+ tile_size: int = 1,
584
+ kernel_options: Optional[Dict[str, Any]] = None,
585
+ ):
569
586
  global_test: TestField = test.global_field
570
587
  space_restriction = global_test.space_restriction
571
588
  domain = global_test.domain
@@ -581,8 +598,42 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
581
598
  TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
582
599
 
583
600
  TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
601
+ TEST_NODE_DOF_COUNT = test.node_dof_count
602
+
603
+ res_vec = cache.cached_vec_type(length=test.node_dof_count, dtype=accumulate_dtype)
604
+ qp_vec = cache.cached_vec_type(length=test.node_dof_count, dtype=float)
605
+
606
+ @cache.dynamic_func(f"{test.name}_{quadrature.name}")
607
+ def next_qp(
608
+ qp: int,
609
+ elem_offset: int,
610
+ qp_point_count: int,
611
+ element_index: ElementIndex,
612
+ test_element_index: NodeElementIndex,
613
+ element_end: int,
614
+ qp_arg: quadrature.Arg,
615
+ domain_arg: domain.ElementArg,
616
+ domain_index_arg: domain.ElementIndexArg,
617
+ test_arg: space_restriction.NodeArg,
618
+ ):
619
+ while qp >= qp_point_count and elem_offset < element_end:
620
+ # Next element
621
+ elem_offset += 1
622
+
623
+ if elem_offset < element_end:
624
+ qp -= qp_point_count
625
+ test_element_index = space_restriction.node_element_index(test_arg, elem_offset)
626
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
627
+ qp_point_count = quadrature.point_count(
628
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
629
+ )
584
630
 
585
- @cache.dynamic_kernel(f"{test.name}_{quadrature.name}_{wp.types.get_type_code(accumulate_dtype)}")
631
+ return qp, elem_offset, qp_point_count, element_index, test_element_index
632
+
633
+ @cache.dynamic_kernel(
634
+ f"{test.name}_{quadrature.name}_{wp.types.get_type_code(accumulate_dtype)}_{tile_size}",
635
+ kernel_options=kernel_options,
636
+ )
586
637
  def dispatch_linear_kernel_fn(
587
638
  qp_arg: quadrature.Arg,
588
639
  domain_arg: domain.ElementArg,
@@ -592,33 +643,47 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
592
643
  local_result: wp.array3d(dtype=Any),
593
644
  result: wp.array2d(dtype=Any),
594
645
  ):
595
- local_node_index, test_node_dof = wp.tid()
646
+ local_node_index, lane = wp.tid()
647
+
596
648
  node_index = space_restriction.node_partition_index(test_arg, local_node_index)
597
649
  element_beg, element_end = space_restriction.node_element_range(test_arg, node_index)
598
650
 
599
- val_sum = accumulate_dtype(0.0)
600
-
601
- for n in range(element_beg, element_end):
602
- test_element_index = space_restriction.node_element_index(test_arg, n)
603
- element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
604
-
605
- qp_point_count = quadrature.point_count(
606
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index
651
+ val_sum = res_vec()
652
+
653
+ elem_offset = element_beg - 1
654
+ qp_point_count = int(0)
655
+ qp = lane
656
+ test_element_index = NodeElementIndex()
657
+ element_index = ElementIndex(NULL_ELEMENT_INDEX)
658
+
659
+ while elem_offset < element_end:
660
+ qp, elem_offset, qp_point_count, element_index, test_element_index = next_qp(
661
+ qp,
662
+ elem_offset,
663
+ qp_point_count,
664
+ element_index,
665
+ test_element_index,
666
+ element_end,
667
+ qp_arg,
668
+ domain_arg,
669
+ domain_index_arg,
670
+ test_arg,
607
671
  )
608
- for k in range(qp_point_count):
672
+
673
+ if qp < qp_point_count:
609
674
  qp_index = quadrature.point_index(
610
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
675
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
611
676
  )
612
677
  qp_eval_index = quadrature.point_evaluation_index(
613
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
678
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
614
679
  )
615
680
  coords = quadrature.point_coords(
616
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
681
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
617
682
  )
618
683
 
619
684
  qp_result = local_result[qp_eval_index]
620
685
 
621
- qp_sum = float(0.0)
686
+ qp_sum = qp_vec()
622
687
 
623
688
  if wp.static(0 != TEST_INNER_COUNT):
624
689
  w = test.space.element_inner_weight(
@@ -629,9 +694,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
629
694
  test_element_index.node_index_in_element,
630
695
  qp_index,
631
696
  )
632
- for val_dof in range(TEST_NODE_DOF_DIM):
633
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
634
- qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
697
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
698
+ for val_dof in range(TEST_NODE_DOF_DIM):
699
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
700
+ qp_sum[test_node_dof] += (
701
+ basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
702
+ )
635
703
 
636
704
  if wp.static(0 != TEST_OUTER_COUNT):
637
705
  w = test.space.element_outer_weight(
@@ -642,9 +710,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
642
710
  test_element_index.node_index_in_element,
643
711
  qp_index,
644
712
  )
645
- for val_dof in range(TEST_NODE_DOF_DIM):
646
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
647
- qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
713
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
714
+ for val_dof in range(TEST_NODE_DOF_DIM):
715
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
716
+ qp_sum[test_node_dof] += (
717
+ basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
718
+ )
648
719
 
649
720
  if wp.static(0 != TEST_INNER_GRAD_COUNT):
650
721
  w_grad = test.space.element_inner_weight_gradient(
@@ -655,13 +726,14 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
655
726
  test_element_index.node_index_in_element,
656
727
  qp_index,
657
728
  )
658
- for val_dof in range(TEST_NODE_DOF_DIM):
659
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
660
- for grad_dof in range(TEST_INNER_GRAD_COUNT):
661
- qp_sum += (
662
- basis_coefficient(w_grad, val_dof, grad_dof)
663
- * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
664
- )
729
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
730
+ for val_dof in range(TEST_NODE_DOF_DIM):
731
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
732
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
733
+ qp_sum[test_node_dof] += (
734
+ basis_coefficient(w_grad, val_dof, grad_dof)
735
+ * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
736
+ )
665
737
 
666
738
  if wp.static(0 != TEST_OUTER_GRAD_COUNT):
667
739
  w_grad = test.space.element_outer_weight_gradient(
@@ -672,23 +744,36 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
672
744
  test_element_index.node_index_in_element,
673
745
  qp_index,
674
746
  )
675
- for val_dof in range(TEST_NODE_DOF_DIM):
676
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
677
- for grad_dof in range(TEST_OUTER_GRAD_COUNT):
678
- qp_sum += (
679
- basis_coefficient(w_grad, val_dof, grad_dof)
680
- * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
681
- )
747
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
748
+ for val_dof in range(TEST_NODE_DOF_DIM):
749
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
750
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
751
+ qp_sum[test_node_dof] += (
752
+ basis_coefficient(w_grad, val_dof, grad_dof)
753
+ * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
754
+ )
682
755
 
683
- val_sum += accumulate_dtype(qp_sum)
756
+ val_sum += res_vec(qp_sum)
757
+ qp += wp.static(tile_size)
684
758
 
685
- result[node_index, test_node_dof] += result.dtype(val_sum)
759
+ if wp.static(tile_size == 1):
760
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
761
+ result[node_index, test_node_dof] += result.dtype(val_sum[test_node_dof])
762
+ else:
763
+ t_sum = wp.tile_sum(wp.tile(val_sum, preserve_type=True))[0]
764
+ for test_node_dof in range(lane, TEST_NODE_DOF_COUNT, wp.static(tile_size)):
765
+ result[node_index, test_node_dof] += result.dtype(t_sum[test_node_dof])
686
766
 
687
767
  return dispatch_linear_kernel_fn
688
768
 
689
769
 
690
770
  def make_bilinear_dispatch_kernel(
691
- test: LocalTestField, trial: LocalTrialField, quadrature: Quadrature, accumulate_dtype: type
771
+ test: LocalTestField,
772
+ trial: LocalTrialField,
773
+ quadrature: Quadrature,
774
+ accumulate_dtype: type,
775
+ tile_size: int = 1,
776
+ kernel_options: Optional[Dict[str, Any]] = None,
692
777
  ):
693
778
  global_test: TestField = test.global_field
694
779
  space_restriction = global_test.space_restriction
@@ -716,12 +801,24 @@ def make_bilinear_dispatch_kernel(
716
801
 
717
802
  TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
718
803
  TRIAL_NODE_DOF_DIM = trial.value_dof_count // trial.node_dof_count
804
+ TEST_TRIAL_NODE_DOF_DIM = TEST_NODE_DOF_DIM * TRIAL_NODE_DOF_DIM
805
+
806
+ TEST_NODE_DOF_COUNT = test.node_dof_count
807
+ TRIAL_NODE_DOF_COUNT = trial.node_dof_count
808
+ TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
809
+ TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
719
810
 
720
811
  MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
721
812
 
722
813
  trial_dof_vec = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
814
+ test_dof_vec = cache.cached_vec_type(length=test.TAYLOR_DOF_COUNT, dtype=float)
723
815
 
724
- @cache.dynamic_kernel(f"{trial.name}_{test.name}_{quadrature.name}{wp.types.get_type_code(accumulate_dtype)}")
816
+ val_t = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=accumulate_dtype)
817
+
818
+ @cache.dynamic_kernel(
819
+ f"{trial.name}_{test.name}_{quadrature.name}{wp.types.get_type_code(accumulate_dtype)}_{tile_size}",
820
+ kernel_options=kernel_options,
821
+ )
725
822
  def dispatch_bilinear_kernel_fn(
726
823
  qp_arg: quadrature.Arg,
727
824
  domain_arg: domain.ElementArg,
@@ -731,163 +828,166 @@ def make_bilinear_dispatch_kernel(
731
828
  trial_partition_arg: trial.space_partition.PartitionArg,
732
829
  trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
733
830
  trial_space_arg: trial.space.SpaceArg,
734
- local_result: wp.array4d(dtype=trial_dof_vec),
831
+ local_result: wp.array4d(dtype=float),
735
832
  triplet_rows: wp.array(dtype=int),
736
833
  triplet_cols: wp.array(dtype=int),
737
834
  triplet_values: wp.array3d(dtype=Any),
738
835
  ):
739
- test_local_node_index, test_node_dof, trial_node_dof, trial_node = wp.tid()
836
+ test_node_offset, trial_node, lane = wp.tid()
740
837
 
741
- test_node_index = space_restriction.node_partition_index(test_arg, test_local_node_index)
742
- element_beg, element_end = space_restriction.node_element_range(test_arg, test_node_index)
838
+ test_node_index = space_restriction.node_partition_index_from_element_offset(test_arg, test_node_offset)
743
839
 
744
- for element in range(element_beg, element_end):
745
- test_element_index = space_restriction.node_element_index(test_arg, element)
746
- element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
747
- test_node = test_element_index.node_index_in_element
840
+ test_element_index = space_restriction.node_element_index(test_arg, test_node_offset)
841
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
842
+ test_node = test_element_index.node_index_in_element
748
843
 
749
- element_trial_node_count = trial.space.topology.element_node_count(
750
- domain_arg, trial_topology_arg, element_index
751
- )
844
+ element_trial_node_count = trial.space.topology.element_node_count(
845
+ domain_arg, trial_topology_arg, element_index
846
+ )
847
+
848
+ if trial_node >= element_trial_node_count:
849
+ block_offset = test_node_offset * MAX_NODES_PER_ELEMENT + trial_node
850
+ triplet_rows[block_offset] = NULL_NODE_INDEX
851
+ triplet_cols[block_offset] = NULL_NODE_INDEX
852
+ return
752
853
 
753
- qp_point_count = wp.where(
754
- trial_node < element_trial_node_count,
755
- quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
756
- 0,
854
+ qp_point_count = quadrature.point_count(
855
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
856
+ )
857
+ qp_dof_count = qp_point_count * TEST_TRIAL_NODE_DOF_DIM
858
+
859
+ val_sum = val_t()
860
+
861
+ for dof in range(lane, qp_dof_count, wp.static(tile_size)):
862
+ k = dof // TEST_TRIAL_NODE_DOF_DIM
863
+ test_trial_val_dof = dof - k * TEST_TRIAL_NODE_DOF_DIM
864
+ test_val_dof = test_trial_val_dof // TRIAL_NODE_DOF_DIM
865
+ trial_val_dof = test_trial_val_dof - test_val_dof * TRIAL_NODE_DOF_DIM
866
+
867
+ qp_index = quadrature.point_index(
868
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
869
+ )
870
+ qp_eval_index = quadrature.point_evaluation_index(
871
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
872
+ )
873
+ coords = quadrature.point_coords(
874
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
757
875
  )
758
876
 
759
- val_sum = accumulate_dtype(0.0)
877
+ # test shape functions
878
+ w_test = test_dof_vec()
760
879
 
761
- for k in range(qp_point_count):
762
- qp_index = quadrature.point_index(
763
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
764
- )
765
- qp_eval_index = quadrature.point_evaluation_index(
766
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
767
- )
768
- coords = quadrature.point_coords(
769
- domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
880
+ if wp.static(0 != TEST_INNER_COUNT):
881
+ w_test_inner = test.space.element_inner_weight(
882
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
770
883
  )
884
+ w_test[TEST_INNER_BEGIN] = basis_coefficient(w_test_inner, test_val_dof)
771
885
 
772
- qp_result = local_result[qp_eval_index]
773
- trial_result = float(0.0)
774
-
775
- if wp.static(0 != TEST_INNER_COUNT):
776
- w_test_inner = test.space.element_inner_weight(
777
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
778
- )
886
+ if wp.static(0 != TEST_OUTER_COUNT):
887
+ w_test_outer = test.space.element_outer_weight(
888
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
889
+ )
890
+ w_test[TEST_OUTER_BEGIN] = basis_coefficient(w_test_outer, test_val_dof)
779
891
 
780
- if wp.static(0 != TEST_OUTER_COUNT):
781
- w_test_outer = test.space.element_outer_weight(
782
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
892
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
893
+ w_test_grad_inner = test.space.element_inner_weight_gradient(
894
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
895
+ )
896
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
897
+ w_test[TEST_INNER_GRAD_BEGIN + grad_dof] = basis_coefficient(
898
+ w_test_grad_inner, test_val_dof, grad_dof
783
899
  )
784
900
 
785
- if wp.static(0 != TEST_INNER_GRAD_COUNT):
786
- w_test_grad_inner = test.space.element_inner_weight_gradient(
787
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
901
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
902
+ w_test_grad_outer = test.space.element_outer_weight_gradient(
903
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
904
+ )
905
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
906
+ w_test[TEST_OUTER_GRAD_BEGIN + grad_dof] = basis_coefficient(
907
+ w_test_grad_outer, test_val_dof, grad_dof
788
908
  )
789
909
 
790
- if wp.static(0 != TEST_OUTER_GRAD_COUNT):
791
- w_test_grad_outer = test.space.element_outer_weight_gradient(
792
- domain_arg, test_space_arg, element_index, coords, test_node, qp_index
793
- )
910
+ # trial shape functions
911
+ w_trial = trial_dof_vec()
794
912
 
795
- if wp.static(0 != TRIAL_INNER_COUNT):
796
- w_trial_inner = trial.space.element_inner_weight(
797
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
798
- )
913
+ if wp.static(0 != TRIAL_INNER_COUNT):
914
+ w_trial_inner = trial.space.element_inner_weight(
915
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
916
+ )
917
+ w_trial[TRIAL_INNER_BEGIN] = basis_coefficient(w_trial_inner, trial_val_dof)
799
918
 
800
- if wp.static(0 != TRIAL_OUTER_COUNT):
801
- w_trial_outer = trial.space.element_outer_weight(
802
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
803
- )
919
+ if wp.static(0 != TRIAL_OUTER_COUNT):
920
+ w_trial_outer = trial.space.element_outer_weight(
921
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
922
+ )
923
+ w_trial[TRIAL_OUTER_BEGIN] = basis_coefficient(w_trial_outer, trial_val_dof)
804
924
 
805
- if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
806
- w_trial_grad_inner = trial.space.element_inner_weight_gradient(
807
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
925
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
926
+ w_trial_grad_inner = trial.space.element_inner_weight_gradient(
927
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
928
+ )
929
+ for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
930
+ w_trial[TRIAL_INNER_GRAD_BEGIN + grad_dof] = basis_coefficient(
931
+ w_trial_grad_inner, trial_val_dof, grad_dof
808
932
  )
809
933
 
810
- if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
811
- w_trial_grad_outer = trial.space.element_outer_weight_gradient(
812
- domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
934
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
935
+ w_trial_grad_outer = trial.space.element_outer_weight_gradient(
936
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
937
+ )
938
+ for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
939
+ w_trial[TRIAL_OUTER_GRAD_BEGIN + grad_dof] = basis_coefficient(
940
+ w_trial_grad_outer, trial_val_dof, grad_dof
813
941
  )
814
942
 
815
- for trial_val_dof in range(TRIAL_NODE_DOF_DIM):
943
+ # triple product test @ qp @ trial
944
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
945
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
946
+ for trial_node_dof in range(TRIAL_NODE_DOF_COUNT):
947
+ dof_res = float(0.0)
816
948
  trial_dof = trial_node_dof * TRIAL_NODE_DOF_DIM + trial_val_dof
817
- test_result = trial_dof_vec(0.0)
818
-
819
- if wp.static(0 != TEST_INNER_COUNT):
820
- for test_val_dof in range(TEST_NODE_DOF_DIM):
821
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
822
- test_result += (
823
- basis_coefficient(w_test_inner, test_val_dof)
824
- * qp_result[test_dof, trial_dof, TEST_INNER_BEGIN]
825
- )
826
949
 
827
- if wp.static(0 != TEST_OUTER_COUNT):
828
- for test_val_dof in range(TEST_NODE_DOF_DIM):
829
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
830
- test_result += (
831
- basis_coefficient(w_test_outer, test_val_dof)
832
- * qp_result[test_dof, trial_dof, TEST_OUTER_BEGIN]
950
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
951
+ test_res = float(0.0)
952
+ for trial_taylor_dof in range(TRIAL_TAYLOR_DOF_COUNT):
953
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
954
+ test_res += (
955
+ local_result[test_dof, trial_dof, qp_eval_index, taylor_dof] * w_trial[trial_taylor_dof]
833
956
  )
957
+ dof_res += w_test[test_taylor_dof] * test_res
834
958
 
835
- if wp.static(0 != TEST_INNER_GRAD_COUNT):
836
- for test_val_dof in range(TEST_NODE_DOF_DIM):
837
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
838
- for grad_dof in range(TEST_INNER_GRAD_COUNT):
839
- test_result += (
840
- basis_coefficient(w_test_grad_inner, test_val_dof, grad_dof)
841
- * qp_result[test_dof, trial_dof, grad_dof + TEST_INNER_GRAD_BEGIN]
842
- )
959
+ val_sum[test_node_dof, trial_node_dof] += accumulate_dtype(dof_res)
843
960
 
844
- if wp.static(0 != TEST_OUTER_GRAD_COUNT):
845
- for test_val_dof in range(TEST_NODE_DOF_DIM):
846
- test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
847
- for grad_dof in range(TEST_OUTER_GRAD_COUNT):
848
- test_result += (
849
- basis_coefficient(w_test_grad_outer, test_val_dof, grad_dof)
850
- * qp_result[test_dof, trial_dof, grad_dof + TEST_OUTER_GRAD_BEGIN]
851
- )
961
+ # write block value
962
+ block_offset = test_node_offset * MAX_NODES_PER_ELEMENT + trial_node
963
+ if wp.static(tile_size) > 1:
964
+ val_sum = wp.tile_sum(wp.tile(val_sum, preserve_type=True))[0]
852
965
 
853
- if wp.static(0 != TRIAL_INNER_COUNT):
854
- trial_result += basis_coefficient(w_trial_inner, trial_val_dof) * test_result[TRIAL_INNER_BEGIN]
966
+ for dof in range(lane, wp.static(TEST_NODE_DOF_COUNT * TRIAL_NODE_DOF_COUNT), wp.static(tile_size)):
967
+ test_node_dof = dof // TRIAL_NODE_DOF_COUNT
968
+ trial_node_dof = dof - TRIAL_NODE_DOF_COUNT * test_node_dof
855
969
 
856
- if wp.static(0 != TRIAL_OUTER_COUNT):
857
- trial_result += basis_coefficient(w_trial_outer, trial_val_dof) * test_result[TRIAL_OUTER_BEGIN]
858
-
859
- if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
860
- for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
861
- trial_result += (
862
- basis_coefficient(w_trial_grad_inner, trial_val_dof, grad_dof)
863
- * test_result[grad_dof + TRIAL_INNER_GRAD_BEGIN]
864
- )
865
-
866
- if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
867
- for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
868
- trial_result += (
869
- basis_coefficient(w_trial_grad_outer, trial_val_dof, grad_dof)
870
- * test_result[grad_dof + TRIAL_OUTER_GRAD_BEGIN]
871
- )
872
-
873
- val_sum += accumulate_dtype(trial_result)
874
-
875
- block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
876
- triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(val_sum)
877
-
878
- # Set row and column indices
879
- if test_node_dof == 0 and trial_node_dof == 0:
880
- if trial_node < element_trial_node_count:
881
- trial_node_index = trial.space_partition.partition_node_index(
882
- trial_partition_arg,
883
- trial.space.topology.element_node_index(
884
- domain_arg, trial_topology_arg, element_index, trial_node
885
- ),
970
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(
971
+ val_sum[test_node_dof, trial_node_dof]
972
+ )
973
+ else:
974
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
975
+ for trial_node_dof in range(TRIAL_NODE_DOF_COUNT):
976
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(
977
+ val_sum[test_node_dof, trial_node_dof]
886
978
  )
887
- else:
888
- trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
889
979
 
890
- triplet_rows[block_offset] = test_node_index
891
- triplet_cols[block_offset] = trial_node_index
980
+ # Set row and column indices
981
+ if lane == 0:
982
+ if trial_node < element_trial_node_count:
983
+ trial_node_index = trial.space_partition.partition_node_index(
984
+ trial_partition_arg,
985
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
986
+ )
987
+ else:
988
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
989
+
990
+ triplet_rows[block_offset] = test_node_index
991
+ triplet_cols[block_offset] = trial_node_index
892
992
 
893
993
  return dispatch_bilinear_kernel_fn
@@ -542,17 +542,17 @@ class Geometry:
542
542
 
543
543
  pos_type = cache.cached_vec_type(self.dimension, dtype=float)
544
544
 
545
- @cache.dynamic_func(suffix=self.name)
545
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
546
546
  def cell_lookup(args: self.CellArg, pos: pos_type, max_dist: float):
547
547
  return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
548
548
 
549
- @cache.dynamic_func(suffix=self.name)
549
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
550
550
  def cell_lookup(args: self.CellArg, pos: pos_type, guess: Sample):
551
551
  guess_pos = self.cell_position(args, guess)
552
552
  max_dist = wp.length(guess_pos - pos)
553
553
  return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
554
554
 
555
- @cache.dynamic_func(suffix=self.name)
555
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
556
556
  def cell_lookup(args: self.CellArg, pos: pos_type):
557
557
  max_dist = 0.0
558
558
  return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
@@ -561,13 +561,13 @@ class Geometry:
561
561
  filtered_cell_lookup = self.make_filtered_cell_lookup(filter_func=_array_load)
562
562
  pos_type = cache.cached_vec_type(self.dimension, dtype=float)
563
563
 
564
- @cache.dynamic_func(suffix=self.name)
564
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
565
565
  def cell_lookup(
566
566
  args: self.CellArg, pos: pos_type, max_dist: float, filter_array: wp.array(dtype=Any), filter_target: Any
567
567
  ):
568
568
  return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)
569
569
 
570
- @cache.dynamic_func(suffix=self.name)
570
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
571
571
  def cell_lookup(args: self.CellArg, pos: pos_type, filter_array: wp.array(dtype=Any), filter_target: Any):
572
572
  max_dist = 0.0
573
573
  return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)