warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -34,7 +34,7 @@ from warp.fem.field import (
34
34
  make_restriction,
35
35
  )
36
36
  from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
37
- from warp.fem.linalg import array_axpy
37
+ from warp.fem.linalg import array_axpy, basis_coefficient
38
38
  from warp.fem.operator import Integrand, Operator, at_node, integrand
39
39
  from warp.fem.quadrature import Quadrature, RegularQuadrature
40
40
  from warp.fem.types import (
@@ -493,7 +493,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
493
493
  callee = getattr(call.func, "id", None)
494
494
 
495
495
  if callee == self._func_name:
496
- # Replace function arguments with ours generated structs
496
+ # Replace function arguments with our generated structs
497
497
  call.args.clear()
498
498
  for arg in self._arg_names:
499
499
  if arg == self._domain_name:
@@ -576,33 +576,33 @@ def get_integrate_constant_kernel(
576
576
  ):
577
577
  def integrate_kernel_fn(
578
578
  qp_arg: quadrature.Arg,
579
+ qp_element_index_arg: quadrature.ElementIndexArg,
579
580
  domain_arg: domain.ElementArg,
580
581
  domain_index_arg: domain.ElementIndexArg,
581
582
  fields: FieldStruct,
582
583
  values: ValueStruct,
583
584
  result: wp.array(dtype=accumulate_dtype),
584
585
  ):
585
- domain_element_index = wp.tid()
586
+ qp_eval_index = wp.tid()
587
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
588
+ if domain_element_index == NULL_ELEMENT_INDEX:
589
+ return
590
+
586
591
  element_index = domain.element_index(domain_index_arg, domain_element_index)
587
- elem_sum = accumulate_dtype(0.0)
592
+
593
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
594
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
595
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
588
596
 
589
597
  test_dof_index = NULL_DOF_INDEX
590
598
  trial_dof_index = NULL_DOF_INDEX
591
599
 
592
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
593
- for k in range(qp_point_count):
594
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
595
- coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
596
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
597
-
598
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
599
- vol = domain.element_measure(domain_arg, sample)
600
-
601
- val = integrand_func(sample, fields, values)
600
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
601
+ vol = domain.element_measure(domain_arg, sample)
602
602
 
603
- elem_sum += accumulate_dtype(qp_weight * vol * val)
603
+ val = integrand_func(sample, fields, values)
604
604
 
605
- wp.atomic_add(result, 0, elem_sum)
605
+ wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
606
606
 
607
607
  return integrate_kernel_fn
608
608
 
@@ -745,35 +745,35 @@ def get_integrate_linear_local_kernel(
745
745
  ValueStruct: wp.codegen.Struct,
746
746
  test: LocalTestField,
747
747
  ):
748
- TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
749
-
750
748
  def integrate_kernel_fn(
751
749
  qp_arg: quadrature.Arg,
750
+ qp_element_index_arg: quadrature.ElementIndexArg,
752
751
  domain_arg: domain.ElementArg,
753
752
  domain_index_arg: domain.ElementIndexArg,
754
753
  fields: FieldStruct,
755
754
  values: ValueStruct,
756
755
  result: wp.array3d(dtype=float),
757
756
  ):
758
- domain_element_index, taylor_dof, test_dof = wp.tid()
759
- element_index = domain.element_index(domain_index_arg, domain_element_index)
757
+ qp_eval_index, taylor_dof, test_dof = wp.tid()
758
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
760
759
 
761
- trial_dof_index = NULL_DOF_INDEX
762
- test_dof_offset = test_dof * TAYLOR_DOF_COUNT
760
+ if domain_element_index == NULL_ELEMENT_INDEX:
761
+ return
763
762
 
764
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
765
- for qp in range(qp_point_count):
766
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
767
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
768
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
763
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
769
764
 
770
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
765
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
766
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
767
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
771
768
 
772
- test_dof_index = DofIndex(qp_index, test_dof_offset + taylor_dof)
769
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
773
770
 
774
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
775
- val = integrand_func(sample, fields, values)
776
- result[qp_index, taylor_dof, test_dof] = qp_weight * vol * val
771
+ trial_dof_index = NULL_DOF_INDEX
772
+ test_dof_index = DofIndex(taylor_dof, test_dof)
773
+
774
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
775
+ val = integrand_func(sample, fields, values)
776
+ result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
777
777
 
778
778
  return integrate_kernel_fn
779
779
 
@@ -818,10 +818,10 @@ def get_integrate_bilinear_kernel(
818
818
  element_trial_node_count = trial.space.topology.element_node_count(
819
819
  domain_arg, trial_topology_arg, element_index
820
820
  )
821
- qp_point_count = wp.select(
821
+ qp_point_count = wp.where(
822
822
  trial_node < element_trial_node_count,
823
- 0,
824
823
  quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
824
+ 0,
825
825
  )
826
826
 
827
827
  test_dof_index = DofIndex(
@@ -963,36 +963,38 @@ def get_integrate_bilinear_local_kernel(
963
963
 
964
964
  def integrate_kernel_fn(
965
965
  qp_arg: quadrature.Arg,
966
+ qp_element_index_arg: quadrature.ElementIndexArg,
966
967
  domain_arg: domain.ElementArg,
967
968
  domain_index_arg: domain.ElementIndexArg,
968
969
  fields: FieldStruct,
969
970
  values: ValueStruct,
970
971
  result: wp.array4d(dtype=float),
971
972
  ):
972
- domain_element_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
973
+ qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
974
+
975
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
976
+ if domain_element_index == NULL_ELEMENT_INDEX:
977
+ return
978
+
973
979
  element_index = domain.element_index(domain_index_arg, domain_element_index)
974
980
 
975
- test_dof_offset = TEST_TAYLOR_DOF_COUNT * test_dof
976
- trial_dof_offset = TRIAL_TAYLOR_DOF_COUNT * trial_dof
981
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
982
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
983
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
977
984
 
978
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
979
- for k in range(qp_point_count):
980
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
981
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
982
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
985
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
986
+ qp_vol = vol * qp_weight
983
987
 
984
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
985
- qp_vol = vol * qp_weight
988
+ trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
986
989
 
987
- for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
988
- taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
990
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
991
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
989
992
 
990
- test_dof_index = DofIndex(qp_index, test_dof_offset + test_taylor_dof)
991
- trial_dof_index = DofIndex(qp_index, trial_dof_offset + trial_taylor_dof)
993
+ test_dof_index = DofIndex(test_taylor_dof, test_dof)
992
994
 
993
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
994
- val = integrand_func(sample, fields, values)
995
- result[qp_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
995
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
996
+ val = integrand_func(sample, fields, values)
997
+ result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
996
998
 
997
999
  return integrate_kernel_fn
998
1000
 
@@ -1138,6 +1140,7 @@ def _launch_integrate_kernel(
1138
1140
  output_dtype: type,
1139
1141
  output: Optional[Union[wp.array, BsrMatrix]],
1140
1142
  add_to_output: bool,
1143
+ bsr_options: Optional[Dict[str, Any]],
1141
1144
  device,
1142
1145
  ):
1143
1146
  # Set-up launch arguments
@@ -1175,9 +1178,10 @@ def _launch_integrate_kernel(
1175
1178
 
1176
1179
  wp.launch(
1177
1180
  kernel=kernel,
1178
- dim=domain.element_count(),
1181
+ dim=quadrature.evaluation_point_count(),
1179
1182
  inputs=[
1180
1183
  qp_arg,
1184
+ quadrature.element_index_arg_value(device),
1181
1185
  domain_elt_arg,
1182
1186
  domain_elt_index_arg,
1183
1187
  field_arg_values,
@@ -1279,15 +1283,16 @@ def _launch_integrate_kernel(
1279
1283
  temporary_store=temporary_store,
1280
1284
  device=device,
1281
1285
  requires_grad=output.requires_grad,
1282
- shape=(quadrature.total_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1286
+ shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1283
1287
  dtype=float,
1284
1288
  )
1285
1289
 
1286
1290
  wp.launch(
1287
1291
  kernel=kernel,
1288
- dim=(domain.element_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1292
+ dim=local_result.array.shape,
1289
1293
  inputs=[
1290
1294
  qp_arg,
1295
+ quadrature.element_index_arg_value(device),
1291
1296
  domain_elt_arg,
1292
1297
  domain_elt_index_arg,
1293
1298
  field_arg_values,
@@ -1389,7 +1394,7 @@ def _launch_integrate_kernel(
1389
1394
  device=device,
1390
1395
  requires_grad=False,
1391
1396
  shape=(
1392
- quadrature.total_point_count(),
1397
+ quadrature.evaluation_point_count(),
1393
1398
  test.value_dof_count,
1394
1399
  trial.value_dof_count,
1395
1400
  test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
@@ -1399,9 +1404,15 @@ def _launch_integrate_kernel(
1399
1404
 
1400
1405
  wp.launch(
1401
1406
  kernel=kernel,
1402
- dim=(domain.element_count(), test.value_dof_count, trial.value_dof_count, trial.TAYLOR_DOF_COUNT),
1407
+ dim=(
1408
+ quadrature.evaluation_point_count(),
1409
+ test.value_dof_count,
1410
+ trial.value_dof_count,
1411
+ trial.TAYLOR_DOF_COUNT,
1412
+ ),
1403
1413
  inputs=[
1404
1414
  qp_arg,
1415
+ quadrature.element_index_arg_value(device),
1405
1416
  domain_elt_arg,
1406
1417
  domain_elt_index_arg,
1407
1418
  field_arg_values,
@@ -1496,7 +1507,7 @@ def _launch_integrate_kernel(
1496
1507
  else:
1497
1508
  bsr_result = output
1498
1509
 
1499
- bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values)
1510
+ bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
1500
1511
 
1501
1512
  # Do not wait for garbage collection
1502
1513
  triplet_values_temp.release()
@@ -1541,8 +1552,9 @@ def integrate(
1541
1552
  device=None,
1542
1553
  temporary_store: Optional[cache.TemporaryStore] = None,
1543
1554
  kernel_options: Optional[Dict[str, Any]] = None,
1544
- assembly: str = None,
1555
+ assembly: Optional[str] = None,
1545
1556
  add: bool = False,
1557
+ bsr_options: Optional[Dict[str, Any]] = None,
1546
1558
  ):
1547
1559
  """
1548
1560
  Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
@@ -1566,6 +1578,7 @@ def integrate(
1566
1578
  - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
1567
1579
  - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1568
1580
  add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1581
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
1569
1582
  """
1570
1583
  if fields is None:
1571
1584
  fields = {}
@@ -1678,6 +1691,7 @@ def integrate(
1678
1691
  output_dtype=output_dtype,
1679
1692
  output=output,
1680
1693
  add_to_output=add,
1694
+ bsr_options=bsr_options,
1681
1695
  device=device,
1682
1696
  )
1683
1697
 
@@ -1823,53 +1837,128 @@ def get_interpolate_at_quadrature_kernel(
1823
1837
  ):
1824
1838
  def interpolate_at_quadrature_nonvalued_kernel_fn(
1825
1839
  qp_arg: quadrature.Arg,
1840
+ qp_element_index_arg: quadrature.ElementIndexArg,
1826
1841
  domain_arg: quadrature.domain.ElementArg,
1827
1842
  domain_index_arg: quadrature.domain.ElementIndexArg,
1828
1843
  fields: FieldStruct,
1829
1844
  values: ValueStruct,
1830
1845
  result: wp.array(dtype=float),
1831
1846
  ):
1832
- domain_element_index = wp.tid()
1847
+ qp_eval_index = wp.tid()
1848
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1849
+ if domain_element_index == NULL_ELEMENT_INDEX:
1850
+ return
1851
+
1833
1852
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1834
1853
 
1835
1854
  test_dof_index = NULL_DOF_INDEX
1836
1855
  trial_dof_index = NULL_DOF_INDEX
1837
1856
 
1838
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
1839
- for k in range(qp_point_count):
1840
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
1841
- coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
1842
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
1857
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1858
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1859
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1843
1860
 
1844
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1845
- integrand_func(sample, fields, values)
1861
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1862
+ integrand_func(sample, fields, values)
1846
1863
 
1847
1864
  def interpolate_at_quadrature_kernel_fn(
1848
1865
  qp_arg: quadrature.Arg,
1866
+ qp_element_index_arg: quadrature.ElementIndexArg,
1849
1867
  domain_arg: quadrature.domain.ElementArg,
1850
1868
  domain_index_arg: quadrature.domain.ElementIndexArg,
1851
1869
  fields: FieldStruct,
1852
1870
  values: ValueStruct,
1853
1871
  result: wp.array(dtype=value_type),
1854
1872
  ):
1855
- domain_element_index = wp.tid()
1873
+ qp_eval_index = wp.tid()
1874
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1875
+ if domain_element_index == NULL_ELEMENT_INDEX:
1876
+ return
1877
+
1856
1878
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1857
1879
 
1858
1880
  test_dof_index = NULL_DOF_INDEX
1859
1881
  trial_dof_index = NULL_DOF_INDEX
1860
1882
 
1861
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
1862
- for k in range(qp_point_count):
1863
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
1864
- coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
1865
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
1883
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1884
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1885
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1866
1886
 
1867
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1868
- result[qp_index] = integrand_func(sample, fields, values)
1887
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1888
+ result[qp_index] = integrand_func(sample, fields, values)
1869
1889
 
1870
1890
  return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
1871
1891
 
1872
1892
 
1893
+ def get_interpolate_jacobian_at_quadrature_kernel(
1894
+ integrand_func: wp.Function,
1895
+ domain: GeometryDomain,
1896
+ quadrature: Quadrature,
1897
+ FieldStruct: wp.codegen.Struct,
1898
+ ValueStruct: wp.codegen.Struct,
1899
+ trial: TrialField,
1900
+ value_size: int,
1901
+ value_type: type,
1902
+ ):
1903
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
1904
+ VALUE_SIZE = wp.constant(value_size)
1905
+
1906
+ def interpolate_jacobian_kernel_fn(
1907
+ qp_arg: quadrature.Arg,
1908
+ qp_element_index_arg: quadrature.ElementIndexArg,
1909
+ domain_arg: domain.ElementArg,
1910
+ domain_index_arg: domain.ElementIndexArg,
1911
+ trial_partition_arg: trial.space_partition.PartitionArg,
1912
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
1913
+ fields: FieldStruct,
1914
+ values: ValueStruct,
1915
+ triplet_rows: wp.array(dtype=int),
1916
+ triplet_cols: wp.array(dtype=int),
1917
+ triplet_values: wp.array3d(dtype=value_type),
1918
+ ):
1919
+ qp_eval_index, trial_node, trial_dof = wp.tid()
1920
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1921
+
1922
+ if domain_element_index == NULL_ELEMENT_INDEX:
1923
+ return
1924
+
1925
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1926
+ if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
1927
+ return
1928
+
1929
+ element_trial_node_count = trial.space.topology.element_node_count(
1930
+ domain_arg, trial_topology_arg, element_index
1931
+ )
1932
+
1933
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1934
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1935
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1936
+
1937
+ block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
1938
+
1939
+ test_dof_index = NULL_DOF_INDEX
1940
+ trial_dof_index = DofIndex(trial_node, trial_dof)
1941
+
1942
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1943
+ val = integrand_func(sample, fields, values)
1944
+
1945
+ for k in range(VALUE_SIZE):
1946
+ triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
1947
+
1948
+ if trial_dof == 0:
1949
+ if trial_node < element_trial_node_count:
1950
+ trial_node_index = trial.space_partition.partition_node_index(
1951
+ trial_partition_arg,
1952
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
1953
+ )
1954
+ else:
1955
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
1956
+ triplet_rows[block_offset] = qp_index
1957
+ triplet_cols[block_offset] = trial_node_index
1958
+
1959
+ return interpolate_jacobian_kernel_fn
1960
+
1961
+
1873
1962
  def get_interpolate_free_kernel(
1874
1963
  integrand_func: wp.Function,
1875
1964
  domain: GeometryDomain,
@@ -1939,9 +2028,9 @@ def _generate_interpolate_kernel(
1939
2028
  dest_dtype = dest.dtype if dest else None
1940
2029
  type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
1941
2030
  if quadrature is None:
1942
- kernel_suffix = f"_itp_{field_names}_{type_str}"
2031
+ kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
1943
2032
  else:
1944
- kernel_suffix = f"_itp_{field_names}_{quadrature.name}_{type_str}"
2033
+ kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
1945
2034
 
1946
2035
  kernel = cache.get_integrand_kernel(
1947
2036
  integrand=integrand,
@@ -1986,14 +2075,27 @@ def _generate_interpolate_kernel(
1986
2075
  ValueStruct=ValueStruct,
1987
2076
  )
1988
2077
  elif quadrature is not None:
1989
- interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
1990
- integrand_func,
1991
- domain=domain,
1992
- quadrature=quadrature,
1993
- value_type=dest_dtype,
1994
- FieldStruct=FieldStruct,
1995
- ValueStruct=ValueStruct,
1996
- )
2078
+ if arguments.trial_name:
2079
+ trial = arguments.field_args[arguments.trial_name]
2080
+ interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
2081
+ integrand_func,
2082
+ domain=domain,
2083
+ quadrature=quadrature,
2084
+ FieldStruct=FieldStruct,
2085
+ ValueStruct=ValueStruct,
2086
+ trial=trial,
2087
+ value_size=dest.block_shape[0],
2088
+ value_type=dest.scalar_type,
2089
+ )
2090
+ else:
2091
+ interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
2092
+ integrand_func,
2093
+ domain=domain,
2094
+ quadrature=quadrature,
2095
+ value_type=dest_dtype,
2096
+ FieldStruct=FieldStruct,
2097
+ ValueStruct=ValueStruct,
2098
+ )
1997
2099
  else:
1998
2100
  interpolate_kernel_fn = get_interpolate_free_kernel(
1999
2101
  integrand_func,
@@ -2027,8 +2129,11 @@ def _launch_interpolate_kernel(
2027
2129
  dest: Optional[Union[FieldRestriction, wp.array]],
2028
2130
  quadrature: Optional[Quadrature],
2029
2131
  dim: int,
2132
+ trial: Optional[TrialField],
2030
2133
  fields: Dict[str, FieldLike],
2031
2134
  values: Dict[str, Any],
2135
+ temporary_store: Optional[cache.TemporaryStore],
2136
+ bsr_options: Optional[Dict[str, Any]],
2032
2137
  device,
2033
2138
  ) -> wp.Kernel:
2034
2139
  # Set-up launch arguments
@@ -2059,21 +2164,74 @@ def _launch_interpolate_kernel(
2059
2164
  ],
2060
2165
  device=device,
2061
2166
  )
2062
- elif quadrature is not None:
2063
- qp_arg = quadrature.arg_value(device)
2167
+ return
2168
+
2169
+ if quadrature is None:
2064
2170
  wp.launch(
2065
2171
  kernel=kernel,
2066
- dim=domain.element_count(),
2067
- inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2172
+ dim=dim,
2173
+ inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2068
2174
  device=device,
2069
2175
  )
2070
- else:
2176
+ return
2177
+
2178
+ qp_arg = quadrature.arg_value(device)
2179
+ qp_element_index_arg = quadrature.element_index_arg_value(device)
2180
+ if trial is None:
2071
2181
  wp.launch(
2072
2182
  kernel=kernel,
2073
- dim=dim,
2074
- inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2183
+ dim=quadrature.evaluation_point_count(),
2184
+ inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2075
2185
  device=device,
2076
2186
  )
2187
+ return
2188
+
2189
+ nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
2190
+
2191
+ if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
2192
+ raise RuntimeError(
2193
+ f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
2194
+ )
2195
+ if dest.block_shape[1] != trial.node_dof_count:
2196
+ raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
2197
+
2198
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2199
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2200
+ triplet_values_temp = cache.borrow_temporary(
2201
+ temporary_store,
2202
+ dtype=dest.scalar_type,
2203
+ shape=(nnz, *dest.block_shape),
2204
+ device=device,
2205
+ )
2206
+ triplet_cols = triplet_cols_temp.array
2207
+ triplet_rows = triplet_rows_temp.array
2208
+ triplet_values = triplet_values_temp.array
2209
+ triplet_rows.fill_(-1)
2210
+ triplet_values.zero_()
2211
+
2212
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
2213
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
2214
+
2215
+ wp.launch(
2216
+ kernel=kernel,
2217
+ dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
2218
+ inputs=[
2219
+ qp_arg,
2220
+ qp_element_index_arg,
2221
+ elt_arg,
2222
+ elt_index_arg,
2223
+ trial_partition_arg,
2224
+ trial_topology_arg,
2225
+ field_arg_values,
2226
+ value_struct_values,
2227
+ triplet_rows,
2228
+ triplet_cols,
2229
+ triplet_values,
2230
+ ],
2231
+ device=device,
2232
+ )
2233
+
2234
+ bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
2077
2235
 
2078
2236
 
2079
2237
  @integrand
@@ -2091,6 +2249,8 @@ def interpolate(
2091
2249
  values: Optional[Dict[str, Any]] = None,
2092
2250
  device=None,
2093
2251
  kernel_options: Optional[Dict[str, Any]] = None,
2252
+ temporary_store: Optional[cache.TemporaryStore] = None,
2253
+ bsr_options: Optional[Dict[str, Any]] = None,
2094
2254
  ):
2095
2255
  """
2096
2256
  Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
@@ -2109,6 +2269,8 @@ def interpolate(
2109
2269
  values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
2110
2270
  device: Device on which to perform the interpolation
2111
2271
  kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
2272
+ temporary_store: shared pool from which to allocate temporary arrays
2273
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
2112
2274
  """
2113
2275
 
2114
2276
  if isinstance(integrand, FieldLike):
@@ -2126,8 +2288,12 @@ def interpolate(
2126
2288
  raise ValueError("integrand must be tagged with @integrand decorator")
2127
2289
 
2128
2290
  arguments = _parse_integrand_arguments(integrand, fields)
2129
- if arguments.test_name or arguments.trial_name:
2130
- raise ValueError("Test or Trial fields should not be used for interpolation")
2291
+ if arguments.test_name:
2292
+ raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
2293
+ if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
2294
+ raise ValueError(
2295
+ f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
2296
+ )
2131
2297
 
2132
2298
  if isinstance(dest, DiscreteField):
2133
2299
  dest = make_restriction(dest, domain=domain)
@@ -2160,7 +2326,10 @@ def interpolate(
2160
2326
  dest=dest,
2161
2327
  quadrature=quadrature,
2162
2328
  dim=dim,
2329
+ trial=fields.get(arguments.trial_name),
2163
2330
  fields=arguments.field_args,
2164
2331
  values=values,
2332
+ temporary_store=temporary_store,
2333
+ bsr_options=bsr_options,
2165
2334
  device=device,
2166
2335
  )
warp/fem/linalg.py CHANGED
@@ -172,11 +172,11 @@ def householder_qr_decomposition(A: Any):
172
172
 
173
173
  for i in range(type(x).length):
174
174
  for k in range(type(x).length):
175
- x[k] = wp.select(k < i, A[k, i], zero)
175
+ x[k] = wp.where(k < i, zero, A[k, i])
176
176
 
177
177
  alpha = wp.length(x) * wp.sign(x[i])
178
178
  x[i] += alpha
179
- two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
179
+ two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
180
180
 
181
181
  A -= wp.outer(two_over_x_sq * x, x * A)
182
182
  Q -= wp.outer(Q * x, two_over_x_sq * x)
@@ -201,11 +201,11 @@ def householder_make_hessenberg(A: Any):
201
201
 
202
202
  for i in range(1, type(x).length):
203
203
  for k in range(type(x).length):
204
- x[k] = wp.select(k < i, A[k, i - 1], zero)
204
+ x[k] = wp.where(k < i, zero, A[k, i - 1])
205
205
 
206
206
  alpha = wp.length(x) * wp.sign(x[i])
207
207
  x[i] += alpha
208
- two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
208
+ two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
209
209
 
210
210
  # apply on both sides
211
211
  A -= wp.outer(two_over_x_sq * x, x * A)
@@ -226,7 +226,7 @@ def solve_triangular(R: Any, b: Any):
226
226
  for i in range(b.length, 0, -1):
227
227
  j = i - 1
228
228
  r = b[j] - wp.dot(R[j], x)
229
- x[j] = wp.select(R[j, j] == zero, r / R[j, j], zero)
229
+ x[j] = wp.where(R[j, j] == zero, zero, r / R[j, j])
230
230
 
231
231
  return x
232
232