warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -420,36 +420,6 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
420
420
  return call
421
421
 
422
422
 
423
- def get_integrate_null_kernel(
424
- integrand_func: wp.Function,
425
- domain: GeometryDomain,
426
- quadrature: Quadrature,
427
- FieldStruct: wp.codegen.Struct,
428
- ValueStruct: wp.codegen.Struct,
429
- ):
430
- def integrate_kernel_fn(
431
- qp_arg: quadrature.Arg,
432
- domain_arg: domain.ElementArg,
433
- domain_index_arg: domain.ElementIndexArg,
434
- fields: FieldStruct,
435
- values: ValueStruct,
436
- ):
437
- element_index = domain.element_index(domain_index_arg, wp.tid())
438
-
439
- test_dof_index = NULL_DOF_INDEX
440
- trial_dof_index = NULL_DOF_INDEX
441
-
442
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
443
- for k in range(qp_point_count):
444
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
445
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
446
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
447
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
448
- integrand_func(sample, fields, values)
449
-
450
- return integrate_kernel_fn
451
-
452
-
453
423
  def get_integrate_constant_kernel(
454
424
  integrand_func: wp.Function,
455
425
  domain: GeometryDomain,
@@ -477,7 +447,7 @@ def get_integrate_constant_kernel(
477
447
  qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
478
448
  coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
479
449
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
480
-
450
+
481
451
  sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
482
452
  vol = domain.element_measure(domain_arg, sample)
483
453
 
@@ -497,6 +467,7 @@ def get_integrate_linear_kernel(
497
467
  FieldStruct: wp.codegen.Struct,
498
468
  ValueStruct: wp.codegen.Struct,
499
469
  test: TestField,
470
+ output_dtype,
500
471
  accumulate_dtype,
501
472
  ):
502
473
  def integrate_kernel_fn(
@@ -506,32 +477,36 @@ def get_integrate_linear_kernel(
506
477
  test_arg: test.space_restriction.NodeArg,
507
478
  fields: FieldStruct,
508
479
  values: ValueStruct,
509
- result: wp.array2d(dtype=accumulate_dtype),
480
+ result: wp.array2d(dtype=output_dtype),
510
481
  ):
511
- local_node_index = wp.tid()
482
+ local_node_index, test_dof = wp.tid()
512
483
  node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
513
484
  element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
514
485
 
515
486
  trial_dof_index = NULL_DOF_INDEX
516
487
 
488
+ val_sum = accumulate_dtype(0.0)
489
+
517
490
  for n in range(element_count):
518
491
  node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
519
492
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
520
493
 
494
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
495
+
521
496
  qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
522
497
  for k in range(qp_point_count):
523
498
  qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
524
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
525
-
499
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
526
500
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
527
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
528
501
 
529
- for i in range(test.space.VALUE_DOF_COUNT):
530
- test_dof_index = DofIndex(node_element_index.node_index_in_element, i)
531
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
532
- val = integrand_func(sample, fields, values)
502
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
503
+
504
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
505
+ val = integrand_func(sample, fields, values)
533
506
 
534
- result[node_index, i] = result[node_index, i] + accumulate_dtype(qp_weight * vol * val)
507
+ val_sum += accumulate_dtype(qp_weight * vol * val)
508
+
509
+ result[node_index, test_dof] = output_dtype(val_sum)
535
510
 
536
511
  return integrate_kernel_fn
537
512
 
@@ -542,6 +517,7 @@ def get_integrate_linear_nodal_kernel(
542
517
  FieldStruct: wp.codegen.Struct,
543
518
  ValueStruct: wp.codegen.Struct,
544
519
  test: TestField,
520
+ output_dtype,
545
521
  accumulate_dtype,
546
522
  ):
547
523
  def integrate_kernel_fn(
@@ -550,7 +526,7 @@ def get_integrate_linear_nodal_kernel(
550
526
  test_restriction_arg: test.space_restriction.NodeArg,
551
527
  fields: FieldStruct,
552
528
  values: ValueStruct,
553
- result: wp.array2d(dtype=accumulate_dtype),
529
+ result: wp.array2d(dtype=output_dtype),
554
530
  ):
555
531
  local_node_index, dof = wp.tid()
556
532
 
@@ -595,7 +571,7 @@ def get_integrate_linear_nodal_kernel(
595
571
 
596
572
  val_sum += accumulate_dtype(node_weight * vol * val)
597
573
 
598
- result[node_index, dof] = val_sum
574
+ result[node_index, dof] = output_dtype(val_sum)
599
575
 
600
576
  return integrate_kernel_fn
601
577
 
@@ -608,6 +584,7 @@ def get_integrate_bilinear_kernel(
608
584
  ValueStruct: wp.codegen.Struct,
609
585
  test: TestField,
610
586
  trial: TrialField,
587
+ output_dtype,
611
588
  accumulate_dtype,
612
589
  ):
613
590
  NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
@@ -624,19 +601,26 @@ def get_integrate_bilinear_kernel(
624
601
  row_offsets: wp.array(dtype=int),
625
602
  triplet_rows: wp.array(dtype=int),
626
603
  triplet_cols: wp.array(dtype=int),
627
- triplet_values: wp.array3d(dtype=accumulate_dtype),
604
+ triplet_values: wp.array3d(dtype=output_dtype),
628
605
  ):
629
- test_local_node_index = wp.tid()
606
+ test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
630
607
 
631
608
  element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
632
609
  test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
633
610
 
611
+ trial_dof_index = DofIndex(trial_node, trial_dof)
612
+
634
613
  for element in range(element_count):
635
614
  test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
636
615
  element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
637
616
  qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
638
617
 
639
- start_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT
618
+ test_dof_index = DofIndex(
619
+ test_element_index.node_index_in_element,
620
+ test_dof,
621
+ )
622
+
623
+ val_sum = accumulate_dtype(0.0)
640
624
 
641
625
  for k in range(qp_point_count):
642
626
  qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
@@ -645,42 +629,28 @@ def get_integrate_bilinear_kernel(
645
629
  qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
646
630
  vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
647
631
 
648
- offset_cur = start_offset
649
-
650
- for trial_n in range(NODES_PER_ELEMENT):
651
- for i in range(test.space.VALUE_DOF_COUNT):
652
- for j in range(trial.space.VALUE_DOF_COUNT):
653
- test_dof_index = DofIndex(
654
- test_element_index.node_index_in_element,
655
- i,
656
- )
657
- trial_dof_index = DofIndex(trial_n, j)
658
- sample = Sample(
659
- element_index,
660
- coords,
661
- qp_index,
662
- qp_weight,
663
- test_dof_index,
664
- trial_dof_index,
665
- )
666
- val = integrand_func(sample, fields, values)
667
- triplet_values[offset_cur, i, j] = triplet_values[offset_cur, i, j] + accumulate_dtype(
668
- qp_weight * vol * val
669
- )
670
-
671
- offset_cur += 1
672
-
673
- # Set column indices
674
- offset_cur = start_offset
675
- for trial_n in range(NODES_PER_ELEMENT):
632
+ sample = Sample(
633
+ element_index,
634
+ coords,
635
+ qp_index,
636
+ qp_weight,
637
+ test_dof_index,
638
+ trial_dof_index,
639
+ )
640
+ val = integrand_func(sample, fields, values)
641
+ val_sum += accumulate_dtype(qp_weight * vol * val)
642
+
643
+ block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
644
+ triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
645
+
646
+ # Set row and column indices
647
+ if test_dof == 0 and trial_dof == 0:
676
648
  trial_node_index = trial.space_partition.partition_node_index(
677
649
  trial_partition_arg,
678
- trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_n),
650
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
679
651
  )
680
-
681
- triplet_rows[offset_cur] = test_node_index
682
- triplet_cols[offset_cur] = trial_node_index
683
- offset_cur += 1
652
+ triplet_rows[block_offset] = test_node_index
653
+ triplet_cols[block_offset] = trial_node_index
684
654
 
685
655
  return integrate_kernel_fn
686
656
 
@@ -691,6 +661,7 @@ def get_integrate_bilinear_nodal_kernel(
691
661
  FieldStruct: wp.codegen.Struct,
692
662
  ValueStruct: wp.codegen.Struct,
693
663
  test: TestField,
664
+ output_dtype,
694
665
  accumulate_dtype,
695
666
  ):
696
667
  def integrate_kernel_fn(
@@ -701,7 +672,7 @@ def get_integrate_bilinear_nodal_kernel(
701
672
  values: ValueStruct,
702
673
  triplet_rows: wp.array(dtype=int),
703
674
  triplet_cols: wp.array(dtype=int),
704
- triplet_values: wp.array3d(dtype=accumulate_dtype),
675
+ triplet_values: wp.array3d(dtype=output_dtype),
705
676
  ):
706
677
  local_node_index, test_dof, trial_dof = wp.tid()
707
678
 
@@ -729,7 +700,6 @@ def get_integrate_bilinear_nodal_kernel(
729
700
  node_element_index.node_index_in_element,
730
701
  )
731
702
 
732
-
733
703
  test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
734
704
  trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
735
705
 
@@ -746,7 +716,7 @@ def get_integrate_bilinear_nodal_kernel(
746
716
 
747
717
  val_sum += accumulate_dtype(node_weight * vol * val)
748
718
 
749
- triplet_values[local_node_index, test_dof, trial_dof] = val_sum
719
+ triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
750
720
  triplet_rows[local_node_index] = node_index
751
721
  triplet_cols[local_node_index] = node_index
752
722
 
@@ -763,9 +733,12 @@ def _generate_integrate_kernel(
763
733
  trial: Optional[TrialField],
764
734
  trial_name: str,
765
735
  fields: Dict[str, FieldLike],
736
+ output_dtype: type,
766
737
  accumulate_dtype: type,
767
738
  kernel_options: Dict[str, Any] = {},
768
739
  ) -> wp.Kernel:
740
+ output_dtype = wp.types.type_scalar_type(output_dtype)
741
+
769
742
  # Extract field arguments from integrand
770
743
  field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
771
744
  integrand, fields=fields, domain=domain
@@ -775,7 +748,7 @@ def _generate_integrate_kernel(
775
748
  ValueStruct = _gen_value_struct(value_args)
776
749
 
777
750
  # Check if kernel exist in cache
778
- kernel_suffix = f"_itg_{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
751
+ kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
779
752
  if nodal:
780
753
  kernel_suffix += "_nodal"
781
754
  else:
@@ -819,6 +792,7 @@ def _generate_integrate_kernel(
819
792
  FieldStruct,
820
793
  ValueStruct,
821
794
  test=test,
795
+ output_dtype=output_dtype,
822
796
  accumulate_dtype=accumulate_dtype,
823
797
  )
824
798
  else:
@@ -829,6 +803,7 @@ def _generate_integrate_kernel(
829
803
  FieldStruct,
830
804
  ValueStruct,
831
805
  test=test,
806
+ output_dtype=output_dtype,
832
807
  accumulate_dtype=accumulate_dtype,
833
808
  )
834
809
  else:
@@ -839,6 +814,7 @@ def _generate_integrate_kernel(
839
814
  FieldStruct,
840
815
  ValueStruct,
841
816
  test=test,
817
+ output_dtype=output_dtype,
842
818
  accumulate_dtype=accumulate_dtype,
843
819
  )
844
820
  else:
@@ -850,6 +826,7 @@ def _generate_integrate_kernel(
850
826
  ValueStruct,
851
827
  test=test,
852
828
  trial=trial,
829
+ output_dtype=output_dtype,
853
830
  accumulate_dtype=accumulate_dtype,
854
831
  )
855
832
 
@@ -949,32 +926,46 @@ def _launch_integrate_kernel(
949
926
 
950
927
  # Linear form
951
928
  if trial is None:
952
- if test.space.VALUE_DOF_COUNT == 1:
953
- accumulate_array_dtype = accumulate_dtype
954
- else:
955
- accumulate_array_dtype = cache.cached_vec_type(length=test.space.VALUE_DOF_COUNT, dtype=accumulate_dtype)
956
-
957
- if output is not None and output.size < test.space_partition.node_count():
958
- raise RuntimeError(f"Output array must be of size at least {test.space_partition.node_count()}")
959
-
960
- accumulate_in_place = wp.types.types_equal(accumulate_array_dtype, output_dtype)
961
-
962
929
  # If an output array is provided with the correct type, accumulate directly into it
963
930
  # Otherwise, grab a temporary array
964
- if output is not None and accumulate_in_place:
965
- accumulate_array = output
966
- else:
967
- accumulate_temporary = cache.borrow_temporary(
931
+ if output is None:
932
+ if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
933
+ output_shape = (test.space_partition.node_count(),)
934
+ elif type_length(output_dtype) == 1:
935
+ output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
936
+ else:
937
+ raise RuntimeError(
938
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
939
+ )
940
+
941
+ output_temporary = cache.borrow_temporary(
968
942
  temporary_store=temporary_store,
969
- shape=test.space_partition.node_count(),
970
- dtype=accumulate_array_dtype,
943
+ shape=output_shape,
944
+ dtype=output_dtype,
971
945
  device=device,
972
- requires_grad=output is not None and output.requires_grad,
973
946
  )
974
- accumulate_array = accumulate_temporary.array
947
+
948
+ output = output_temporary.array
949
+
950
+ else:
951
+ output_temporary = None
952
+
953
+ if output.shape[0] < test.space_partition.node_count():
954
+ raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
955
+
956
+ output_dtype = output.dtype
957
+ if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
958
+ if type_length(output_dtype) != 1:
959
+ raise RuntimeError(
960
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
961
+ )
962
+ if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
963
+ raise RuntimeError(
964
+ f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
965
+ )
975
966
 
976
967
  # Launch the integration on the kernel on a 2d scalar view of the actual array
977
- accumulate_array.zero_()
968
+ output.zero_()
978
969
 
979
970
  def as_2d_array(array):
980
971
  return wp.array(
@@ -984,11 +975,11 @@ def _launch_integrate_kernel(
984
975
  owner=False,
985
976
  device=array.device,
986
977
  shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
987
- dtype=accumulate_dtype,
978
+ dtype=wp.types.type_scalar_type(output_dtype),
988
979
  grad=None if array.grad is None else as_2d_array(array.grad),
989
980
  )
990
981
 
991
- accumulate_2d_view = as_2d_array(accumulate_array)
982
+ output_view = output if output.ndim == 2 else as_2d_array(output)
992
983
 
993
984
  if nodal:
994
985
  wp.launch(
@@ -1000,14 +991,14 @@ def _launch_integrate_kernel(
1000
991
  test_arg,
1001
992
  field_arg_values,
1002
993
  value_struct_values,
1003
- accumulate_2d_view,
994
+ output_view,
1004
995
  ],
1005
996
  device=device,
1006
997
  )
1007
998
  else:
1008
999
  wp.launch(
1009
1000
  kernel=kernel,
1010
- dim=test.space_restriction.node_count(),
1001
+ dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
1011
1002
  inputs=[
1012
1003
  qp_arg,
1013
1004
  domain_elt_arg,
@@ -1015,47 +1006,23 @@ def _launch_integrate_kernel(
1015
1006
  test_arg,
1016
1007
  field_arg_values,
1017
1008
  value_struct_values,
1018
- accumulate_2d_view,
1009
+ output_view,
1019
1010
  ],
1020
1011
  device=device,
1021
1012
  )
1022
1013
 
1023
- # Accumulate and output dtypes are simular -- return accumulation array directly
1024
- if output == accumulate_array:
1025
- return output
1026
-
1027
- if accumulate_in_place:
1028
- return accumulate_temporary.detach()
1014
+ if output_temporary is not None:
1015
+ return output_temporary.detach()
1029
1016
 
1030
- # Different types -- need to cast result
1031
- if output is not None:
1032
- cast_result = output
1033
- elif type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
1034
- cast_result = wp.empty(
1035
- dtype=output_dtype,
1036
- shape=accumulate_array.shape,
1037
- device=device,
1038
- requires_grad=accumulate_array.requires_grad,
1039
- )
1040
- else:
1041
- cast_result = wp.empty(
1042
- dtype=output_dtype,
1043
- shape=accumulate_2d_view.shape,
1044
- device=device,
1045
- requires_grad=accumulate_array.requires_grad,
1046
- )
1047
-
1048
- array_cast(in_array=accumulate_array, out_array=cast_result)
1049
- accumulate_temporary.release() # Do not wait for garbage collection
1050
- return cast_result
1017
+ return output
1051
1018
 
1052
1019
  # Bilinear form
1053
1020
 
1054
1021
  if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
1055
- block_type = accumulate_dtype
1022
+ block_type = output_dtype
1056
1023
  else:
1057
1024
  block_type = cache.cached_mat_type(
1058
- shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=accumulate_dtype
1025
+ shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
1059
1026
  )
1060
1027
 
1061
1028
  if nodal:
@@ -1072,7 +1039,7 @@ def _launch_integrate_kernel(
1072
1039
  test.space.VALUE_DOF_COUNT,
1073
1040
  trial.space.VALUE_DOF_COUNT,
1074
1041
  ),
1075
- dtype=accumulate_dtype,
1042
+ dtype=output_dtype,
1076
1043
  device=device,
1077
1044
  )
1078
1045
  triplet_cols = triplet_cols_temp.array
@@ -1105,7 +1072,12 @@ def _launch_integrate_kernel(
1105
1072
  trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1106
1073
  wp.launch(
1107
1074
  kernel=kernel,
1108
- dim=test.space_restriction.node_count(),
1075
+ dim=(
1076
+ test.space_restriction.node_count(),
1077
+ trial.space.topology.NODES_PER_ELEMENT,
1078
+ test.space.VALUE_DOF_COUNT,
1079
+ trial.space.VALUE_DOF_COUNT,
1080
+ ),
1109
1081
  inputs=[
1110
1082
  qp_arg,
1111
1083
  domain_elt_arg,
@@ -1123,38 +1095,27 @@ def _launch_integrate_kernel(
1123
1095
  device=device,
1124
1096
  )
1125
1097
 
1126
- compress_in_place = accumulate_dtype == output_dtype
1127
-
1128
1098
  if output is not None:
1129
1099
  if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1130
1100
  raise RuntimeError(
1131
1101
  f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1132
1102
  )
1133
1103
 
1134
- if output is not None and compress_in_place:
1135
- bsr_matrix = output
1136
1104
  else:
1137
- bsr_matrix = bsr_zeros(
1105
+ output = bsr_zeros(
1138
1106
  rows_of_blocks=test.space_partition.node_count(),
1139
1107
  cols_of_blocks=trial.space_partition.node_count(),
1140
1108
  block_type=block_type,
1141
1109
  device=device,
1142
1110
  )
1143
1111
 
1144
- bsr_set_from_triplets(bsr_matrix, triplet_rows, triplet_cols, triplet_values)
1112
+ bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
1145
1113
 
1146
1114
  # Do not wait for garbage collection
1147
1115
  triplet_values_temp.release()
1148
1116
  triplet_rows_temp.release()
1149
1117
  triplet_cols_temp.release()
1150
1118
 
1151
- if compress_in_place:
1152
- return bsr_matrix
1153
-
1154
- if output is None:
1155
- return bsr_copy(bsr_matrix, scalar_type=output_dtype)
1156
-
1157
- bsr_assign(src=bsr_matrix, dest=output)
1158
1119
  return output
1159
1120
 
1160
1121
 
@@ -1181,7 +1142,7 @@ def integrate(
1181
1142
  quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1182
1143
  nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1183
1144
  fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1184
- values: Additional variable values to be passed to the integrand, can by of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1145
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1185
1146
  temporary_store: shared pool from which to allocate temporary arrays
1186
1147
  accumulate_dtype: Scalar type to be used for accumulating integration samples
1187
1148
  output: Sparse matrix or warp array into which to store the result of the integration
@@ -1246,6 +1207,7 @@ def integrate(
1246
1207
  trial_name=trial_name,
1247
1208
  fields=fields,
1248
1209
  accumulate_dtype=accumulate_dtype,
1210
+ output_dtype=output_dtype,
1249
1211
  kernel_options=kernel_options,
1250
1212
  )
1251
1213
 
@@ -1268,7 +1230,7 @@ def integrate(
1268
1230
  )
1269
1231
 
1270
1232
 
1271
- def get_interpolate_kernel(
1233
+ def get_interpolate_to_field_function(
1272
1234
  integrand_func: wp.Function,
1273
1235
  domain: GeometryDomain,
1274
1236
  FieldStruct: wp.codegen.Struct,
@@ -1277,7 +1239,8 @@ def get_interpolate_kernel(
1277
1239
  ):
1278
1240
  value_type = dest.space.dtype
1279
1241
 
1280
- def interpolate_kernel_fn(
1242
+ def interpolate_to_field_fn(
1243
+ local_node_index: int,
1281
1244
  domain_arg: domain.ElementArg,
1282
1245
  domain_index_arg: domain.ElementIndexArg,
1283
1246
  dest_node_arg: dest.space_restriction.NodeArg,
@@ -1285,19 +1248,15 @@ def get_interpolate_kernel(
1285
1248
  fields: FieldStruct,
1286
1249
  values: ValueStruct,
1287
1250
  ):
1288
- local_node_index = wp.tid()
1289
1251
  node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1290
-
1291
1252
  element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
1292
- if element_count == 0:
1293
- return
1294
1253
 
1295
1254
  test_dof_index = NULL_DOF_INDEX
1296
1255
  trial_dof_index = NULL_DOF_INDEX
1297
1256
  node_weight = 1.0
1298
1257
 
1299
- # Volume-weighted average accross elements
1300
- # Superfluous if the function is continuous, but we might as well
1258
+ # Volume-weighted average across elements
1259
+ # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1301
1260
 
1302
1261
  val_sum = value_type(0.0)
1303
1262
  vol_sum = float(0.0)
@@ -1328,15 +1287,112 @@ def get_interpolate_kernel(
1328
1287
  vol_sum += vol
1329
1288
  val_sum += vol * val
1330
1289
 
1290
+ return val_sum, vol_sum
1291
+
1292
+ return interpolate_to_field_fn
1293
+
1294
+
1295
+ def get_interpolate_to_field_kernel(
1296
+ interpolate_to_field_fn: wp.Function,
1297
+ domain: GeometryDomain,
1298
+ FieldStruct: wp.codegen.Struct,
1299
+ ValueStruct: wp.codegen.Struct,
1300
+ dest: FieldRestriction,
1301
+ ):
1302
+ def interpolate_to_field_kernel_fn(
1303
+ domain_arg: domain.ElementArg,
1304
+ domain_index_arg: domain.ElementIndexArg,
1305
+ dest_node_arg: dest.space_restriction.NodeArg,
1306
+ dest_eval_arg: dest.field.EvalArg,
1307
+ fields: FieldStruct,
1308
+ values: ValueStruct,
1309
+ ):
1310
+ local_node_index = wp.tid()
1311
+
1312
+ val_sum, vol_sum = interpolate_to_field_fn(
1313
+ local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1314
+ )
1315
+
1331
1316
  if vol_sum > 0.0:
1317
+ node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1332
1318
  dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
1333
1319
 
1334
- return interpolate_kernel_fn
1320
+ return interpolate_to_field_kernel_fn
1321
+
1322
+
1323
+ def get_interpolate_to_array_kernel(
1324
+ integrand_func: wp.Function,
1325
+ domain: GeometryDomain,
1326
+ quadrature: Quadrature,
1327
+ FieldStruct: wp.codegen.Struct,
1328
+ ValueStruct: wp.codegen.Struct,
1329
+ value_type: type,
1330
+ ):
1331
+ def interpolate_to_array_kernel_fn(
1332
+ qp_arg: quadrature.Arg,
1333
+ domain_arg: quadrature.domain.ElementArg,
1334
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1335
+ fields: FieldStruct,
1336
+ values: ValueStruct,
1337
+ result: wp.array(dtype=value_type),
1338
+ ):
1339
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1340
+
1341
+ test_dof_index = NULL_DOF_INDEX
1342
+ trial_dof_index = NULL_DOF_INDEX
1343
+
1344
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1345
+ for k in range(qp_point_count):
1346
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1347
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1348
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1349
+
1350
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1351
+
1352
+ result[qp_index] = integrand_func(sample, fields, values)
1353
+
1354
+ return interpolate_to_array_kernel_fn
1355
+
1356
+
1357
+ def get_interpolate_nonvalued_kernel(
1358
+ integrand_func: wp.Function,
1359
+ domain: GeometryDomain,
1360
+ quadrature: Quadrature,
1361
+ FieldStruct: wp.codegen.Struct,
1362
+ ValueStruct: wp.codegen.Struct,
1363
+ ):
1364
+ def interpolate_nonvalued_kernel_fn(
1365
+ qp_arg: quadrature.Arg,
1366
+ domain_arg: quadrature.domain.ElementArg,
1367
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1368
+ fields: FieldStruct,
1369
+ values: ValueStruct,
1370
+ ):
1371
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1335
1372
 
1373
+ test_dof_index = NULL_DOF_INDEX
1374
+ trial_dof_index = NULL_DOF_INDEX
1336
1375
 
1337
- def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields: Dict[str, FieldLike]) -> wp.Kernel:
1338
- domain = dest.domain
1376
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1377
+ for k in range(qp_point_count):
1378
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1379
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1380
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1339
1381
 
1382
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1383
+ integrand_func(sample, fields, values)
1384
+
1385
+ return interpolate_nonvalued_kernel_fn
1386
+
1387
+
1388
+ def _generate_interpolate_kernel(
1389
+ integrand: Integrand,
1390
+ domain: GeometryDomain,
1391
+ dest: Optional[Union[FieldLike, wp.array]],
1392
+ quadrature: Optional[Quadrature],
1393
+ fields: Dict[str, FieldLike],
1394
+ kernel_options: Dict[str, Any] = {},
1395
+ ) -> wp.Kernel:
1340
1396
  # Extract field arguments from integrand
1341
1397
  field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
1342
1398
  integrand, fields=fields, domain=domain
@@ -1354,9 +1410,14 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
1354
1410
  ValueStruct = _gen_value_struct(value_args)
1355
1411
 
1356
1412
  # Check if kernel exist in cache
1357
- kernel_suffix = (
1358
- f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1359
- )
1413
+ if isinstance(dest, FieldRestriction):
1414
+ kernel_suffix = (
1415
+ f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1416
+ )
1417
+ elif wp.types.is_array(dest):
1418
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
1419
+ else:
1420
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
1360
1421
 
1361
1422
  kernel = cache.get_integrand_kernel(
1362
1423
  integrand=integrand,
@@ -1366,18 +1427,61 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
1366
1427
  return kernel, FieldStruct, ValueStruct
1367
1428
 
1368
1429
  # Generate interpolation kernel
1369
- interpolate_kernel_fn = get_interpolate_kernel(
1370
- integrand_func,
1371
- domain,
1372
- dest=dest,
1373
- FieldStruct=FieldStruct,
1374
- ValueStruct=ValueStruct,
1375
- )
1430
+ if isinstance(dest, FieldRestriction):
1431
+ # need to split into kernel + function for diffferentiability
1432
+ interpolate_fn = get_interpolate_to_field_function(
1433
+ integrand_func,
1434
+ domain,
1435
+ dest=dest,
1436
+ FieldStruct=FieldStruct,
1437
+ ValueStruct=ValueStruct,
1438
+ )
1439
+
1440
+ interpolate_fn = cache.get_integrand_function(
1441
+ integrand=integrand,
1442
+ func=interpolate_fn,
1443
+ suffix=kernel_suffix,
1444
+ code_transformers=[
1445
+ PassFieldArgsToIntegrand(
1446
+ arg_names=integrand.argspec.args,
1447
+ field_args=field_args.keys(),
1448
+ value_args=value_args.keys(),
1449
+ sample_name=sample_name,
1450
+ domain_name=domain_name,
1451
+ )
1452
+ ],
1453
+ )
1454
+
1455
+ interpolate_kernel_fn = get_interpolate_to_field_kernel(
1456
+ interpolate_fn,
1457
+ domain,
1458
+ dest=dest,
1459
+ FieldStruct=FieldStruct,
1460
+ ValueStruct=ValueStruct,
1461
+ )
1462
+ elif wp.types.is_array(dest):
1463
+ interpolate_kernel_fn = get_interpolate_to_array_kernel(
1464
+ integrand_func,
1465
+ domain=domain,
1466
+ quadrature=quadrature,
1467
+ value_type=dest.dtype,
1468
+ FieldStruct=FieldStruct,
1469
+ ValueStruct=ValueStruct,
1470
+ )
1471
+ else:
1472
+ interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
1473
+ integrand_func,
1474
+ domain=domain,
1475
+ quadrature=quadrature,
1476
+ FieldStruct=FieldStruct,
1477
+ ValueStruct=ValueStruct,
1478
+ )
1376
1479
 
1377
1480
  kernel = cache.get_integrand_kernel(
1378
1481
  integrand=integrand,
1379
1482
  kernel_fn=interpolate_kernel_fn,
1380
1483
  suffix=kernel_suffix,
1484
+ kernel_options=kernel_options,
1381
1485
  code_transformers=[
1382
1486
  PassFieldArgsToIntegrand(
1383
1487
  arg_names=integrand.argspec.args,
@@ -1396,16 +1500,16 @@ def _launch_interpolate_kernel(
1396
1500
  kernel: wp.kernel,
1397
1501
  FieldStruct: wp.codegen.Struct,
1398
1502
  ValueStruct: wp.codegen.Struct,
1399
- dest: FieldLike,
1503
+ domain: GeometryDomain,
1504
+ dest: Optional[Union[FieldRestriction, wp.array]],
1505
+ quadrature: Optional[Quadrature],
1400
1506
  fields: Dict[str, FieldLike],
1401
1507
  values: Dict[str, Any],
1402
1508
  device,
1403
1509
  ) -> wp.Kernel:
1404
1510
  # Set-up launch arguments
1405
- elt_arg = dest.domain.element_arg_value(device=device)
1406
- elt_index_arg = dest.domain.element_index_arg_value(device=device)
1407
- dest_node_arg = dest.space_restriction.node_arg(device=device)
1408
- dest_eval_arg = dest.field.eval_arg_value(device=device)
1511
+ elt_arg = domain.element_arg_value(device=device)
1512
+ elt_index_arg = domain.element_index_arg_value(device=device)
1409
1513
 
1410
1514
  field_arg_values = FieldStruct()
1411
1515
  for k, v in fields.items():
@@ -1415,37 +1519,65 @@ def _launch_interpolate_kernel(
1415
1519
  for k, v in values.items():
1416
1520
  setattr(value_struct_values, k, v)
1417
1521
 
1418
- wp.launch(
1419
- kernel=kernel,
1420
- dim=dest.space_restriction.node_count(),
1421
- inputs=[
1422
- elt_arg,
1423
- elt_index_arg,
1424
- dest_node_arg,
1425
- dest_eval_arg,
1426
- field_arg_values,
1427
- value_struct_values,
1428
- ],
1429
- device=device,
1430
- )
1522
+ if isinstance(dest, FieldRestriction):
1523
+ dest_node_arg = dest.space_restriction.node_arg(device=device)
1524
+ dest_eval_arg = dest.field.eval_arg_value(device=device)
1525
+
1526
+ wp.launch(
1527
+ kernel=kernel,
1528
+ dim=dest.space_restriction.node_count(),
1529
+ inputs=[
1530
+ elt_arg,
1531
+ elt_index_arg,
1532
+ dest_node_arg,
1533
+ dest_eval_arg,
1534
+ field_arg_values,
1535
+ value_struct_values,
1536
+ ],
1537
+ device=device,
1538
+ )
1539
+ elif wp.types.is_array(dest):
1540
+ qp_arg = quadrature.arg_value(device)
1541
+ wp.launch(
1542
+ kernel=kernel,
1543
+ dim=domain.element_count(),
1544
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
1545
+ device=device,
1546
+ )
1547
+ else:
1548
+ qp_arg = quadrature.arg_value(device)
1549
+ wp.launch(
1550
+ kernel=kernel,
1551
+ dim=domain.element_count(),
1552
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
1553
+ device=device,
1554
+ )
1431
1555
 
1432
1556
 
1433
1557
  def interpolate(
1434
1558
  integrand: Integrand,
1435
- dest: Union[DiscreteField, FieldRestriction],
1559
+ dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
1560
+ quadrature: Optional[Quadrature] = None,
1436
1561
  fields: Dict[str, FieldLike] = {},
1437
1562
  values: Dict[str, Any] = {},
1438
1563
  device=None,
1564
+ kernel_options: Dict[str, Any] = {},
1439
1565
  ):
1440
1566
  """
1441
- Interpolates a function and assigns the result to a discrete field.
1567
+ Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
1442
1568
 
1443
1569
  Args:
1444
1570
  integrand: Function to be interpolated, must have :func:`integrand` decorator
1445
- dest: Discrete field, or restriction of a discrete field to a domain, to which the interpolation result will be assigned
1571
+ dest: Where to store the interpolation result. Can be either
1572
+
1573
+ - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
1574
+ - a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
1575
+ - ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is reponsible for dealing with the interpolation result.
1576
+ quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
1446
1577
  fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
1447
- values: Additional variable values to be passed to the integrand, can by of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1578
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1448
1579
  device: Device on which to perform the interpolation
1580
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1449
1581
  """
1450
1582
  if not isinstance(integrand, Integrand):
1451
1583
  raise ValueError("integrand must be tagged with @integrand decorator")
@@ -1454,20 +1586,33 @@ def interpolate(
1454
1586
  if test is not None or trial is not None:
1455
1587
  raise ValueError("Test or Trial fields should not be used for interpolation")
1456
1588
 
1457
- if not isinstance(dest, FieldRestriction):
1589
+ if isinstance(dest, DiscreteField):
1458
1590
  dest = make_restriction(dest)
1459
1591
 
1592
+ if isinstance(dest, FieldRestriction):
1593
+ domain = dest.domain
1594
+ else:
1595
+ if quadrature is None:
1596
+ raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
1597
+
1598
+ domain = quadrature.domain
1599
+
1460
1600
  kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
1461
1601
  integrand=integrand,
1602
+ domain=domain,
1462
1603
  dest=dest,
1604
+ quadrature=quadrature,
1463
1605
  fields=fields,
1606
+ kernel_options=kernel_options,
1464
1607
  )
1465
1608
 
1466
1609
  return _launch_interpolate_kernel(
1467
1610
  kernel=kernel,
1468
1611
  FieldStruct=FieldStruct,
1469
1612
  ValueStruct=ValueStruct,
1613
+ domain=domain,
1470
1614
  dest=dest,
1615
+ quadrature=quadrature,
1471
1616
  fields=fields,
1472
1617
  values=values,
1473
1618
  device=device,