warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -16,7 +16,7 @@
16
16
  import ast
17
17
  import inspect
18
18
  import textwrap
19
- from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
19
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
20
20
 
21
21
  import warp as wp
22
22
  import warp.fem.operator as operator
@@ -34,7 +34,10 @@ from warp.fem.field import (
34
34
  TrialField,
35
35
  make_restriction,
36
36
  )
37
- from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
37
+ from warp.fem.field.virtual import (
38
+ make_bilinear_dispatch_kernel,
39
+ make_linear_dispatch_kernel,
40
+ )
38
41
  from warp.fem.linalg import array_axpy, basis_coefficient
39
42
  from warp.fem.operator import (
40
43
  Integrand,
@@ -56,7 +59,7 @@ from warp.fem.types import (
56
59
  )
57
60
  from warp.fem.utils import type_zero_element
58
61
  from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
59
- from warp.types import type_size
62
+ from warp.types import is_array, type_size
60
63
  from warp.utils import array_cast
61
64
 
62
65
 
@@ -101,7 +104,8 @@ class IntegrandVisitor(ast.NodeTransformer):
101
104
  field: FieldLike
102
105
  abstract_type: type
103
106
  concrete_type: type
104
- root_arg_name: type
107
+ root_arg_name: str
108
+ local_arg_name: str
105
109
 
106
110
  def __init__(
107
111
  self,
@@ -111,6 +115,7 @@ class IntegrandVisitor(ast.NodeTransformer):
111
115
  self._integrand = integrand
112
116
  self._field_symbols = field_info.copy()
113
117
  self._field_nodes = {}
118
+ self._field_arg_annotation_nodes = {}
114
119
 
115
120
  @staticmethod
116
121
  def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
@@ -127,6 +132,7 @@ class IntegrandVisitor(ast.NodeTransformer):
127
132
  abstract_type=integrand.argspec.annotations[name],
128
133
  concrete_type=get_concrete_type(field),
129
134
  root_arg_name=name,
135
+ local_arg_name=name,
130
136
  )
131
137
  for name, field in field_args.items()
132
138
  }
@@ -167,6 +173,7 @@ class IntegrandVisitor(ast.NodeTransformer):
167
173
  field=res[0],
168
174
  abstract_type=res[1],
169
175
  concrete_type=res[2],
176
+ local_arg_name=field_info.local_arg_name,
170
177
  root_arg_name=f"{field_info.root_arg_name}.{func.name}",
171
178
  )
172
179
 
@@ -191,6 +198,13 @@ class IntegrandVisitor(ast.NodeTransformer):
191
198
 
192
199
  return node
193
200
 
201
+ def visit_FunctionDef(self, node: ast.FunctionDef):
202
+ # record field arg annotation nodes
203
+ for arg in node.args.args:
204
+ self._field_arg_annotation_nodes[arg.arg] = arg.annotation
205
+
206
+ return self.generic_visit(node)
207
+
194
208
  def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
195
209
  # Get field types for call site arguments
196
210
  call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
@@ -211,7 +225,13 @@ class IntegrandVisitor(ast.NodeTransformer):
211
225
  raise TypeError(
212
226
  f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
213
227
  )
214
- callee_field_args[arg] = passed_field_info
228
+ callee_field_args[arg] = IntegrandVisitor.FieldInfo(
229
+ field=passed_field_info.field,
230
+ abstract_type=passed_field_info.abstract_type,
231
+ concrete_type=passed_field_info.concrete_type,
232
+ local_arg_name=arg,
233
+ root_arg_name=passed_field_info.root_arg_name,
234
+ )
215
235
 
216
236
  return callee_field_args
217
237
 
@@ -263,18 +283,14 @@ class IntegrandTransformer(IntegrandVisitor):
263
283
  f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
264
284
  ) from e
265
285
 
266
- # Update the ast Call node to use the new function pointer
267
- call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
268
-
269
286
  # Save the pointer as an attribute than can be accessed from the calling scope
270
- # For usual operator call syntax, we can use the operator itself, but for the
271
- # shortcut default operator syntax, we store it on the callee's concrete type
272
- if isinstance(callee, Operator):
273
- setattr(callee, pointer.key, pointer)
274
- else:
275
- setattr(field_info.concrete_type, pointer.key, pointer)
287
+ # (use the annotation node of the argument this field is constructed from)
288
+ callee_node = self._field_arg_annotation_nodes[field_info.local_arg_name]
289
+ setattr(self._field_symbols[field_info.local_arg_name].abstract_type, pointer.key, pointer)
290
+ call.func = ast.Attribute(value=callee_node, attr=pointer.key, ctx=ast.Load())
276
291
 
277
- # also insert callee as first argument
292
+ # For shortcut default operator syntax, insert callee as first argument
293
+ if not isinstance(callee, Operator):
278
294
  call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
279
295
 
280
296
  # replace first argument with selected attribute
@@ -592,6 +608,9 @@ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_s
592
608
  return options
593
609
 
594
610
 
611
+ _INTEGRATE_CONSTANT_TILE_SIZE = 256
612
+
613
+
595
614
  def get_integrate_constant_kernel(
596
615
  integrand_func: wp.Function,
597
616
  domain: GeometryDomain,
@@ -599,8 +618,12 @@ def get_integrate_constant_kernel(
599
618
  FieldStruct: wp.codegen.Struct,
600
619
  ValueStruct: wp.codegen.Struct,
601
620
  accumulate_dtype,
621
+ tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
602
622
  ):
623
+ zero_element = type_zero_element(accumulate_dtype)
624
+
603
625
  def integrate_kernel_fn(
626
+ qp_count: int,
604
627
  qp_arg: quadrature.Arg,
605
628
  qp_element_index_arg: quadrature.ElementIndexArg,
606
629
  domain_arg: domain.ElementArg,
@@ -609,26 +632,33 @@ def get_integrate_constant_kernel(
609
632
  values: ValueStruct,
610
633
  result: wp.array(dtype=accumulate_dtype),
611
634
  ):
612
- qp_eval_index = wp.tid()
613
- domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
614
- if domain_element_index == NULL_ELEMENT_INDEX:
615
- return
635
+ block_index, lane = wp.tid()
636
+ qp_eval_index = block_index * tile_size + lane
616
637
 
617
- element_index = domain.element_index(domain_index_arg, domain_element_index)
638
+ if qp_eval_index >= qp_count:
639
+ domain_element_index, qp = NULL_ELEMENT_INDEX, 0
640
+ else:
641
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
618
642
 
619
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
620
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
621
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
643
+ if domain_element_index == NULL_ELEMENT_INDEX:
644
+ val = zero_element()
645
+ else:
646
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
622
647
 
623
- test_dof_index = NULL_DOF_INDEX
624
- trial_dof_index = NULL_DOF_INDEX
648
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
649
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
650
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
625
651
 
626
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
627
- vol = domain.element_measure(domain_arg, sample)
652
+ test_dof_index = NULL_DOF_INDEX
653
+ trial_dof_index = NULL_DOF_INDEX
628
654
 
629
- val = integrand_func(sample, fields, values)
655
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
656
+ vol = domain.element_measure(domain_arg, sample)
657
+
658
+ val = accumulate_dtype(qp_weight * vol * integrand_func(sample, fields, values))
630
659
 
631
- wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
660
+ tile_integral = wp.tile_sum(wp.tile(val))
661
+ wp.tile_atomic_add(result, tile_integral, offset=0)
632
662
 
633
663
  return integrate_kernel_fn
634
664
 
@@ -1020,7 +1050,7 @@ def get_integrate_bilinear_local_kernel(
1020
1050
 
1021
1051
  sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1022
1052
  val = integrand_func(sample, fields, values)
1023
- result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
1053
+ result[test_dof, trial_dof, qp_eval_index, taylor_dof] = qp_vol * val
1024
1054
 
1025
1055
  return integrate_kernel_fn
1026
1056
 
@@ -1150,9 +1180,46 @@ def _generate_integrate_kernel(
1150
1180
  return kernel, FieldStruct, ValueStruct
1151
1181
 
1152
1182
 
1183
+ def _generate_auxiliary_kernels(
1184
+ quadrature: Quadrature,
1185
+ test: Optional[TestField],
1186
+ trial: Optional[TrialField],
1187
+ accumulate_dtype: type,
1188
+ device,
1189
+ kernel_options: Optional[Dict[str, Any]] = None,
1190
+ ) -> List[Tuple[wp.Kernel, int]]:
1191
+ if test is None or not isinstance(test, LocalTestField):
1192
+ return ()
1193
+
1194
+ # For dispatched assembly, generate additional kernels
1195
+ # heuristic to use tiles for "long" quadratures
1196
+ dispatch_tile_size = 32
1197
+ qp_eval_count = quadrature.evaluation_point_count()
1198
+
1199
+ if trial is None:
1200
+ if (
1201
+ not device.is_cuda
1202
+ or qp_eval_count * test.space_restriction.total_node_element_count()
1203
+ < 3 * dispatch_tile_size * test.space_restriction.node_count() * test.domain.element_count()
1204
+ ):
1205
+ dispatch_tile_size = 1
1206
+ dispatch_kernel = make_linear_dispatch_kernel(
1207
+ test, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
1208
+ )
1209
+ else:
1210
+ if not device.is_cuda or qp_eval_count < 3 * dispatch_tile_size * test.domain.element_count():
1211
+ dispatch_tile_size = 1
1212
+ dispatch_kernel = make_bilinear_dispatch_kernel(
1213
+ test, trial, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
1214
+ )
1215
+
1216
+ return ((dispatch_kernel, dispatch_tile_size),)
1217
+
1218
+
1153
1219
  def _launch_integrate_kernel(
1154
1220
  integrand: Integrand,
1155
1221
  kernel: wp.Kernel,
1222
+ auxiliary_kernels: List[Tuple[wp.Kernel, int]],
1156
1223
  FieldStruct: wp.codegen.Struct,
1157
1224
  ValueStruct: wp.codegen.Struct,
1158
1225
  domain: GeometryDomain,
@@ -1202,10 +1269,15 @@ def _launch_integrate_kernel(
1202
1269
  if output != accumulate_array or not add_to_output:
1203
1270
  accumulate_array.zero_()
1204
1271
 
1272
+ qp_count = quadrature.evaluation_point_count()
1273
+ tile_size = _INTEGRATE_CONSTANT_TILE_SIZE
1274
+ block_count = (qp_count + tile_size - 1) // tile_size
1205
1275
  wp.launch(
1206
1276
  kernel=kernel,
1207
- dim=quadrature.evaluation_point_count(),
1277
+ dim=(block_count, tile_size),
1278
+ block_dim=tile_size,
1208
1279
  inputs=[
1280
+ qp_count,
1209
1281
  qp_arg,
1210
1282
  quadrature.element_index_arg_value(device),
1211
1283
  domain_elt_arg,
@@ -1328,21 +1400,29 @@ def _launch_integrate_kernel(
1328
1400
  device=device,
1329
1401
  )
1330
1402
 
1331
- dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
1332
- wp.launch(
1333
- kernel=dispatch_kernel,
1334
- dim=(test.space_restriction.node_count(), test.node_dof_count),
1335
- inputs=[
1336
- qp_arg,
1337
- domain_elt_arg,
1338
- domain_elt_index_arg,
1339
- test_arg,
1340
- test.space.space_arg_value(device),
1341
- local_result.array,
1342
- output_view,
1343
- ],
1344
- device=device,
1345
- )
1403
+ if test.TAYLOR_DOF_COUNT == 0:
1404
+ wp.utils.warn(
1405
+ f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
1406
+ category=UserWarning,
1407
+ stacklevel=2,
1408
+ )
1409
+ else:
1410
+ dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
1411
+ wp.launch(
1412
+ kernel=dispatch_kernel,
1413
+ dim=(test.space_restriction.node_count(), dispatch_tile_size),
1414
+ block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
1415
+ inputs=[
1416
+ qp_arg,
1417
+ domain_elt_arg,
1418
+ domain_elt_index_arg,
1419
+ test_arg,
1420
+ test.space.space_arg_value(device),
1421
+ local_result.array,
1422
+ output_view,
1423
+ ],
1424
+ device=device,
1425
+ )
1346
1426
 
1347
1427
  local_result.release()
1348
1428
 
@@ -1415,14 +1495,15 @@ def _launch_integrate_kernel(
1415
1495
  device=device,
1416
1496
  )
1417
1497
  elif isinstance(test, LocalTestField):
1498
+ qp_eval_count = quadrature.evaluation_point_count()
1418
1499
  local_result = cache.borrow_temporary(
1419
1500
  temporary_store=temporary_store,
1420
1501
  device=device,
1421
1502
  requires_grad=False,
1422
1503
  shape=(
1423
- quadrature.evaluation_point_count(),
1424
1504
  test.value_dof_count,
1425
1505
  trial.value_dof_count,
1506
+ qp_eval_count,
1426
1507
  test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
1427
1508
  ),
1428
1509
  dtype=float,
@@ -1431,7 +1512,7 @@ def _launch_integrate_kernel(
1431
1512
  wp.launch(
1432
1513
  kernel=kernel,
1433
1514
  dim=(
1434
- quadrature.evaluation_point_count(),
1515
+ qp_eval_count,
1435
1516
  test.value_dof_count,
1436
1517
  trial.value_dof_count,
1437
1518
  trial.TAYLOR_DOF_COUNT,
@@ -1448,45 +1529,41 @@ def _launch_integrate_kernel(
1448
1529
  device=device,
1449
1530
  )
1450
1531
 
1451
- vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
1452
- vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
1453
- local_result_as_vec = wp.array(
1454
- data=None,
1455
- ptr=local_result.array.ptr,
1456
- capacity=local_result.array.capacity,
1457
- device=local_result.array.device,
1458
- shape=vec_array_shape,
1459
- dtype=vec_array_dtype,
1460
- )
1461
-
1462
- dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
1463
-
1464
- trial_partition_arg = trial.space_partition.partition_arg_value(device)
1465
- trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1466
- wp.launch(
1467
- kernel=dispatch_kernel,
1468
- dim=(
1469
- test.space_restriction.node_count(),
1470
- test.node_dof_count,
1471
- trial.node_dof_count,
1472
- trial.space.topology.MAX_NODES_PER_ELEMENT,
1473
- ),
1474
- inputs=[
1475
- qp_arg,
1476
- domain_elt_arg,
1477
- domain_elt_index_arg,
1478
- test_arg,
1479
- test.space.space_arg_value(device),
1480
- trial_partition_arg,
1481
- trial_topology_arg,
1482
- trial.space.space_arg_value(device),
1483
- local_result_as_vec,
1484
- triplet_rows,
1485
- triplet_cols,
1486
- triplet_values,
1487
- ],
1488
- device=device,
1489
- )
1532
+ if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
1533
+ wp.utils.warn(
1534
+ f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
1535
+ category=UserWarning,
1536
+ stacklevel=2,
1537
+ )
1538
+ triplet_rows.fill_(-1)
1539
+ else:
1540
+ dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
1541
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1542
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1543
+ wp.launch(
1544
+ kernel=dispatch_kernel,
1545
+ dim=(
1546
+ test.space_restriction.total_node_element_count(),
1547
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1548
+ dispatch_tile_size,
1549
+ ),
1550
+ block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
1551
+ inputs=[
1552
+ qp_arg,
1553
+ domain_elt_arg,
1554
+ domain_elt_index_arg,
1555
+ test_arg,
1556
+ test.space.space_arg_value(device),
1557
+ trial_partition_arg,
1558
+ trial_topology_arg,
1559
+ trial.space.space_arg_value(device),
1560
+ local_result.array,
1561
+ triplet_rows,
1562
+ triplet_cols,
1563
+ triplet_values,
1564
+ ],
1565
+ device=device,
1566
+ )
1490
1567
 
1491
1568
  local_result.release()
1492
1569
 
@@ -1621,6 +1698,9 @@ def integrate(
1621
1698
  if values is None:
1622
1699
  values = {}
1623
1700
 
1701
+ if device is None:
1702
+ device = wp.get_device()
1703
+
1624
1704
  if not isinstance(integrand, Integrand):
1625
1705
  raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1626
1706
 
@@ -1713,9 +1793,19 @@ def integrate(
1713
1793
  kernel_options=kernel_options,
1714
1794
  )
1715
1795
 
1796
+ auxiliary_kernels = _generate_auxiliary_kernels(
1797
+ quadrature=quadrature,
1798
+ test=test,
1799
+ trial=trial,
1800
+ accumulate_dtype=accumulate_dtype,
1801
+ device=device,
1802
+ kernel_options=kernel_options,
1803
+ )
1804
+
1716
1805
  return _launch_integrate_kernel(
1717
1806
  integrand=integrand,
1718
1807
  kernel=kernel,
1808
+ auxiliary_kernels=auxiliary_kernels,
1719
1809
  FieldStruct=FieldStruct,
1720
1810
  ValueStruct=ValueStruct,
1721
1811
  domain=domain,
@@ -2207,6 +2297,9 @@ def _launch_interpolate_kernel(
2207
2297
  return
2208
2298
 
2209
2299
  if quadrature is None:
2300
+ if dest is not None and (not is_array(dest) or dest.shape[0] != dim):
2301
+ raise ValueError(f"dest must be a warp array with {dim} rows")
2302
+
2210
2303
  wp.launch(
2211
2304
  kernel=kernel,
2212
2305
  dim=dim,
@@ -2216,21 +2309,34 @@ def _launch_interpolate_kernel(
2216
2309
  return
2217
2310
 
2218
2311
  qp_arg = quadrature.arg_value(device)
2312
+ qp_eval_count = quadrature.evaluation_point_count()
2313
+ qp_index_count = quadrature.total_point_count()
2314
+
2315
+ if qp_eval_count != qp_index_count:
2316
+ wp.utils.warn(
2317
+ f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
2318
+ category=UserWarning,
2319
+ stacklevel=2,
2320
+ )
2321
+
2219
2322
  qp_element_index_arg = quadrature.element_index_arg_value(device)
2220
2323
  if trial is None:
2324
+ if dest is not None and (not is_array(dest) or dest.shape[0] != qp_index_count):
2325
+ raise ValueError(f"dest must be a warp array with {qp_index_count} rows")
2326
+
2221
2327
  wp.launch(
2222
2328
  kernel=kernel,
2223
- dim=quadrature.evaluation_point_count(),
2329
+ dim=qp_eval_count,
2224
2330
  inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2225
2331
  device=device,
2226
2332
  )
2227
2333
  return
2228
2334
 
2229
- nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
2335
+ nnz = qp_eval_count * trial.space.topology.MAX_NODES_PER_ELEMENT
2230
2336
 
2231
- if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
2337
+ if dest.nrow != qp_index_count or dest.ncol != trial.space_partition.node_count():
2232
2338
  raise RuntimeError(
2233
- f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
2339
+ f"'dest' matrix must have {qp_index_count} rows and {trial.space_partition.node_count()} columns of blocks"
2234
2340
  )
2235
2341
  if dest.block_shape[1] != trial.node_dof_count:
2236
2342
  raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
@@ -2324,6 +2430,9 @@ def interpolate(
2324
2430
  if values is None:
2325
2431
  values = {}
2326
2432
 
2433
+ if device is None:
2434
+ device = wp.get_device()
2435
+
2327
2436
  if not isinstance(integrand, Integrand):
2328
2437
  raise ValueError("integrand must be tagged with @integrand decorator")
2329
2438
 
@@ -159,6 +159,10 @@ class SpaceRestriction:
159
159
  def node_partition_index(args: NodeArg, restriction_node_index: int):
160
160
  return args.dof_partition_indices[restriction_node_index]
161
161
 
162
+ @wp.func
163
+ def node_partition_index_from_element_offset(args: NodeArg, element_offset: int):
164
+ return wp.lower_bound(args.dof_element_offsets, element_offset + 1) - 1
165
+
162
166
  @wp.func
163
167
  def node_element_range(args: NodeArg, partition_node_index: int):
164
168
  return args.dof_element_offsets[partition_node_index], args.dof_element_offsets[partition_node_index + 1]
@@ -168,19 +168,12 @@ class TetrahedronPolynomialShapeFunctions(TetrahedronShapeFunction):
168
168
 
169
169
  self.VERTEX_NODE_COUNT = wp.constant(1)
170
170
  self.EDGE_NODE_COUNT = wp.constant(degree - 1)
171
+ self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
172
+ self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
173
+
171
174
  self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) * (degree + 3) // 6)
172
175
  self.NODES_PER_SIDE = wp.constant((degree + 1) * (degree + 2) // 2)
173
176
 
174
- self.SIDE_NODE_COUNT = wp.constant(self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT))
175
- self.INTERIOR_NODE_COUNT = wp.constant(
176
- self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT)
177
- )
178
-
179
- self.VERTEX_NODE_COUNT = wp.constant(1)
180
- self.EDGE_NODE_COUNT = wp.constant(degree - 1)
181
- self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
182
- self.INERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
183
-
184
177
  tet_coords = np.empty((self.NODES_PER_ELEMENT, 3), dtype=int)
185
178
 
186
179
  for tx in range(degree + 1):
@@ -107,7 +107,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
107
107
  assert hooks.forward, "Failed to find kernel entry point"
108
108
 
109
109
  # Launch the kernel.
110
- wp.context.runtime.core.cuda_launch_kernel(
110
+ wp.context.runtime.core.wp_cuda_launch_kernel(
111
111
  device.context, hooks.forward, bounds.size, 0, 256, hooks.forward_smem_bytes, kernel_params, stream
112
112
  )
113
113