warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -19,6 +19,7 @@ import textwrap
19
19
  from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
20
20
 
21
21
  import warp as wp
22
+ import warp.fem.operator as operator
22
23
  from warp.codegen import get_annotations
23
24
  from warp.fem import cache
24
25
  from warp.fem.domain import GeometryDomain
@@ -35,7 +36,11 @@ from warp.fem.field import (
35
36
  )
36
37
  from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
37
38
  from warp.fem.linalg import array_axpy, basis_coefficient
38
- from warp.fem.operator import Integrand, Operator, at_node, integrand
39
+ from warp.fem.operator import (
40
+ Integrand,
41
+ Operator,
42
+ integrand,
43
+ )
39
44
  from warp.fem.quadrature import Quadrature, RegularQuadrature
40
45
  from warp.fem.types import (
41
46
  NULL_DOF_INDEX,
@@ -49,8 +54,9 @@ from warp.fem.types import (
49
54
  Sample,
50
55
  make_free_sample,
51
56
  )
57
+ from warp.fem.utils import type_zero_element
52
58
  from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
53
- from warp.types import type_length
59
+ from warp.types import type_size
54
60
  from warp.utils import array_cast
55
61
 
56
62
 
@@ -111,6 +117,8 @@ class IntegrandVisitor(ast.NodeTransformer):
111
117
  def get_concrete_type(field: Union[FieldLike, Domain]):
112
118
  if isinstance(field, FieldLike):
113
119
  return field.ElementEvalArg
120
+ elif isinstance(field, GeometryDomain):
121
+ return field.DomainArg
114
122
  return field.ElementArg
115
123
 
116
124
  return {
@@ -232,7 +240,7 @@ class IntegrandOperatorParser(IntegrandVisitor):
232
240
 
233
241
  @staticmethod
234
242
  def apply(
235
- integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
243
+ integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Optional[Callable] = None
236
244
  ) -> wp.Function:
237
245
  field_info = IntegrandVisitor._build_field_info(integrand, field_args)
238
246
  IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
@@ -267,7 +275,11 @@ class IntegrandTransformer(IntegrandVisitor):
267
275
  setattr(field_info.concrete_type, pointer.key, pointer)
268
276
 
269
277
  # also insert callee as first argument
270
- call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
278
+ call.args = [ast.Name(id=callee, ctx=ast.Load()), *call.args]
279
+
280
+ # replace first argument with selected attribute
281
+ if operator.attr:
282
+ call.args[0] = ast.Attribute(value=call.args[0], attr=operator.attr)
271
283
 
272
284
  def _process_integrand_call(
273
285
  self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
@@ -456,6 +468,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
456
468
  fields_var_name: str = "fields",
457
469
  values_var_name: str = "values",
458
470
  domain_var_name: str = "domain_arg",
471
+ domain_index_var_name: str = "domain_index_arg",
459
472
  sample_var_name: str = "sample",
460
473
  field_wrappers_attr: str = "_field_wrappers",
461
474
  ):
@@ -470,6 +483,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
470
483
  self._fields_var_name = fields_var_name
471
484
  self._values_var_name = values_var_name
472
485
  self._domain_var_name = domain_var_name
486
+ self._domain_index_var_name = domain_index_var_name
473
487
  self._sample_var_name = sample_var_name
474
488
 
475
489
  self._field_wrappers_attr = field_wrappers_attr
@@ -485,8 +499,28 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
485
499
  for name, field in fields.items():
486
500
  if isinstance(field, FieldLike):
487
501
  setattr(field_wrappers, name, field.ElementEvalArg)
502
+ elif isinstance(field, GeometryDomain):
503
+ setattr(field_wrappers, name, field.DomainArg)
488
504
  setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
489
505
 
506
+ def _emit_field_wrapper_call(self, field_name, *data_arguments):
507
+ return ast.Call(
508
+ func=ast.Attribute(
509
+ value=ast.Attribute(
510
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
511
+ attr=self._field_wrappers_attr,
512
+ ctx=ast.Load(),
513
+ ),
514
+ attr=field_name,
515
+ ctx=ast.Load(),
516
+ ),
517
+ args=[
518
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
519
+ *data_arguments,
520
+ ],
521
+ keywords=[],
522
+ )
523
+
490
524
  def visit_Call(self, call: ast.Call):
491
525
  call = self.generic_visit(call)
492
526
 
@@ -498,33 +532,25 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
498
532
  for arg in self._arg_names:
499
533
  if arg == self._domain_name:
500
534
  call.args.append(
501
- ast.Name(id=self._domain_var_name, ctx=ast.Load()),
535
+ self._emit_field_wrapper_call(
536
+ arg,
537
+ ast.Name(id=self._domain_index_var_name, ctx=ast.Load()),
538
+ )
502
539
  )
540
+
503
541
  elif arg == self._sample_name:
504
542
  call.args.append(
505
543
  ast.Name(id=self._sample_var_name, ctx=ast.Load()),
506
544
  )
507
545
  elif arg in self._field_args:
508
546
  call.args.append(
509
- ast.Call(
510
- func=ast.Attribute(
511
- value=ast.Attribute(
512
- value=ast.Name(id=self._func_name, ctx=ast.Load()),
513
- attr=self._field_wrappers_attr,
514
- ctx=ast.Load(),
515
- ),
547
+ self._emit_field_wrapper_call(
548
+ arg,
549
+ ast.Attribute(
550
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
516
551
  attr=arg,
517
552
  ctx=ast.Load(),
518
553
  ),
519
- args=[
520
- ast.Name(id=self._domain_var_name, ctx=ast.Load()),
521
- ast.Attribute(
522
- value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
523
- attr=arg,
524
- ctx=ast.Load(),
525
- ),
526
- ],
527
- keywords=[],
528
554
  )
529
555
  )
530
556
  elif arg in self._value_args:
@@ -704,7 +730,7 @@ def get_integrate_linear_nodal_kernel(
704
730
 
705
731
  coords = test.space.node_coords_in_element(
706
732
  domain_arg,
707
- _get_test_arg(),
733
+ _get_test_arg().space_arg,
708
734
  element_index,
709
735
  node_element_index.node_index_in_element,
710
736
  )
@@ -712,7 +738,7 @@ def get_integrate_linear_nodal_kernel(
712
738
  if coords[0] != OUTSIDE:
713
739
  node_weight = test.space.node_quadrature_weight(
714
740
  domain_arg,
715
- _get_test_arg(),
741
+ _get_test_arg().space_arg,
716
742
  element_index,
717
743
  node_element_index.node_index_in_element,
718
744
  )
@@ -913,7 +939,7 @@ def get_integrate_bilinear_nodal_kernel(
913
939
 
914
940
  coords = test.space.node_coords_in_element(
915
941
  domain_arg,
916
- _get_test_arg(),
942
+ _get_test_arg().space_arg,
917
943
  element_index,
918
944
  node_element_index.node_index_in_element,
919
945
  )
@@ -921,7 +947,7 @@ def get_integrate_bilinear_nodal_kernel(
921
947
  if coords[0] != OUTSIDE:
922
948
  node_weight = test.space.node_quadrature_weight(
923
949
  domain_arg,
924
- _get_test_arg(),
950
+ _get_test_arg().space_arg,
925
951
  element_index,
926
952
  node_element_index.node_index_in_element,
927
953
  )
@@ -1153,7 +1179,7 @@ def _launch_integrate_kernel(
1153
1179
  field_arg_values = FieldStruct()
1154
1180
  for k, v in fields.items():
1155
1181
  if not isinstance(v, GeometryDomain):
1156
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
1182
+ v.fill_eval_arg(getattr(field_arg_values, k), device=device)
1157
1183
 
1158
1184
  value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1159
1185
 
@@ -1203,7 +1229,7 @@ def _launch_integrate_kernel(
1203
1229
  array_cast(in_array=accumulate_array, out_array=output)
1204
1230
  return output
1205
1231
 
1206
- test_arg = test.space_restriction.node_arg(device=device)
1232
+ test_arg = test.space_restriction.node_arg_value(device=device)
1207
1233
  nodal = quadrature is None
1208
1234
 
1209
1235
  # Linear form
@@ -1211,9 +1237,9 @@ def _launch_integrate_kernel(
1211
1237
  # If an output array is provided with the correct type, accumulate directly into it
1212
1238
  # Otherwise, grab a temporary array
1213
1239
  if output is None:
1214
- if type_length(output_dtype) == test.node_dof_count:
1240
+ if type_size(output_dtype) == test.node_dof_count:
1215
1241
  output_shape = (test.space_partition.node_count(),)
1216
- elif type_length(output_dtype) == 1:
1242
+ elif type_size(output_dtype) == 1:
1217
1243
  output_shape = (test.space_partition.node_count(), test.node_dof_count)
1218
1244
  else:
1219
1245
  raise RuntimeError(
@@ -1236,8 +1262,8 @@ def _launch_integrate_kernel(
1236
1262
  raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
1237
1263
 
1238
1264
  output_dtype = output.dtype
1239
- if type_length(output_dtype) != test.node_dof_count:
1240
- if type_length(output_dtype) != 1:
1265
+ if type_size(output_dtype) != test.node_dof_count:
1266
+ if type_size(output_dtype) != 1:
1241
1267
  raise RuntimeError(
1242
1268
  f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
1243
1269
  )
@@ -1311,7 +1337,7 @@ def _launch_integrate_kernel(
1311
1337
  domain_elt_arg,
1312
1338
  domain_elt_index_arg,
1313
1339
  test_arg,
1314
- test.global_field.eval_arg_value(device),
1340
+ test.space.space_arg_value(device),
1315
1341
  local_result.array,
1316
1342
  output_view,
1317
1343
  ],
@@ -1450,10 +1476,10 @@ def _launch_integrate_kernel(
1450
1476
  domain_elt_arg,
1451
1477
  domain_elt_index_arg,
1452
1478
  test_arg,
1453
- test.global_field.eval_arg_value(device),
1479
+ test.space.space_arg_value(device),
1454
1480
  trial_partition_arg,
1455
1481
  trial_topology_arg,
1456
- trial.global_field.eval_arg_value(device),
1482
+ trial.space.space_arg_value(device),
1457
1483
  local_result_as_vec,
1458
1484
  triplet_rows,
1459
1485
  triplet_cols,
@@ -1529,21 +1555,30 @@ def _pick_assembly_strategy(
1529
1555
  if assembly not in ("generic", "nodal", "dispatch"):
1530
1556
  raise ValueError(f"Invalid assembly strategy'{assembly}'")
1531
1557
  return assembly
1532
- elif nodal:
1533
- return "nodal"
1558
+ elif nodal is not None:
1559
+ wp.utils.warn(
1560
+ "'nodal' argument of `warp.fem.integrate` is deprecated and will be removed in a future version. Please use `assembly='nodal'` instead.",
1561
+ category=DeprecationWarning,
1562
+ stacklevel=2,
1563
+ )
1564
+ if nodal:
1565
+ return "nodal"
1534
1566
 
1535
- test_operators = operators.get(arguments.test_name, {})
1536
- trial_operators = operators.get(arguments.trial_name, {})
1537
- uses_at_node = at_node in test_operators or at_node in trial_operators
1567
+ test_operators = operators.get(arguments.test_name, set())
1568
+ trial_operators = operators.get(arguments.trial_name, set())
1538
1569
 
1539
- return "generic" if uses_at_node else "dispatch"
1570
+ uses_virtual_node_operator = {operator.at_node, operator.node_count, operator.node_index} & (
1571
+ test_operators | trial_operators
1572
+ )
1573
+
1574
+ return "generic" if uses_virtual_node_operator else "dispatch"
1540
1575
 
1541
1576
 
1542
1577
  def integrate(
1543
1578
  integrand: Integrand,
1544
1579
  domain: Optional[GeometryDomain] = None,
1545
1580
  quadrature: Optional[Quadrature] = None,
1546
- nodal: bool = False,
1581
+ nodal: Optional[bool] = None,
1547
1582
  fields: Optional[Dict[str, FieldLike]] = None,
1548
1583
  values: Optional[Dict[str, Any]] = None,
1549
1584
  accumulate_dtype: type = wp.float64,
@@ -1575,7 +1610,7 @@ def integrate(
1575
1610
  assembly: Specifies the strategy for assembling the integrated vector or matrix:
1576
1611
  - "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1577
1612
  - "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
1578
- - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
1613
+ - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` and `node_index` operators on test or trial functions.
1579
1614
  - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1580
1615
  add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1581
1616
  bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
@@ -1622,6 +1657,9 @@ def integrate(
1622
1657
 
1623
1658
  _find_integrand_operators(integrand, arguments.field_args)
1624
1659
 
1660
+ if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
1661
+ wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
1662
+
1625
1663
  assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
1626
1664
  # print("assembly for ", integrand.name, ":", strategy)
1627
1665
 
@@ -1703,7 +1741,7 @@ def get_interpolate_to_field_function(
1703
1741
  ValueStruct: wp.codegen.Struct,
1704
1742
  dest: FieldRestriction,
1705
1743
  ):
1706
- value_type = dest.space.dtype
1744
+ zero_value = type_zero_element(dest.space.dtype)
1707
1745
 
1708
1746
  def interpolate_to_field_fn(
1709
1747
  local_node_index: int,
@@ -1724,7 +1762,7 @@ def get_interpolate_to_field_function(
1724
1762
  # Volume-weighted average across elements
1725
1763
  # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1726
1764
 
1727
- val_sum = value_type(0.0)
1765
+ val_sum = zero_value()
1728
1766
  vol_sum = float(0.0)
1729
1767
 
1730
1768
  for n in range(element_beg, element_end):
@@ -1969,6 +2007,7 @@ def get_interpolate_free_kernel(
1969
2007
  def interpolate_free_nonvalued_kernel_fn(
1970
2008
  dim: int,
1971
2009
  domain_arg: domain.ElementArg,
2010
+ domain_index_arg: domain.ElementIndexArg,
1972
2011
  fields: FieldStruct,
1973
2012
  values: ValueStruct,
1974
2013
  result: wp.array(dtype=float),
@@ -1987,6 +2026,7 @@ def get_interpolate_free_kernel(
1987
2026
  def interpolate_free_kernel_fn(
1988
2027
  dim: int,
1989
2028
  domain_arg: domain.ElementArg,
2029
+ domain_index_arg: domain.ElementIndexArg,
1990
2030
  fields: FieldStruct,
1991
2031
  values: ValueStruct,
1992
2032
  result: wp.array(dtype=value_type),
@@ -2143,12 +2183,12 @@ def _launch_interpolate_kernel(
2143
2183
  field_arg_values = FieldStruct()
2144
2184
  for k, v in fields.items():
2145
2185
  if not isinstance(v, GeometryDomain):
2146
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
2186
+ v.fill_eval_arg(getattr(field_arg_values, k), device=device)
2147
2187
 
2148
2188
  value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
2149
2189
 
2150
2190
  if isinstance(dest, FieldRestriction):
2151
- dest_node_arg = dest.space_restriction.node_arg(device=device)
2191
+ dest_node_arg = dest.space_restriction.node_arg_value(device=device)
2152
2192
  dest_eval_arg = dest.field.eval_arg_value(device=device)
2153
2193
 
2154
2194
  wp.launch(
@@ -2170,7 +2210,7 @@ def _launch_interpolate_kernel(
2170
2210
  wp.launch(
2171
2211
  kernel=kernel,
2172
2212
  dim=dim,
2173
- inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2213
+ inputs=[dim, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2174
2214
  device=device,
2175
2215
  )
2176
2216
  return
@@ -2193,7 +2233,7 @@ def _launch_interpolate_kernel(
2193
2233
  f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
2194
2234
  )
2195
2235
  if dest.block_shape[1] != trial.node_dof_count:
2196
- raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
2236
+ raise RuntimeError(f"'dest' matrix blocks must have {trial.node_dof_count} columns")
2197
2237
 
2198
2238
  triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2199
2239
  triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
@@ -2243,7 +2283,7 @@ def interpolate(
2243
2283
  integrand: Union[Integrand, FieldLike],
2244
2284
  dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
2245
2285
  quadrature: Optional[Quadrature] = None,
2246
- dim: int = 0,
2286
+ dim: Optional[int] = None,
2247
2287
  domain: Optional[Domain] = None,
2248
2288
  fields: Optional[Dict[str, FieldLike]] = None,
2249
2289
  values: Optional[Dict[str, Any]] = None,
@@ -2290,11 +2330,13 @@ def interpolate(
2290
2330
  arguments = _parse_integrand_arguments(integrand, fields)
2291
2331
  if arguments.test_name:
2292
2332
  raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
2293
- if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
2333
+ if arguments.trial_name and not isinstance(dest, BsrMatrix):
2294
2334
  raise ValueError(
2295
- f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
2335
+ f"Interpolation using trial field '{arguments.trial_name}' requires 'dest' to be a `warp.sparse.BsrMatrix`"
2296
2336
  )
2297
2337
 
2338
+ trial = arguments.field_args.get(arguments.trial_name, None)
2339
+
2298
2340
  if isinstance(dest, DiscreteField):
2299
2341
  dest = make_restriction(dest, domain=domain)
2300
2342
 
@@ -2302,12 +2344,25 @@ def interpolate(
2302
2344
  domain = dest.domain
2303
2345
  elif quadrature is not None:
2304
2346
  domain = quadrature.domain
2347
+ elif dim is None:
2348
+ if trial is not None:
2349
+ domain = trial.domain
2350
+ elif domain is None:
2351
+ raise ValueError(
2352
+ "Unable to determine interpolation domain, provide an explicit field restriction or quadrature"
2353
+ )
2354
+
2355
+ # Default to one sample per domain element
2356
+ quadrature = RegularQuadrature(domain, order=0)
2305
2357
 
2306
2358
  if arguments.domain_name:
2307
2359
  arguments.field_args[arguments.domain_name] = domain
2308
2360
 
2309
2361
  _find_integrand_operators(integrand, arguments.field_args)
2310
2362
 
2363
+ if operator.lookup in integrand.operators.get(arguments.domain_name, []) and not domain.supports_lookup(device):
2364
+ wp.utils.warn(f"{integrand.name}: using lookup() operator on a domain that does not support it")
2365
+
2311
2366
  kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
2312
2367
  integrand=integrand,
2313
2368
  domain=domain,
@@ -2326,7 +2381,7 @@ def interpolate(
2326
2381
  dest=dest,
2327
2382
  quadrature=quadrature,
2328
2383
  dim=dim,
2329
- trial=fields.get(arguments.trial_name),
2384
+ trial=trial,
2330
2385
  fields=arguments.field_args,
2331
2386
  values=values,
2332
2387
  temporary_store=temporary_store,
warp/fem/linalg.py CHANGED
@@ -16,80 +16,62 @@
16
16
  from typing import Any
17
17
 
18
18
  import warp as wp
19
+ import warp.types
19
20
 
20
21
 
21
22
  @wp.func
22
- def generalized_outer(x: Any, y: Any):
23
- """Generalized outer product allowing for the first argument to be a scalar"""
23
+ def generalized_outer(x: wp.vec(Any, wp.Scalar), y: wp.vec(Any, wp.Scalar)):
24
+ """Generalized outer product allowing for vector or scalar arguments"""
24
25
  return wp.outer(x, y)
25
26
 
26
27
 
27
28
  @wp.func
28
- def generalized_outer(x: wp.float32, y: wp.vec2):
29
+ def generalized_outer(x: wp.Scalar, y: wp.vec(Any, wp.Scalar)):
29
30
  return x * y
30
31
 
31
32
 
32
33
  @wp.func
33
- def generalized_outer(x: wp.float32, y: wp.vec3):
34
+ def generalized_outer(x: wp.vec(Any, wp.Scalar), y: wp.Scalar):
34
35
  return x * y
35
36
 
36
37
 
37
38
  @wp.func
38
- def generalized_inner(x: Any, y: Any):
39
- """Generalized inner product allowing for the first argument to be a tensor"""
40
- return wp.dot(x, y)
41
-
42
-
43
- @wp.func
44
- def generalized_inner(x: float, y: float):
45
- return x * y
39
+ def generalized_outer(x: wp.quatf, y: wp.vec(Any, wp.Scalar)):
40
+ return generalized_outer(wp.vec4(x[0], x[1], x[2], x[3]), y)
46
41
 
47
42
 
48
43
  @wp.func
49
- def generalized_inner(x: wp.mat22, y: wp.vec2):
50
- return x[0] * y[0] + x[1] * y[1]
51
-
52
-
53
- @wp.func
54
- def generalized_inner(x: wp.mat33, y: wp.vec3):
55
- return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]
44
+ def generalized_inner(x: wp.vec(Any, wp.Scalar), y: wp.vec(Any, wp.Scalar)):
45
+ """Generalized inner product allowing for vector, tensor and scalar arguments"""
46
+ return wp.dot(x, y)
56
47
 
57
48
 
58
49
  @wp.func
59
- def basis_element(template_type: Any, coord: int):
60
- """Returns a instance of `template_type` with a single coordinate set to 1 in the canonical basis"""
61
-
62
- t = type(template_type)(0.0)
63
- t[coord] = 1.0
64
- return t
50
+ def generalized_inner(x: wp.Scalar, y: wp.Scalar):
51
+ return x * y
65
52
 
66
53
 
67
54
  @wp.func
68
- def basis_element(template_type: wp.float32, coord: int):
69
- return 1.0
55
+ def generalized_inner(x: wp.mat((Any, Any), wp.Scalar), y: wp.vec(Any, wp.Scalar)):
56
+ return y @ x
70
57
 
71
58
 
72
59
  @wp.func
73
- def basis_element(template_type: wp.mat22, coord: int):
74
- t = wp.mat22(0.0)
75
- row = coord // 2
76
- col = coord - 2 * row
77
- t[row, col] = 1.0
78
- return t
60
+ def generalized_inner(x: wp.vec(Any, wp.Scalar), y: wp.mat((Any, Any), wp.Scalar)):
61
+ return y @ x
79
62
 
80
63
 
81
64
  @wp.func
82
- def basis_element(template_type: wp.mat33, coord: int):
83
- t = wp.mat33(0.0)
84
- row = coord // 3
85
- col = coord - 3 * row
86
- t[row, col] = 1.0
87
- return t
65
+ def basis_coefficient(val: wp.Scalar, i: int):
66
+ return val
88
67
 
89
68
 
90
69
  @wp.func
91
- def basis_coefficient(val: wp.float32, i: int):
92
- return val
70
+ def basis_coefficient(val: wp.mat((Any, Any), wp.Scalar), i: int):
71
+ cols = int(type(val[0]).length)
72
+ row = i // cols
73
+ col = i - row * cols
74
+ return val[row, col]
93
75
 
94
76
 
95
77
  @wp.func
@@ -98,31 +80,16 @@ def basis_coefficient(val: Any, i: int):
98
80
 
99
81
 
100
82
  @wp.func
101
- def basis_coefficient(val: wp.vec2, i: int, j: int):
102
- # treat as row vector
103
- return val[j]
104
-
105
-
106
- @wp.func
107
- def basis_coefficient(val: wp.vec3, i: int, j: int):
83
+ def basis_coefficient(val: wp.vec(Any, wp.Scalar), i: int, j: int):
108
84
  # treat as row vector
109
85
  return val[j]
110
86
 
111
87
 
112
88
  @wp.func
113
- def basis_coefficient(val: Any, i: int, j: int):
89
+ def basis_coefficient(val: wp.mat((Any, Any), wp.Scalar), i: int, j: int):
114
90
  return val[i, j]
115
91
 
116
92
 
117
- @wp.func
118
- def basis_coefficient(template_type: wp.mat33, coord: int):
119
- t = wp.mat33(0.0)
120
- row = coord // 3
121
- col = coord - 3 * row
122
- t[row, col] = 1.0
123
- return t
124
-
125
-
126
93
  @wp.func
127
94
  def symmetric_part(x: Any):
128
95
  """Symmetric part of a square tensor"""