warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (134) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +47 -67
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +312 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1249 -784
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/fabric.py +1 -1
  18. warp/fem/cache.py +27 -19
  19. warp/fem/domain.py +2 -2
  20. warp/fem/field/nodal_field.py +2 -2
  21. warp/fem/field/virtual.py +264 -166
  22. warp/fem/geometry/geometry.py +5 -5
  23. warp/fem/integrate.py +129 -51
  24. warp/fem/space/restriction.py +4 -0
  25. warp/fem/space/shape/tet_shape_function.py +3 -10
  26. warp/jax_experimental/custom_call.py +1 -1
  27. warp/jax_experimental/ffi.py +2 -1
  28. warp/marching_cubes.py +708 -0
  29. warp/native/array.h +99 -4
  30. warp/native/builtin.h +82 -5
  31. warp/native/bvh.cpp +64 -28
  32. warp/native/bvh.cu +58 -58
  33. warp/native/bvh.h +2 -2
  34. warp/native/clang/clang.cpp +7 -7
  35. warp/native/coloring.cpp +8 -2
  36. warp/native/crt.cpp +2 -2
  37. warp/native/crt.h +3 -5
  38. warp/native/cuda_util.cpp +41 -10
  39. warp/native/cuda_util.h +10 -4
  40. warp/native/exports.h +1842 -1908
  41. warp/native/fabric.h +2 -1
  42. warp/native/hashgrid.cpp +37 -37
  43. warp/native/hashgrid.cu +2 -2
  44. warp/native/initializer_array.h +1 -1
  45. warp/native/intersect.h +2 -2
  46. warp/native/mat.h +1910 -116
  47. warp/native/mathdx.cpp +43 -43
  48. warp/native/mesh.cpp +24 -24
  49. warp/native/mesh.cu +26 -26
  50. warp/native/mesh.h +4 -2
  51. warp/native/nanovdb/GridHandle.h +179 -12
  52. warp/native/nanovdb/HostBuffer.h +8 -7
  53. warp/native/nanovdb/NanoVDB.h +517 -895
  54. warp/native/nanovdb/NodeManager.h +323 -0
  55. warp/native/nanovdb/PNanoVDB.h +2 -2
  56. warp/native/quat.h +331 -14
  57. warp/native/range.h +7 -1
  58. warp/native/reduce.cpp +10 -10
  59. warp/native/reduce.cu +13 -14
  60. warp/native/runlength_encode.cpp +2 -2
  61. warp/native/runlength_encode.cu +5 -5
  62. warp/native/scan.cpp +3 -3
  63. warp/native/scan.cu +4 -4
  64. warp/native/sort.cpp +10 -10
  65. warp/native/sort.cu +22 -22
  66. warp/native/sparse.cpp +8 -8
  67. warp/native/sparse.cu +13 -13
  68. warp/native/spatial.h +366 -17
  69. warp/native/temp_buffer.h +2 -2
  70. warp/native/tile.h +283 -69
  71. warp/native/vec.h +381 -14
  72. warp/native/volume.cpp +54 -54
  73. warp/native/volume.cu +1 -1
  74. warp/native/volume.h +2 -1
  75. warp/native/volume_builder.cu +30 -37
  76. warp/native/warp.cpp +150 -149
  77. warp/native/warp.cu +323 -192
  78. warp/native/warp.h +227 -226
  79. warp/optim/linear.py +736 -271
  80. warp/render/imgui_manager.py +289 -0
  81. warp/render/render_opengl.py +85 -6
  82. warp/sim/graph_coloring.py +2 -2
  83. warp/sparse.py +558 -175
  84. warp/tests/aux_test_module_aot.py +7 -0
  85. warp/tests/cuda/test_async.py +3 -3
  86. warp/tests/cuda/test_conditional_captures.py +101 -0
  87. warp/tests/geometry/test_marching_cubes.py +233 -12
  88. warp/tests/sim/test_coloring.py +6 -6
  89. warp/tests/test_array.py +56 -5
  90. warp/tests/test_codegen.py +3 -2
  91. warp/tests/test_context.py +8 -15
  92. warp/tests/test_enum.py +136 -0
  93. warp/tests/test_examples.py +2 -2
  94. warp/tests/test_fem.py +45 -2
  95. warp/tests/test_fixedarray.py +229 -0
  96. warp/tests/test_func.py +18 -15
  97. warp/tests/test_future_annotations.py +7 -5
  98. warp/tests/test_linear_solvers.py +30 -0
  99. warp/tests/test_map.py +1 -1
  100. warp/tests/test_mat.py +1518 -378
  101. warp/tests/test_mat_assign_copy.py +178 -0
  102. warp/tests/test_mat_constructors.py +574 -0
  103. warp/tests/test_module_aot.py +287 -0
  104. warp/tests/test_print.py +69 -0
  105. warp/tests/test_quat.py +140 -34
  106. warp/tests/test_quat_assign_copy.py +145 -0
  107. warp/tests/test_reload.py +2 -1
  108. warp/tests/test_sparse.py +71 -0
  109. warp/tests/test_spatial.py +140 -34
  110. warp/tests/test_spatial_assign_copy.py +160 -0
  111. warp/tests/test_struct.py +43 -3
  112. warp/tests/test_types.py +0 -20
  113. warp/tests/test_vec.py +179 -34
  114. warp/tests/test_vec_assign_copy.py +143 -0
  115. warp/tests/tile/test_tile.py +184 -18
  116. warp/tests/tile/test_tile_cholesky.py +605 -0
  117. warp/tests/tile/test_tile_load.py +169 -0
  118. warp/tests/tile/test_tile_mathdx.py +2 -558
  119. warp/tests/tile/test_tile_matmul.py +1 -1
  120. warp/tests/tile/test_tile_mlp.py +1 -1
  121. warp/tests/tile/test_tile_shared_memory.py +5 -5
  122. warp/tests/unittest_suites.py +6 -0
  123. warp/tests/walkthrough_debug.py +1 -1
  124. warp/thirdparty/unittest_parallel.py +108 -9
  125. warp/types.py +554 -264
  126. warp/utils.py +68 -86
  127. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  128. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
  129. warp/native/marching.cpp +0 -19
  130. warp/native/marching.cu +0 -514
  131. warp/native/marching.h +0 -19
  132. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  133. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -16,7 +16,7 @@
16
16
  import ast
17
17
  import inspect
18
18
  import textwrap
19
- from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
19
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
20
20
 
21
21
  import warp as wp
22
22
  import warp.fem.operator as operator
@@ -34,7 +34,10 @@ from warp.fem.field import (
34
34
  TrialField,
35
35
  make_restriction,
36
36
  )
37
- from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
37
+ from warp.fem.field.virtual import (
38
+ make_bilinear_dispatch_kernel,
39
+ make_linear_dispatch_kernel,
40
+ )
38
41
  from warp.fem.linalg import array_axpy, basis_coefficient
39
42
  from warp.fem.operator import (
40
43
  Integrand,
@@ -101,7 +104,8 @@ class IntegrandVisitor(ast.NodeTransformer):
101
104
  field: FieldLike
102
105
  abstract_type: type
103
106
  concrete_type: type
104
- root_arg_name: type
107
+ root_arg_name: str
108
+ local_arg_name: str
105
109
 
106
110
  def __init__(
107
111
  self,
@@ -111,6 +115,7 @@ class IntegrandVisitor(ast.NodeTransformer):
111
115
  self._integrand = integrand
112
116
  self._field_symbols = field_info.copy()
113
117
  self._field_nodes = {}
118
+ self._field_arg_annotation_nodes = {}
114
119
 
115
120
  @staticmethod
116
121
  def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
@@ -127,6 +132,7 @@ class IntegrandVisitor(ast.NodeTransformer):
127
132
  abstract_type=integrand.argspec.annotations[name],
128
133
  concrete_type=get_concrete_type(field),
129
134
  root_arg_name=name,
135
+ local_arg_name=name,
130
136
  )
131
137
  for name, field in field_args.items()
132
138
  }
@@ -167,6 +173,7 @@ class IntegrandVisitor(ast.NodeTransformer):
167
173
  field=res[0],
168
174
  abstract_type=res[1],
169
175
  concrete_type=res[2],
176
+ local_arg_name=field_info.local_arg_name,
170
177
  root_arg_name=f"{field_info.root_arg_name}.{func.name}",
171
178
  )
172
179
 
@@ -191,6 +198,13 @@ class IntegrandVisitor(ast.NodeTransformer):
191
198
 
192
199
  return node
193
200
 
201
+ def visit_FunctionDef(self, node: ast.FunctionDef):
202
+ # record field arg annotation nodes
203
+ for arg in node.args.args:
204
+ self._field_arg_annotation_nodes[arg.arg] = arg.annotation
205
+
206
+ return self.generic_visit(node)
207
+
194
208
  def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
195
209
  # Get field types for call site arguments
196
210
  call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
@@ -211,7 +225,13 @@ class IntegrandVisitor(ast.NodeTransformer):
211
225
  raise TypeError(
212
226
  f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
213
227
  )
214
- callee_field_args[arg] = passed_field_info
228
+ callee_field_args[arg] = IntegrandVisitor.FieldInfo(
229
+ field=passed_field_info.field,
230
+ abstract_type=passed_field_info.abstract_type,
231
+ concrete_type=passed_field_info.concrete_type,
232
+ local_arg_name=arg,
233
+ root_arg_name=passed_field_info.root_arg_name,
234
+ )
215
235
 
216
236
  return callee_field_args
217
237
 
@@ -263,18 +283,14 @@ class IntegrandTransformer(IntegrandVisitor):
263
283
  f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
264
284
  ) from e
265
285
 
266
- # Update the ast Call node to use the new function pointer
267
- call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
268
-
269
286
  # Save the pointer as an attribute than can be accessed from the calling scope
270
- # For usual operator call syntax, we can use the operator itself, but for the
271
- # shortcut default operator syntax, we store it on the callee's concrete type
272
- if isinstance(callee, Operator):
273
- setattr(callee, pointer.key, pointer)
274
- else:
275
- setattr(field_info.concrete_type, pointer.key, pointer)
287
+ # (use the annotation node of the argument this field is constructed from)
288
+ callee_node = self._field_arg_annotation_nodes[field_info.local_arg_name]
289
+ setattr(self._field_symbols[field_info.local_arg_name].abstract_type, pointer.key, pointer)
290
+ call.func = ast.Attribute(value=callee_node, attr=pointer.key, ctx=ast.Load())
276
291
 
277
- # also insert callee as first argument
292
+ # For shortcut default operator syntax, insert callee as first argument
293
+ if not isinstance(callee, Operator):
278
294
  call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
279
295
 
280
296
  # replace first argument with selected attribute
@@ -592,6 +608,9 @@ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_s
592
608
  return options
593
609
 
594
610
 
611
+ _INTEGRATE_CONSTANT_TILE_SIZE = 256
612
+
613
+
595
614
  def get_integrate_constant_kernel(
596
615
  integrand_func: wp.Function,
597
616
  domain: GeometryDomain,
@@ -599,8 +618,12 @@ def get_integrate_constant_kernel(
599
618
  FieldStruct: wp.codegen.Struct,
600
619
  ValueStruct: wp.codegen.Struct,
601
620
  accumulate_dtype,
621
+ tile_size: int = _INTEGRATE_CONSTANT_TILE_SIZE,
602
622
  ):
623
+ zero_element = type_zero_element(accumulate_dtype)
624
+
603
625
  def integrate_kernel_fn(
626
+ qp_count: int,
604
627
  qp_arg: quadrature.Arg,
605
628
  qp_element_index_arg: quadrature.ElementIndexArg,
606
629
  domain_arg: domain.ElementArg,
@@ -609,26 +632,33 @@ def get_integrate_constant_kernel(
609
632
  values: ValueStruct,
610
633
  result: wp.array(dtype=accumulate_dtype),
611
634
  ):
612
- qp_eval_index = wp.tid()
613
- domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
614
- if domain_element_index == NULL_ELEMENT_INDEX:
615
- return
635
+ block_index, lane = wp.tid()
636
+ qp_eval_index = block_index * tile_size + lane
616
637
 
617
- element_index = domain.element_index(domain_index_arg, domain_element_index)
638
+ if qp_eval_index >= qp_count:
639
+ domain_element_index, qp = NULL_ELEMENT_INDEX, 0
640
+ else:
641
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
618
642
 
619
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
620
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
621
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
643
+ if domain_element_index == NULL_ELEMENT_INDEX:
644
+ val = zero_element()
645
+ else:
646
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
622
647
 
623
- test_dof_index = NULL_DOF_INDEX
624
- trial_dof_index = NULL_DOF_INDEX
648
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
649
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
650
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
625
651
 
626
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
627
- vol = domain.element_measure(domain_arg, sample)
652
+ test_dof_index = NULL_DOF_INDEX
653
+ trial_dof_index = NULL_DOF_INDEX
628
654
 
629
- val = integrand_func(sample, fields, values)
655
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
656
+ vol = domain.element_measure(domain_arg, sample)
630
657
 
631
- wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
658
+ val = accumulate_dtype(qp_weight * vol * integrand_func(sample, fields, values))
659
+
660
+ tile_integral = wp.tile_sum(wp.tile(val))
661
+ wp.tile_atomic_add(result, tile_integral, offset=0)
632
662
 
633
663
  return integrate_kernel_fn
634
664
 
@@ -1020,7 +1050,7 @@ def get_integrate_bilinear_local_kernel(
1020
1050
 
1021
1051
  sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1022
1052
  val = integrand_func(sample, fields, values)
1023
- result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
1053
+ result[test_dof, trial_dof, qp_eval_index, taylor_dof] = qp_vol * val
1024
1054
 
1025
1055
  return integrate_kernel_fn
1026
1056
 
@@ -1150,9 +1180,46 @@ def _generate_integrate_kernel(
1150
1180
  return kernel, FieldStruct, ValueStruct
1151
1181
 
1152
1182
 
1183
+ def _generate_auxiliary_kernels(
1184
+ quadrature: Quadrature,
1185
+ test: Optional[TestField],
1186
+ trial: Optional[TrialField],
1187
+ accumulate_dtype: type,
1188
+ device,
1189
+ kernel_options: Optional[Dict[str, Any]] = None,
1190
+ ) -> List[Tuple[wp.Kernel, int]]:
1191
+ if test is None or not isinstance(test, LocalTestField):
1192
+ return ()
1193
+
1194
+ # For dispatched assembly, generate additional kernels
1195
+ # heuristic to use tiles for "long" quadratures
1196
+ dispatch_tile_size = 32
1197
+ qp_eval_count = quadrature.evaluation_point_count()
1198
+
1199
+ if trial is None:
1200
+ if (
1201
+ not device.is_cuda
1202
+ or qp_eval_count * test.space_restriction.total_node_element_count()
1203
+ < 3 * dispatch_tile_size * test.space_restriction.node_count() * test.domain.element_count()
1204
+ ):
1205
+ dispatch_tile_size = 1
1206
+ dispatch_kernel = make_linear_dispatch_kernel(
1207
+ test, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
1208
+ )
1209
+ else:
1210
+ if not device.is_cuda or qp_eval_count < 3 * dispatch_tile_size * test.domain.element_count():
1211
+ dispatch_tile_size = 1
1212
+ dispatch_kernel = make_bilinear_dispatch_kernel(
1213
+ test, trial, quadrature, accumulate_dtype, dispatch_tile_size, kernel_options
1214
+ )
1215
+
1216
+ return ((dispatch_kernel, dispatch_tile_size),)
1217
+
1218
+
1153
1219
  def _launch_integrate_kernel(
1154
1220
  integrand: Integrand,
1155
1221
  kernel: wp.Kernel,
1222
+ auxiliary_kernels: List[Tuple[wp.Kernel, int]],
1156
1223
  FieldStruct: wp.codegen.Struct,
1157
1224
  ValueStruct: wp.codegen.Struct,
1158
1225
  domain: GeometryDomain,
@@ -1202,10 +1269,15 @@ def _launch_integrate_kernel(
1202
1269
  if output != accumulate_array or not add_to_output:
1203
1270
  accumulate_array.zero_()
1204
1271
 
1272
+ qp_count = quadrature.evaluation_point_count()
1273
+ tile_size = _INTEGRATE_CONSTANT_TILE_SIZE
1274
+ block_count = (qp_count + tile_size - 1) // tile_size
1205
1275
  wp.launch(
1206
1276
  kernel=kernel,
1207
- dim=quadrature.evaluation_point_count(),
1277
+ dim=(block_count, tile_size),
1278
+ block_dim=tile_size,
1208
1279
  inputs=[
1280
+ qp_count,
1209
1281
  qp_arg,
1210
1282
  quadrature.element_index_arg_value(device),
1211
1283
  domain_elt_arg,
@@ -1335,10 +1407,11 @@ def _launch_integrate_kernel(
1335
1407
  stacklevel=2,
1336
1408
  )
1337
1409
  else:
1338
- dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
1410
+ dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
1339
1411
  wp.launch(
1340
1412
  kernel=dispatch_kernel,
1341
- dim=(test.space_restriction.node_count(), test.node_dof_count),
1413
+ dim=(test.space_restriction.node_count(), dispatch_tile_size),
1414
+ block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
1342
1415
  inputs=[
1343
1416
  qp_arg,
1344
1417
  domain_elt_arg,
@@ -1422,14 +1495,15 @@ def _launch_integrate_kernel(
1422
1495
  device=device,
1423
1496
  )
1424
1497
  elif isinstance(test, LocalTestField):
1498
+ qp_eval_count = quadrature.evaluation_point_count()
1425
1499
  local_result = cache.borrow_temporary(
1426
1500
  temporary_store=temporary_store,
1427
1501
  device=device,
1428
1502
  requires_grad=False,
1429
1503
  shape=(
1430
- quadrature.evaluation_point_count(),
1431
1504
  test.value_dof_count,
1432
1505
  trial.value_dof_count,
1506
+ qp_eval_count,
1433
1507
  test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
1434
1508
  ),
1435
1509
  dtype=float,
@@ -1438,7 +1512,7 @@ def _launch_integrate_kernel(
1438
1512
  wp.launch(
1439
1513
  kernel=kernel,
1440
1514
  dim=(
1441
- quadrature.evaluation_point_count(),
1515
+ qp_eval_count,
1442
1516
  test.value_dof_count,
1443
1517
  trial.value_dof_count,
1444
1518
  trial.TAYLOR_DOF_COUNT,
@@ -1455,17 +1529,6 @@ def _launch_integrate_kernel(
1455
1529
  device=device,
1456
1530
  )
1457
1531
 
1458
- vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
1459
- vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
1460
- local_result_as_vec = wp.array(
1461
- data=None,
1462
- ptr=local_result.array.ptr,
1463
- capacity=local_result.array.capacity,
1464
- device=local_result.array.device,
1465
- shape=vec_array_shape,
1466
- dtype=vec_array_dtype,
1467
- )
1468
-
1469
1532
  if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
1470
1533
  wp.utils.warn(
1471
1534
  f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
@@ -1474,18 +1537,17 @@ def _launch_integrate_kernel(
1474
1537
  )
1475
1538
  triplet_rows.fill_(-1)
1476
1539
  else:
1477
- dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
1478
-
1540
+ dispatch_kernel, dispatch_tile_size = auxiliary_kernels[0]
1479
1541
  trial_partition_arg = trial.space_partition.partition_arg_value(device)
1480
1542
  trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1481
1543
  wp.launch(
1482
1544
  kernel=dispatch_kernel,
1483
1545
  dim=(
1484
- test.space_restriction.node_count(),
1485
- test.node_dof_count,
1486
- trial.node_dof_count,
1546
+ test.space_restriction.total_node_element_count(),
1487
1547
  trial.space.topology.MAX_NODES_PER_ELEMENT,
1548
+ dispatch_tile_size,
1488
1549
  ),
1550
+ block_dim=dispatch_tile_size if dispatch_tile_size > 1 else 256,
1489
1551
  inputs=[
1490
1552
  qp_arg,
1491
1553
  domain_elt_arg,
@@ -1495,7 +1557,7 @@ def _launch_integrate_kernel(
1495
1557
  trial_partition_arg,
1496
1558
  trial_topology_arg,
1497
1559
  trial.space.space_arg_value(device),
1498
- local_result_as_vec,
1560
+ local_result.array,
1499
1561
  triplet_rows,
1500
1562
  triplet_cols,
1501
1563
  triplet_values,
@@ -1636,6 +1698,9 @@ def integrate(
1636
1698
  if values is None:
1637
1699
  values = {}
1638
1700
 
1701
+ if device is None:
1702
+ device = wp.get_device()
1703
+
1639
1704
  if not isinstance(integrand, Integrand):
1640
1705
  raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1641
1706
 
@@ -1728,9 +1793,19 @@ def integrate(
1728
1793
  kernel_options=kernel_options,
1729
1794
  )
1730
1795
 
1796
+ auxiliary_kernels = _generate_auxiliary_kernels(
1797
+ quadrature=quadrature,
1798
+ test=test,
1799
+ trial=trial,
1800
+ accumulate_dtype=accumulate_dtype,
1801
+ device=device,
1802
+ kernel_options=kernel_options,
1803
+ )
1804
+
1731
1805
  return _launch_integrate_kernel(
1732
1806
  integrand=integrand,
1733
1807
  kernel=kernel,
1808
+ auxiliary_kernels=auxiliary_kernels,
1734
1809
  FieldStruct=FieldStruct,
1735
1810
  ValueStruct=ValueStruct,
1736
1811
  domain=domain,
@@ -2355,6 +2430,9 @@ def interpolate(
2355
2430
  if values is None:
2356
2431
  values = {}
2357
2432
 
2433
+ if device is None:
2434
+ device = wp.get_device()
2435
+
2358
2436
  if not isinstance(integrand, Integrand):
2359
2437
  raise ValueError("integrand must be tagged with @integrand decorator")
2360
2438
 
@@ -159,6 +159,10 @@ class SpaceRestriction:
159
159
  def node_partition_index(args: NodeArg, restriction_node_index: int):
160
160
  return args.dof_partition_indices[restriction_node_index]
161
161
 
162
+ @wp.func
163
+ def node_partition_index_from_element_offset(args: NodeArg, element_offset: int):
164
+ return wp.lower_bound(args.dof_element_offsets, element_offset + 1) - 1
165
+
162
166
  @wp.func
163
167
  def node_element_range(args: NodeArg, partition_node_index: int):
164
168
  return args.dof_element_offsets[partition_node_index], args.dof_element_offsets[partition_node_index + 1]
@@ -168,19 +168,12 @@ class TetrahedronPolynomialShapeFunctions(TetrahedronShapeFunction):
168
168
 
169
169
  self.VERTEX_NODE_COUNT = wp.constant(1)
170
170
  self.EDGE_NODE_COUNT = wp.constant(degree - 1)
171
+ self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
172
+ self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
173
+
171
174
  self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) * (degree + 3) // 6)
172
175
  self.NODES_PER_SIDE = wp.constant((degree + 1) * (degree + 2) // 2)
173
176
 
174
- self.SIDE_NODE_COUNT = wp.constant(self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT))
175
- self.INTERIOR_NODE_COUNT = wp.constant(
176
- self.NODES_PER_ELEMENT - 3 * (self.VERTEX_NODE_COUNT + self.EDGE_NODE_COUNT)
177
- )
178
-
179
- self.VERTEX_NODE_COUNT = wp.constant(1)
180
- self.EDGE_NODE_COUNT = wp.constant(degree - 1)
181
- self.FACE_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
182
- self.INERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) * max(0, degree - 2) * max(0, degree - 3) // 6)
183
-
184
177
  tet_coords = np.empty((self.NODES_PER_ELEMENT, 3), dtype=int)
185
178
 
186
179
  for tx in range(degree + 1):
@@ -107,7 +107,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
107
107
  assert hooks.forward, "Failed to find kernel entry point"
108
108
 
109
109
  # Launch the kernel.
110
- wp.context.runtime.core.cuda_launch_kernel(
110
+ wp.context.runtime.core.wp_cuda_launch_kernel(
111
111
  device.context, hooks.forward, bounds.size, 0, 256, hooks.forward_smem_bytes, kernel_params, stream
112
112
  )
113
113
 
@@ -317,7 +317,7 @@ class FfiKernel:
317
317
  assert hooks.forward, "Failed to find kernel entry point"
318
318
 
319
319
  # launch the kernel
320
- wp.context.runtime.core.cuda_launch_kernel(
320
+ wp.context.runtime.core.wp_cuda_launch_kernel(
321
321
  device.context,
322
322
  hooks.forward,
323
323
  launch_bounds.size,
@@ -381,6 +381,7 @@ class FfiCallable:
381
381
  if arg_name == "return":
382
382
  if arg_type is not None:
383
383
  raise TypeError("Function must not return a value")
384
+ continue
384
385
  else:
385
386
  arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
386
387
  if arg_name in in_out_argnames: