warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -19,6 +19,7 @@ import textwrap
19
19
  from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
20
20
 
21
21
  import warp as wp
22
+ import warp.fem.operator as operator
22
23
  from warp.codegen import get_annotations
23
24
  from warp.fem import cache
24
25
  from warp.fem.domain import GeometryDomain
@@ -35,7 +36,11 @@ from warp.fem.field import (
35
36
  )
36
37
  from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
37
38
  from warp.fem.linalg import array_axpy, basis_coefficient
38
- from warp.fem.operator import Integrand, Operator, at_node, integrand
39
+ from warp.fem.operator import (
40
+ Integrand,
41
+ Operator,
42
+ integrand,
43
+ )
39
44
  from warp.fem.quadrature import Quadrature, RegularQuadrature
40
45
  from warp.fem.types import (
41
46
  NULL_DOF_INDEX,
@@ -49,8 +54,9 @@ from warp.fem.types import (
49
54
  Sample,
50
55
  make_free_sample,
51
56
  )
57
+ from warp.fem.utils import type_zero_element
52
58
  from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
53
- from warp.types import type_length
59
+ from warp.types import is_array, type_size
54
60
  from warp.utils import array_cast
55
61
 
56
62
 
@@ -111,6 +117,8 @@ class IntegrandVisitor(ast.NodeTransformer):
111
117
  def get_concrete_type(field: Union[FieldLike, Domain]):
112
118
  if isinstance(field, FieldLike):
113
119
  return field.ElementEvalArg
120
+ elif isinstance(field, GeometryDomain):
121
+ return field.DomainArg
114
122
  return field.ElementArg
115
123
 
116
124
  return {
@@ -232,7 +240,7 @@ class IntegrandOperatorParser(IntegrandVisitor):
232
240
 
233
241
  @staticmethod
234
242
  def apply(
235
- integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
243
+ integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Optional[Callable] = None
236
244
  ) -> wp.Function:
237
245
  field_info = IntegrandVisitor._build_field_info(integrand, field_args)
238
246
  IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
@@ -267,7 +275,11 @@ class IntegrandTransformer(IntegrandVisitor):
267
275
  setattr(field_info.concrete_type, pointer.key, pointer)
268
276
 
269
277
  # also insert callee as first argument
270
- call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
278
+ call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
279
+
280
+ # replace first argument with selected attribute
281
+ if operator.attr:
282
+ call.args[0] = ast.Attribute(value=call.args[0], attr=operator.attr)
271
283
 
272
284
  def _process_integrand_call(
273
285
  self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
@@ -456,6 +468,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
456
468
  fields_var_name: str = "fields",
457
469
  values_var_name: str = "values",
458
470
  domain_var_name: str = "domain_arg",
471
+ domain_index_var_name: str = "domain_index_arg",
459
472
  sample_var_name: str = "sample",
460
473
  field_wrappers_attr: str = "_field_wrappers",
461
474
  ):
@@ -470,6 +483,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
470
483
  self._fields_var_name = fields_var_name
471
484
  self._values_var_name = values_var_name
472
485
  self._domain_var_name = domain_var_name
486
+ self._domain_index_var_name = domain_index_var_name
473
487
  self._sample_var_name = sample_var_name
474
488
 
475
489
  self._field_wrappers_attr = field_wrappers_attr
@@ -485,8 +499,28 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
485
499
  for name, field in fields.items():
486
500
  if isinstance(field, FieldLike):
487
501
  setattr(field_wrappers, name, field.ElementEvalArg)
502
+ elif isinstance(field, GeometryDomain):
503
+ setattr(field_wrappers, name, field.DomainArg)
488
504
  setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
489
505
 
506
+ def _emit_field_wrapper_call(self, field_name, *data_arguments):
507
+ return ast.Call(
508
+ func=ast.Attribute(
509
+ value=ast.Attribute(
510
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
511
+ attr=self._field_wrappers_attr,
512
+ ctx=ast.Load(),
513
+ ),
514
+ attr=field_name,
515
+ ctx=ast.Load(),
516
+ ),
517
+ args=[
518
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
519
+ *data_arguments,
520
+ ],
521
+ keywords=[],
522
+ )
523
+
490
524
  def visit_Call(self, call: ast.Call):
491
525
  call = self.generic_visit(call)
492
526
 
@@ -498,33 +532,25 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
498
532
  for arg in self._arg_names:
499
533
  if arg == self._domain_name:
500
534
  call.args.append(
501
- ast.Name(id=self._domain_var_name, ctx=ast.Load()),
535
+ self._emit_field_wrapper_call(
536
+ arg,
537
+ ast.Name(id=self._domain_index_var_name, ctx=ast.Load()),
538
+ )
502
539
  )
540
+
503
541
  elif arg == self._sample_name:
504
542
  call.args.append(
505
543
  ast.Name(id=self._sample_var_name, ctx=ast.Load()),
506
544
  )
507
545
  elif arg in self._field_args:
508
546
  call.args.append(
509
- ast.Call(
510
- func=ast.Attribute(
511
- value=ast.Attribute(
512
- value=ast.Name(id=self._func_name, ctx=ast.Load()),
513
- attr=self._field_wrappers_attr,
514
- ctx=ast.Load(),
515
- ),
547
+ self._emit_field_wrapper_call(
548
+ arg,
549
+ ast.Attribute(
550
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
516
551
  attr=arg,
517
552
  ctx=ast.Load(),
518
553
  ),
519
- args=[
520
- ast.Name(id=self._domain_var_name, ctx=ast.Load()),
521
- ast.Attribute(
522
- value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
523
- attr=arg,
524
- ctx=ast.Load(),
525
- ),
526
- ],
527
- keywords=[],
528
554
  )
529
555
  )
530
556
  elif arg in self._value_args:
@@ -704,7 +730,7 @@ def get_integrate_linear_nodal_kernel(
704
730
 
705
731
  coords = test.space.node_coords_in_element(
706
732
  domain_arg,
707
- _get_test_arg(),
733
+ _get_test_arg().space_arg,
708
734
  element_index,
709
735
  node_element_index.node_index_in_element,
710
736
  )
@@ -712,7 +738,7 @@ def get_integrate_linear_nodal_kernel(
712
738
  if coords[0] != OUTSIDE:
713
739
  node_weight = test.space.node_quadrature_weight(
714
740
  domain_arg,
715
- _get_test_arg(),
741
+ _get_test_arg().space_arg,
716
742
  element_index,
717
743
  node_element_index.node_index_in_element,
718
744
  )
@@ -913,7 +939,7 @@ def get_integrate_bilinear_nodal_kernel(
913
939
 
914
940
  coords = test.space.node_coords_in_element(
915
941
  domain_arg,
916
- _get_test_arg(),
942
+ _get_test_arg().space_arg,
917
943
  element_index,
918
944
  node_element_index.node_index_in_element,
919
945
  )
@@ -921,7 +947,7 @@ def get_integrate_bilinear_nodal_kernel(
921
947
  if coords[0] != OUTSIDE:
922
948
  node_weight = test.space.node_quadrature_weight(
923
949
  domain_arg,
924
- _get_test_arg(),
950
+ _get_test_arg().space_arg,
925
951
  element_index,
926
952
  node_element_index.node_index_in_element,
927
953
  )
@@ -1153,7 +1179,7 @@ def _launch_integrate_kernel(
1153
1179
  field_arg_values = FieldStruct()
1154
1180
  for k, v in fields.items():
1155
1181
  if not isinstance(v, GeometryDomain):
1156
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
1182
+ v.fill_eval_arg(getattr(field_arg_values, k), device=device)
1157
1183
 
1158
1184
  value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1159
1185
 
@@ -1203,7 +1229,7 @@ def _launch_integrate_kernel(
1203
1229
  array_cast(in_array=accumulate_array, out_array=output)
1204
1230
  return output
1205
1231
 
1206
- test_arg = test.space_restriction.node_arg(device=device)
1232
+ test_arg = test.space_restriction.node_arg_value(device=device)
1207
1233
  nodal = quadrature is None
1208
1234
 
1209
1235
  # Linear form
@@ -1211,9 +1237,9 @@ def _launch_integrate_kernel(
1211
1237
  # If an output array is provided with the correct type, accumulate directly into it
1212
1238
  # Otherwise, grab a temporary array
1213
1239
  if output is None:
1214
- if type_length(output_dtype) == test.node_dof_count:
1240
+ if type_size(output_dtype) == test.node_dof_count:
1215
1241
  output_shape = (test.space_partition.node_count(),)
1216
- elif type_length(output_dtype) == 1:
1242
+ elif type_size(output_dtype) == 1:
1217
1243
  output_shape = (test.space_partition.node_count(), test.node_dof_count)
1218
1244
  else:
1219
1245
  raise RuntimeError(
@@ -1236,8 +1262,8 @@ def _launch_integrate_kernel(
1236
1262
  raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
1237
1263
 
1238
1264
  output_dtype = output.dtype
1239
- if type_length(output_dtype) != test.node_dof_count:
1240
- if type_length(output_dtype) != 1:
1265
+ if type_size(output_dtype) != test.node_dof_count:
1266
+ if type_size(output_dtype) != 1:
1241
1267
  raise RuntimeError(
1242
1268
  f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1243
1269
  )
@@ -1302,21 +1328,28 @@ def _launch_integrate_kernel(
1302
1328
  device=device,
1303
1329
  )
1304
1330
 
1305
- dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
1306
- wp.launch(
1307
- kernel=dispatch_kernel,
1308
- dim=(test.space_restriction.node_count(), test.node_dof_count),
1309
- inputs=[
1310
- qp_arg,
1311
- domain_elt_arg,
1312
- domain_elt_index_arg,
1313
- test_arg,
1314
- test.global_field.eval_arg_value(device),
1315
- local_result.array,
1316
- output_view,
1317
- ],
1318
- device=device,
1319
- )
1331
+ if test.TAYLOR_DOF_COUNT == 0:
1332
+ wp.utils.warn(
1333
+ f"Test field is never evaluated in integrand '{integrand.name}', result will be zero",
1334
+ category=UserWarning,
1335
+ stacklevel=2,
1336
+ )
1337
+ else:
1338
+ dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
1339
+ wp.launch(
1340
+ kernel=dispatch_kernel,
1341
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1342
+ inputs=[
1343
+ qp_arg,
1344
+ domain_elt_arg,
1345
+ domain_elt_index_arg,
1346
+ test_arg,
1347
+ test.space.space_arg_value(device),
1348
+ local_result.array,
1349
+ output_view,
1350
+ ],
1351
+ device=device,
1352
+ )
1320
1353
 
1321
1354
  local_result.release()
1322
1355
 
@@ -1433,34 +1466,42 @@ def _launch_integrate_kernel(
1433
1466
  dtype=vec_array_dtype,
1434
1467
  )
1435
1468
 
1436
- dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
1469
+ if test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT == 0:
1470
+ wp.utils.warn(
1471
+ f"Test and/or trial fields are never evaluated in integrand '{integrand.name}', result will be zero",
1472
+ category=UserWarning,
1473
+ stacklevel=2,
1474
+ )
1475
+ triplet_rows.fill_(-1)
1476
+ else:
1477
+ dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
1437
1478
 
1438
- trial_partition_arg = trial.space_partition.partition_arg_value(device)
1439
- trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1440
- wp.launch(
1441
- kernel=dispatch_kernel,
1442
- dim=(
1443
- test.space_restriction.node_count(),
1444
- test.node_dof_count,
1445
- trial.node_dof_count,
1446
- trial.space.topology.MAX_NODES_PER_ELEMENT,
1447
- ),
1448
- inputs=[
1449
- qp_arg,
1450
- domain_elt_arg,
1451
- domain_elt_index_arg,
1452
- test_arg,
1453
- test.global_field.eval_arg_value(device),
1454
- trial_partition_arg,
1455
- trial_topology_arg,
1456
- trial.global_field.eval_arg_value(device),
1457
- local_result_as_vec,
1458
- triplet_rows,
1459
- triplet_cols,
1460
- triplet_values,
1461
- ],
1462
- device=device,
1463
- )
1479
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1480
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1481
+ wp.launch(
1482
+ kernel=dispatch_kernel,
1483
+ dim=(
1484
+ test.space_restriction.node_count(),
1485
+ test.node_dof_count,
1486
+ trial.node_dof_count,
1487
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1488
+ ),
1489
+ inputs=[
1490
+ qp_arg,
1491
+ domain_elt_arg,
1492
+ domain_elt_index_arg,
1493
+ test_arg,
1494
+ test.space.space_arg_value(device),
1495
+ trial_partition_arg,
1496
+ trial_topology_arg,
1497
+ trial.space.space_arg_value(device),
1498
+ local_result_as_vec,
1499
+ triplet_rows,
1500
+ triplet_cols,
1501
+ triplet_values,
1502
+ ],
1503
+ device=device,
1504
+ )
1464
1505
 
1465
1506
  local_result.release()
1466
1507
 
@@ -1529,21 +1570,30 @@ def _pick_assembly_strategy(
1529
1570
  if assembly not in ("generic", "nodal", "dispatch"):
1530
1571
  raise ValueError(f"Invalid assembly strategy'{assembly}'")
1531
1572
  return assembly
1532
- elif nodal:
1533
- return "nodal"
1573
+ elif nodal is not None:
1574
+ wp.utils.warn(
1575
+ "'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
1576
+ category=DeprecationWarning,
1577
+ stacklevel=2,
1578
+ )
1579
+ if nodal:
1580
+ return "nodal"
1534
1581
 
1535
- test_operators = operators.get(arguments.test_name, {})
1536
- trial_operators = operators.get(arguments.trial_name, {})
1537
- uses_at_node = at_node in test_operators or at_node in trial_operators
1582
+ test_operators = operators.get(arguments.test_name, set())
1583
+ trial_operators = operators.get(arguments.trial_name, set())
1538
1584
 
1539
- return "generic" if uses_at_node else "dispatch"
1585
+ uses_virtual_node_operator = {operator.at_node, operator.node_count, operator.node_index} & (
1586
+ test_operators | trial_operators
1587
+ )
1588
+
1589
+ return "generic" if uses_virtual_node_operator else "dispatch"
1540
1590
 
1541
1591
 
1542
1592
  def integrate(
1543
1593
  integrand: Integrand,
1544
1594
  domain: Optional[GeometryDomain] = None,
1545
1595
  quadrature: Optional[Quadrature] = None,
1546
- nodal: bool = False,
1596
+ nodal: Optional[bool] = None,
1547
1597
  fields: Optional[Dict[str, FieldLike]] = None,
1548
1598
  values: Optional[Dict[str, Any]] = None,
1549
1599
  accumulate_dtype: type = wp.float64,
@@ -1575,7 +1625,7 @@ def integrate(
1575
1625
  assembly: Specifies the strategy for assembling the integrated vector or matrix:
1576
1626
  - "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1577
1627
  - "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
1578
- - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
1628
+ - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` and `node_index` operators on test or trial functions.
1579
1629
  - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1580
1630
  add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1581
1631
  bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
@@ -1622,6 +1672,9 @@ def integrate(
1622
1672
 
1623
1673
  _find_integrand_operators(integrand, arguments.field_args)
1624
1674
 
1675
+ if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
1676
+ wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1677
+
1625
1678
  assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
1626
1679
  # print("assembly for ", integrand.name, ":", strategy)
1627
1680
 
@@ -1703,7 +1756,7 @@ def get_interpolate_to_field_function(
1703
1756
  ValueStruct: wp.codegen.Struct,
1704
1757
  dest: FieldRestriction,
1705
1758
  ):
1706
- value_type = dest.space.dtype
1759
+ zero_value = type_zero_element(dest.space.dtype)
1707
1760
 
1708
1761
  def interpolate_to_field_fn(
1709
1762
  local_node_index: int,
@@ -1724,7 +1777,7 @@ def get_interpolate_to_field_function(
1724
1777
  # Volume-weighted average across elements
1725
1778
  # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1726
1779
 
1727
- val_sum = value_type(0.0)
1780
+ val_sum = zero_value()
1728
1781
  vol_sum = float(0.0)
1729
1782
 
1730
1783
  for n in range(element_beg, element_end):
@@ -1969,6 +2022,7 @@ def get_interpolate_free_kernel(
1969
2022
  def interpolate_free_nonvalued_kernel_fn(
1970
2023
  dim: int,
1971
2024
  domain_arg: domain.ElementArg,
2025
+ domain_index_arg: domain.ElementIndexArg,
1972
2026
  fields: FieldStruct,
1973
2027
  values: ValueStruct,
1974
2028
  result: wp.array(dtype=float),
@@ -1987,6 +2041,7 @@ def get_interpolate_free_kernel(
1987
2041
  def interpolate_free_kernel_fn(
1988
2042
  dim: int,
1989
2043
  domain_arg: domain.ElementArg,
2044
+ domain_index_arg: domain.ElementIndexArg,
1990
2045
  fields: FieldStruct,
1991
2046
  values: ValueStruct,
1992
2047
  result: wp.array(dtype=value_type),
@@ -2143,12 +2198,12 @@ def _launch_interpolate_kernel(
2143
2198
  field_arg_values = FieldStruct()
2144
2199
  for k, v in fields.items():
2145
2200
  if not isinstance(v, GeometryDomain):
2146
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
2201
+ v.fill_eval_arg(getattr(field_arg_values, k), device=device)
2147
2202
 
2148
2203
  value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
2149
2204
 
2150
2205
  if isinstance(dest, FieldRestriction):
2151
- dest_node_arg = dest.space_restriction.node_arg(device=device)
2206
+ dest_node_arg = dest.space_restriction.node_arg_value(device=device)
2152
2207
  dest_eval_arg = dest.field.eval_arg_value(device=device)
2153
2208
 
2154
2209
  wp.launch(
@@ -2167,33 +2222,49 @@ def _launch_interpolate_kernel(
2167
2222
  return
2168
2223
 
2169
2224
  if quadrature is None:
2225
+ if dest is not None and (not is_array(dest) or dest.shape[0] != dim):
2226
+ raise ValueError(f"dest must be a warp array with {dim} rows")
2227
+
2170
2228
  wp.launch(
2171
2229
  kernel=kernel,
2172
2230
  dim=dim,
2173
- inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2231
+ inputs=[dim, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2174
2232
  device=device,
2175
2233
  )
2176
2234
  return
2177
2235
 
2178
2236
  qp_arg = quadrature.arg_value(device)
2237
+ qp_eval_count = quadrature.evaluation_point_count()
2238
+ qp_index_count = quadrature.total_point_count()
2239
+
2240
+ if qp_eval_count != qp_index_count:
2241
+ wp.utils.warn(
2242
+ f"Quadrature used for interpolation of {integrand.name} has different number of evaluation and indexed points, this may lead to incorrect results",
2243
+ category=UserWarning,
2244
+ stacklevel=2,
2245
+ )
2246
+
2179
2247
  qp_element_index_arg = quadrature.element_index_arg_value(device)
2180
2248
  if trial is None:
2249
+ if dest is not None and (not is_array(dest) or dest.shape[0] != qp_index_count):
2250
+ raise ValueError(f"dest must be a warp array with {qp_index_count} rows")
2251
+
2181
2252
  wp.launch(
2182
2253
  kernel=kernel,
2183
- dim=quadrature.evaluation_point_count(),
2254
+ dim=qp_eval_count,
2184
2255
  inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2185
2256
  device=device,
2186
2257
  )
2187
2258
  return
2188
2259
 
2189
- nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
2260
+ nnz = qp_eval_count * trial.space.topology.MAX_NODES_PER_ELEMENT
2190
2261
 
2191
- if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
2262
+ if dest.nrow != qp_index_count or dest.ncol != trial.space_partition.node_count():
2192
2263
  raise RuntimeError(
2193
- f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
2264
+ f"'dest' matrix must have {qp_index_count} rows and {trial.space_partition.node_count()} columns of blocks"
2194
2265
  )
2195
2266
  if dest.block_shape[1] != trial.node_dof_count:
2196
- raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
2267
+ raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
2197
2268
 
2198
2269
  triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2199
2270
  triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
@@ -2243,7 +2314,7 @@ def interpolate(
2243
2314
  integrand: Union[Integrand, FieldLike],
2244
2315
  dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
2245
2316
  quadrature: Optional[Quadrature] = None,
2246
- dim: int = 0,
2317
+ dim: Optional[int] = None,
2247
2318
  domain: Optional[Domain] = None,
2248
2319
  fields: Optional[Dict[str, FieldLike]] = None,
2249
2320
  values: Optional[Dict[str, Any]] = None,
@@ -2290,11 +2361,13 @@ def interpolate(
2290
2361
  arguments = _parse_integrand_arguments(integrand, fields)
2291
2362
  if arguments.test_name:
2292
2363
  raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
2293
- if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
2364
+ if arguments.trial_name and not isinstance(dest, BsrMatrix):
2294
2365
  raise ValueError(
2295
- f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
2366
+ f"Interpolation using trial field '{arguments.trial_name}' requires 'dest' to be a `warp.sparse.BsrMatrix`"
2296
2367
  )
2297
2368
 
2369
+ trial = arguments.field_args.get(arguments.trial_name, None)
2370
+
2298
2371
  if isinstance(dest, DiscreteField):
2299
2372
  dest = make_restriction(dest, domain=domain)
2300
2373
 
@@ -2302,12 +2375,25 @@ def interpolate(
2302
2375
  domain = dest.domain
2303
2376
  elif quadrature is not None:
2304
2377
  domain = quadrature.domain
2378
+ elif dim is None:
2379
+ if trial is not None:
2380
+ domain = trial.domain
2381
+ elif domain is None:
2382
+ raise ValueError(
2383
+ "Unable to determine interpolation domain, provide an explicit field restriction or quadrature"
2384
+ )
2385
+
2386
+ # Default to one sample per domain element
2387
+ quadrature = RegularQuadrature(domain, order=0)
2305
2388
 
2306
2389
  if arguments.domain_name:
2307
2390
  arguments.field_args[arguments.domain_name] = domain
2308
2391
 
2309
2392
  _find_integrand_operators(integrand, arguments.field_args)
2310
2393
 
2394
+ if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
2395
+ wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2396
+
2311
2397
  kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
2312
2398
  integrand=integrand,
2313
2399
  domain=domain,
@@ -2326,7 +2412,7 @@ def interpolate(
2326
2412
  dest=dest,
2327
2413
  quadrature=quadrature,
2328
2414
  dim=dim,
2329
- trial=fields.get(arguments.trial_name),
2415
+ trial=trial,
2330
2416
  fields=arguments.field_args,
2331
2417
  values=values,
2332
2418
  temporary_store=temporary_store,