warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import ast
2
- from typing import Any, Dict, List, Optional, Set, Union
2
+ import inspect
3
+ import textwrap
4
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
3
5
 
4
6
  import warp as wp
5
7
  from warp.codegen import get_annotations
@@ -10,11 +12,15 @@ from warp.fem.field import (
10
12
  FieldLike,
11
13
  FieldRestriction,
12
14
  GeometryField,
15
+ LocalTestField,
16
+ LocalTrialField,
13
17
  TestField,
14
18
  TrialField,
15
19
  make_restriction,
16
20
  )
17
- from warp.fem.operator import Integrand, Operator, integrand
21
+ from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
22
+ from warp.fem.linalg import array_axpy
23
+ from warp.fem.operator import Integrand, Operator, at_node, integrand
18
24
  from warp.fem.quadrature import Quadrature, RegularQuadrature
19
25
  from warp.fem.types import (
20
26
  NULL_DOF_INDEX,
@@ -69,28 +75,58 @@ def _resolve_path(func, node):
69
75
  return None, path
70
76
 
71
77
 
72
- class IntegrandTransformer(ast.NodeTransformer):
73
- def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike], annotations: Dict[str, Any]):
78
+ class IntegrandVisitor(ast.NodeTransformer):
79
+ class FieldInfo(NamedTuple):
80
+ field: FieldLike
81
+ abstract_type: type
82
+ concrete_type: type
83
+ root_arg_name: type
84
+
85
+ def __init__(
86
+ self,
87
+ integrand: Integrand,
88
+ field_info: Dict[str, FieldInfo],
89
+ ):
74
90
  self._integrand = integrand
75
- self._field_args = field_args
76
- self._annotations = annotations
91
+ self._field_symbols = field_info.copy()
92
+ self._field_nodes = {}
93
+
94
+ @staticmethod
95
+ def _build_field_info(integrand: Integrand, field_args: Dict[str, FieldLike]):
96
+ def get_concrete_type(field: Union[FieldLike, Domain]):
97
+ if isinstance(field, FieldLike):
98
+ return field.ElementEvalArg
99
+ return field.ElementArg
100
+
101
+ return {
102
+ name: IntegrandVisitor.FieldInfo(
103
+ field=field,
104
+ abstract_type=integrand.argspec.annotations[name],
105
+ concrete_type=get_concrete_type(field),
106
+ root_arg_name=name,
107
+ )
108
+ for name, field in field_args.items()
109
+ }
110
+
111
+ def _get_field_info(self, node: ast.expr):
112
+ field_info = self._field_nodes.get(node)
113
+ if field_info is None and isinstance(node, ast.Name):
114
+ field_info = self._field_symbols.get(node.id)
115
+
116
+ return field_info
77
117
 
78
118
  def visit_Call(self, call: ast.Call):
79
119
  call = self.generic_visit(call)
80
120
 
81
121
  callee = getattr(call.func, "id", None)
82
- if callee in self._field_args:
122
+ if callee in self._field_symbols:
83
123
  # Shortcut for evaluating fields as f(x...)
84
- field = self._field_args[callee]
124
+ field_info = self._field_symbols[callee]
85
125
 
86
126
  # Replace with default call operator
87
- abstract_arg_type = self._integrand.argspec.annotations[callee]
88
- default_operator = abstract_arg_type.call_operator
89
- concrete_arg_type = self._annotations[callee]
90
- self._replace_call_func(call, concrete_arg_type, default_operator, field)
127
+ default_operator = field_info.abstract_type.call_operator
91
128
 
92
- # insert callee as first argument
93
- call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
129
+ self._process_operator_call(call, callee, default_operator, field_info)
94
130
 
95
131
  return call
96
132
 
@@ -98,44 +134,47 @@ class IntegrandTransformer(ast.NodeTransformer):
98
134
 
99
135
  if isinstance(func, Operator) and len(call.args) > 0:
100
136
  # Evaluating operators as op(field, x, ...)
101
- callee = getattr(call.args[0], "id", None)
102
- if callee in self._field_args:
103
- field = self._field_args[callee]
104
- self._replace_call_func(call, func, func, field)
137
+ field_info = self._get_field_info(call.args[0])
138
+ if field_info is not None:
139
+ self._process_operator_call(call, func, func, field_info)
140
+
141
+ if func.field_result:
142
+ res = func.field_result(field_info.field)
143
+ self._field_nodes[call] = IntegrandVisitor.FieldInfo(
144
+ field=res[0],
145
+ abstract_type=res[1],
146
+ concrete_type=res[2],
147
+ root_arg_name=f"{field_info.root_arg_name}.{func.name}",
148
+ )
105
149
 
106
150
  if isinstance(func, Integrand):
107
- key = self._translate_callee(func, call.args)
108
- call.func = ast.Attribute(
109
- value=call.func,
110
- attr=key,
111
- ctx=ast.Load(),
112
- )
151
+ callee_field_args = self._get_callee_field_args(func, call.args)
152
+ self._process_integrand_call(call, func, callee_field_args)
113
153
 
114
154
  # print(ast.dump(call, indent=4))
115
155
 
116
156
  return call
117
157
 
118
- def _replace_call_func(self, call: ast.Call, callee: Union[type, Operator], operator: Operator, field: FieldLike):
119
- try:
120
- # Retrieve the function pointer corresponding to the operator implementation for the field type
121
- pointer = operator.resolver(field)
122
- if pointer is None:
123
- raise NotImplementedError(operator.resolver.__name__)
158
+ def visit_Assign(self, node: ast.Assign):
159
+ node = self.generic_visit(node)
124
160
 
125
- except (AttributeError, NotImplementedError) as e:
126
- raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}") from e
127
- # Save the pointer as an attribute than can be accessed from the callee scope
128
- setattr(callee, pointer.key, pointer)
129
- # Update the ast Call node to use the new function pointer
130
- call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
161
+ # Check if we're assigning a field
162
+ src_field_info = self._get_field_info(node.value)
163
+ if src_field_info is not None:
164
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
165
+ raise NotImplementedError("warp.fem Fields and Domains may only be assigned to simple variables")
166
+
167
+ self._field_symbols[node.targets[0].id] = src_field_info
168
+
169
+ return node
131
170
 
132
- def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
171
+ def _get_callee_field_args(self, callee: Integrand, args: List[ast.AST]):
133
172
  # Get field types for call site arguments
134
- call_site_field_args = []
173
+ call_site_field_args: List[IntegrandVisitor.FieldInfo] = []
135
174
  for arg in args:
136
- name = getattr(arg, "id", None)
137
- if name in self._field_args:
138
- call_site_field_args.append(self._field_args[name])
175
+ field_info = self._get_field_info(arg)
176
+ if field_info is not None:
177
+ call_site_field_args.append(field_info)
139
178
 
140
179
  call_site_field_args.reverse()
141
180
 
@@ -144,46 +183,129 @@ class IntegrandTransformer(ast.NodeTransformer):
144
183
  for arg in callee.argspec.args:
145
184
  arg_type = callee.argspec.annotations[arg]
146
185
  if arg_type in (Field, Domain):
147
- callee_field_args[arg] = call_site_field_args.pop()
186
+ passed_field_info = call_site_field_args.pop()
187
+ if passed_field_info.abstract_type != arg_type:
188
+ raise TypeError(
189
+ f"Attempting to pass a {passed_field_info.abstract_type.__name__} to argument '{arg}' of '{callee.name}' expecting a {arg_type.__name__}"
190
+ )
191
+ callee_field_args[arg] = passed_field_info
148
192
 
149
- return _translate_integrand(callee, callee_field_args).key
193
+ return callee_field_args
150
194
 
151
195
 
152
- def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
153
- # Specialize field argument types
154
- argspec = integrand.argspec
155
- annotations = {}
156
- for arg in argspec.args:
157
- arg_type = argspec.annotations[arg]
158
- if arg_type == Field:
159
- annotations[arg] = field_args[arg].ElementEvalArg
160
- elif arg_type == Domain:
161
- annotations[arg] = field_args[arg].ElementArg
196
+ class IntegrandOperatorParser(IntegrandVisitor):
197
+ def __init__(self, integrand: Integrand, field_info: Dict[str, IntegrandVisitor.FieldInfo], callback: Callable):
198
+ super().__init__(integrand, field_info)
199
+ self._operator_callback = callback
200
+
201
+ def _process_operator_call(
202
+ self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
203
+ ):
204
+ self._operator_callback(field_info, operator)
205
+
206
+ def _process_integrand_call(
207
+ self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
208
+ ):
209
+ callee_field_args = self._get_callee_field_args(callee, call.args)
210
+ callee_parser = IntegrandOperatorParser(callee, callee_field_args, callback=self._operator_callback)
211
+ callee_parser._apply()
212
+
213
+ def _apply(self):
214
+ source = textwrap.dedent(inspect.getsource(self._integrand.func))
215
+ tree = ast.parse(source)
216
+ self.visit(tree)
217
+
218
+ @staticmethod
219
+ def apply(
220
+ integrand: Integrand, field_args: Dict[str, FieldLike], operator_callback: Callable = None
221
+ ) -> wp.Function:
222
+ field_info = IntegrandVisitor._build_field_info(integrand, field_args)
223
+ IntegrandOperatorParser(integrand, field_info, callback=operator_callback)._apply()
224
+
225
+
226
+ class IntegrandTransformer(IntegrandVisitor):
227
+ def _process_operator_call(
228
+ self, call: ast.Call, callee: Union[str, Operator], operator: Operator, field_info: IntegrandVisitor.FieldInfo
229
+ ):
230
+ field = field_info.field
231
+
232
+ try:
233
+ # Retrieve the function pointer corresponding to the operator implementation for the field type
234
+ pointer = operator.resolver(field)
235
+ if not isinstance(pointer, wp.context.Function):
236
+ raise NotImplementedError(operator.resolver.__name__)
237
+
238
+ except (AttributeError, NotImplementedError) as e:
239
+ raise TypeError(
240
+ f"Operator {operator.func.__name__} is not defined for {field_info.abstract_type.__name__} {field.name}"
241
+ ) from e
242
+
243
+ # Update the ast Call node to use the new function pointer
244
+ call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
245
+
246
+ # Save the pointer as an attribute than can be accessed from the calling scope
247
+ # For usual operator call syntax, we can use the operator itself, but for the
248
+ # shortcut default operator syntax, we store it on the callee's concrete type
249
+ if isinstance(callee, Operator):
250
+ setattr(callee, pointer.key, pointer)
162
251
  else:
163
- annotations[arg] = arg_type
252
+ setattr(field_info.concrete_type, pointer.key, pointer)
253
+
254
+ # also insert callee as first argument
255
+ call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
256
+
257
+ def _process_integrand_call(
258
+ self, call: ast.Call, callee: Integrand, callee_field_args: Dict[str, IntegrandVisitor.FieldInfo]
259
+ ):
260
+ callee_field_args = self._get_callee_field_args(callee, call.args)
261
+ transformer = IntegrandTransformer(callee, callee_field_args)
262
+ key = transformer._apply().key
263
+ call.func = ast.Attribute(
264
+ value=call.func,
265
+ attr=key,
266
+ ctx=ast.Load(),
267
+ )
164
268
 
165
- # Transform field evaluation calls
166
- transformer = IntegrandTransformer(integrand, field_args, annotations)
269
+ def _apply(self) -> wp.Function:
270
+ # Transform field evaluation calls
271
+ field_info = self._field_symbols
272
+
273
+ # Specialize field argument types
274
+ argspec = self._integrand.argspec
275
+ annotations = argspec.annotations.copy()
276
+ annotations.update({name: f.concrete_type for name, f in field_info.items()})
277
+
278
+ suffix = "_".join([f.field.name for f in field_info.values()])
279
+ func = cache.get_integrand_function(
280
+ integrand=self._integrand,
281
+ suffix=suffix,
282
+ annotations=annotations,
283
+ code_transformers=[self],
284
+ )
167
285
 
168
- suffix = "_".join([f.name for f in field_args.values()])
286
+ # func = self._integrand.module.functions[func.key] #no longer needed?
287
+ setattr(self._integrand, func.key, func)
169
288
 
170
- func = cache.get_integrand_function(
171
- integrand=integrand,
172
- suffix=suffix,
173
- annotations=annotations,
174
- code_transformers=[transformer],
175
- )
289
+ return func
176
290
 
177
- key = func.key
178
- setattr(integrand, key, integrand.module.functions[key])
291
+ @staticmethod
292
+ def apply(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
293
+ field_info = IntegrandVisitor._build_field_info(integrand, field_args)
294
+ return IntegrandTransformer(integrand, field_info)._apply()
179
295
 
180
- return getattr(integrand, key)
181
296
 
297
+ class IntegrandArguments(NamedTuple):
298
+ field_args: Dict[str, Union[FieldLike, GeometryDomain]]
299
+ value_args: Dict[str, Any]
300
+ domain_name: str
301
+ sample_name: str
302
+ test_name: str
303
+ trial_name: str
182
304
 
183
- def _get_integrand_field_arguments(
305
+
306
+ def _parse_integrand_arguments(
184
307
  integrand: Integrand,
185
308
  fields: Dict[str, FieldLike],
186
- domain: GeometryDomain = None,
187
309
  ):
188
310
  # parse argument types
189
311
  field_args = {}
@@ -191,38 +313,57 @@ def _get_integrand_field_arguments(
191
313
 
192
314
  domain_name = None
193
315
  sample_name = None
316
+ test_name = None
317
+ trial_name = None
194
318
 
195
319
  argspec = integrand.argspec
196
320
  for arg in argspec.args:
197
321
  arg_type = argspec.annotations[arg]
198
322
  if arg_type == Field:
199
- if arg not in fields:
200
- raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'")
201
- field_args[arg] = fields[arg]
323
+ try:
324
+ field = fields[arg]
325
+ except KeyError as err:
326
+ raise ValueError(f"Missing field for argument '{arg}' of integrand '{integrand.name}'") from err
327
+ if not isinstance(field, FieldLike):
328
+ raise ValueError(f"Passed field argument '{arg}' is not a proper Field")
329
+ if isinstance(field, TestField):
330
+ if test_name is not None:
331
+ raise ValueError(f"More than one test field argument: '{test_name}' and '{arg}'")
332
+ test_name = arg
333
+ elif isinstance(field, TrialField):
334
+ if trial_name is not None:
335
+ raise ValueError(f"More than one trial field argument: '{trial_name}' and '{arg}'")
336
+ trial_name = arg
337
+ field_args[arg] = field
202
338
  elif arg_type == Domain:
339
+ if domain_name is not None:
340
+ raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Domain")
341
+ if arg in fields:
342
+ raise ValueError(
343
+ f"Domain argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
344
+ )
203
345
  domain_name = arg
204
- field_args[arg] = domain
205
346
  elif arg_type == Sample:
347
+ if sample_name is not None:
348
+ raise SyntaxError(f"Integrand '{integrand.name}' must have at most one argument of type Sample")
349
+ if arg in fields:
350
+ raise ValueError(
351
+ f"Sample argument '{arg}' of '{integrand.name}' will be automatically populated and must not be passed as a field argument."
352
+ )
206
353
  sample_name = arg
207
354
  else:
355
+ if arg in fields:
356
+ raise ValueError(
357
+ f"Cannot pass a field argument to '{arg}' of '{integrand.name}' with is not of type 'Field'"
358
+ )
208
359
  value_args[arg] = arg_type
209
360
 
210
- return field_args, value_args, domain_name, sample_name
361
+ return IntegrandArguments(field_args, value_args, domain_name, sample_name, test_name, trial_name)
211
362
 
212
363
 
213
- def _check_field_compat(
214
- integrand: Integrand,
215
- fields: Dict[str, FieldLike],
216
- field_args: Dict[str, FieldLike],
217
- domain: GeometryDomain = None,
218
- ):
364
+ def _check_field_compat(integrand: Integrand, arguments: IntegrandArguments, domain: GeometryDomain):
219
365
  # Check field compatibility
220
- for name, field in fields.items():
221
- if name not in field_args:
222
- raise ValueError(
223
- f"Passed field argument '{name}' does not match any parameter of integrand '{integrand.name}'"
224
- )
225
-
366
+ for name, field in arguments.field_args.items():
226
367
  if isinstance(field, GeometryField) and domain is not None:
227
368
  if field.geometry != domain.geometry:
228
369
  raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
@@ -232,37 +373,32 @@ def _check_field_compat(
232
373
  )
233
374
 
234
375
 
235
- def _get_test_and_trial_fields(
236
- fields: Dict[str, FieldLike],
237
- ):
238
- test = None
239
- trial = None
240
- test_name = None
241
- trial_name = None
376
+ def _find_integrand_operators(integrand: Integrand, field_args: Dict[str, FieldLike]):
377
+ if integrand.operators is None:
378
+ # Integrands operator dictionary does not depend on concrete field type,
379
+ # so only needs to be built once per integrand
242
380
 
243
- for name, field in fields.items():
244
- if not isinstance(field, FieldLike):
245
- raise ValueError(f"Passed field argument '{name}' is not a proper Field")
246
-
247
- if isinstance(field, TestField):
248
- if test is not None:
249
- raise ValueError(f"More than one test field argument: '{test_name}' and '{name}'")
250
- test = field
251
- test_name = name
252
- elif isinstance(field, TrialField):
253
- if trial is not None:
254
- raise ValueError(f"More than one trial field argument: '{trial_name}' and '{name}'")
255
- trial = field
256
- trial_name = name
257
-
258
- if trial is not None:
259
- if test is None:
260
- raise ValueError("A trial field cannot be provided without a test field")
381
+ operators = {}
261
382
 
262
- if test.domain != trial.domain:
263
- raise ValueError("Incompatible test and trial domains")
383
+ def operator_callback(field: IntegrandVisitor.FieldInfo, op: Operator):
384
+ if field.root_arg_name in operators:
385
+ operators[field.root_arg_name].add(op)
386
+ else:
387
+ operators[field.root_arg_name] = {op}
388
+
389
+ IntegrandOperatorParser.apply(integrand, field_args, operator_callback=operator_callback)
390
+
391
+ integrand.operators = operators
264
392
 
265
- return test, test_name, trial, trial_name
393
+
394
+ def _notify_operator_usage(
395
+ integrand: Integrand,
396
+ field_args: Dict[str, FieldLike],
397
+ ):
398
+ for arg, field_ops in integrand.operators.items():
399
+ if arg in field_args:
400
+ # print(f"{arg} {field_args[arg].name} : {', '.join(op.name for op in field_ops)}")
401
+ field_args[arg].notify_operator_usage(field_ops)
266
402
 
267
403
 
268
404
  def _gen_field_struct(field_args: Dict[str, FieldLike]):
@@ -295,26 +431,12 @@ def _get_test_arg():
295
431
  pass
296
432
 
297
433
 
298
- class _FieldWrappers:
299
- pass
300
-
301
-
302
- def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
303
- integrand_func._field_wrappers = _FieldWrappers()
304
- for name, field in fields.items():
305
- setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
306
-
307
-
308
434
  class PassFieldArgsToIntegrand(ast.NodeTransformer):
309
435
  def __init__(
310
436
  self,
311
437
  arg_names: List[str],
312
- field_args: Set[str],
313
- value_args: Set[str],
314
- sample_name: str,
315
- domain_name: str,
316
- test_name: str = None,
317
- trial_name: str = None,
438
+ parsed_args: IntegrandArguments,
439
+ integrand_func: wp.Function,
318
440
  func_name: str = "integrand_func",
319
441
  fields_var_name: str = "fields",
320
442
  values_var_name: str = "values",
@@ -323,18 +445,32 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
323
445
  field_wrappers_attr: str = "_field_wrappers",
324
446
  ):
325
447
  self._arg_names = arg_names
326
- self._field_args = field_args
327
- self._value_args = value_args
328
- self._domain_name = domain_name
329
- self._sample_name = sample_name
448
+ self._field_args = parsed_args.field_args
449
+ self._value_args = parsed_args.value_args
450
+ self._domain_name = parsed_args.domain_name
451
+ self._sample_name = parsed_args.sample_name
452
+ self._test_name = parsed_args.test_name
453
+ self._trial_name = parsed_args.trial_name
330
454
  self._func_name = func_name
331
- self._test_name = test_name
332
- self._trial_name = trial_name
333
455
  self._fields_var_name = fields_var_name
334
456
  self._values_var_name = values_var_name
335
457
  self._domain_var_name = domain_var_name
336
458
  self._sample_var_name = sample_var_name
459
+
337
460
  self._field_wrappers_attr = field_wrappers_attr
461
+ self._register_integrand_field_wrappers(integrand_func, parsed_args.field_args)
462
+
463
+ class _FieldWrappers:
464
+ pass
465
+
466
+ def _register_integrand_field_wrappers(self, integrand_func: wp.Function, fields: Dict[str, FieldLike]):
467
+ # Mechanism to pass the geometry argument only once to the root kernel
468
+ # Field wrappers are used to forward it to all fields in nested integrand calls
469
+ field_wrappers = PassFieldArgsToIntegrand._FieldWrappers()
470
+ for name, field in fields.items():
471
+ if isinstance(field, FieldLike):
472
+ setattr(field_wrappers, name, field.ElementEvalArg)
473
+ setattr(integrand_func, self._field_wrappers_attr, field_wrappers)
338
474
 
339
475
  def visit_Call(self, call: ast.Call):
340
476
  call = self.generic_visit(call)
@@ -405,6 +541,16 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
405
541
  return call
406
542
 
407
543
 
544
+ def _combined_kernel_options(integrand_options: Optional[Dict[str, Any]], call_site_options: Optional[Dict[str, Any]]):
545
+ if integrand_options is None:
546
+ return {} if call_site_options is None else call_site_options
547
+
548
+ options = integrand_options.copy()
549
+ if call_site_options is not None:
550
+ options.update(call_site_options)
551
+ return options
552
+
553
+
408
554
  def get_integrate_constant_kernel(
409
555
  integrand_func: wp.Function,
410
556
  domain: GeometryDomain,
@@ -500,7 +646,7 @@ def get_integrate_linear_kernel(
500
646
 
501
647
  val_sum += accumulate_dtype(qp_weight * vol * val)
502
648
 
503
- result[node_index, test_dof] = output_dtype(val_sum)
649
+ result[node_index, test_dof] += output_dtype(val_sum)
504
650
 
505
651
  return integrate_kernel_fn
506
652
 
@@ -571,7 +717,48 @@ def get_integrate_linear_nodal_kernel(
571
717
 
572
718
  val_sum += accumulate_dtype(node_weight * vol * val)
573
719
 
574
- result[partition_node_index, dof] = output_dtype(val_sum)
720
+ result[partition_node_index, dof] += output_dtype(val_sum)
721
+
722
+ return integrate_kernel_fn
723
+
724
+
725
+ def get_integrate_linear_local_kernel(
726
+ integrand_func: wp.Function,
727
+ domain: GeometryDomain,
728
+ quadrature: Quadrature,
729
+ FieldStruct: wp.codegen.Struct,
730
+ ValueStruct: wp.codegen.Struct,
731
+ test: LocalTestField,
732
+ ):
733
+ TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
734
+
735
+ def integrate_kernel_fn(
736
+ qp_arg: quadrature.Arg,
737
+ domain_arg: domain.ElementArg,
738
+ domain_index_arg: domain.ElementIndexArg,
739
+ fields: FieldStruct,
740
+ values: ValueStruct,
741
+ result: wp.array3d(dtype=float),
742
+ ):
743
+ domain_element_index, taylor_dof, test_dof = wp.tid()
744
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
745
+
746
+ trial_dof_index = NULL_DOF_INDEX
747
+ test_dof_offset = test_dof * TAYLOR_DOF_COUNT
748
+
749
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
750
+ for qp in range(qp_point_count):
751
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
752
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
753
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
754
+
755
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
756
+
757
+ test_dof_index = DofIndex(qp_index, test_dof_offset + taylor_dof)
758
+
759
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
760
+ val = integrand_func(sample, fields, values)
761
+ result[qp_index, taylor_dof, test_dof] = qp_weight * vol * val
575
762
 
576
763
  return integrate_kernel_fn
577
764
 
@@ -747,62 +934,89 @@ def get_integrate_bilinear_nodal_kernel(
747
934
  return integrate_kernel_fn
748
935
 
749
936
 
937
+ def get_integrate_bilinear_local_kernel(
938
+ integrand_func: wp.Function,
939
+ domain: GeometryDomain,
940
+ quadrature: Quadrature,
941
+ FieldStruct: wp.codegen.Struct,
942
+ ValueStruct: wp.codegen.Struct,
943
+ test: LocalTestField,
944
+ trial: LocalTrialField,
945
+ ):
946
+ TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
947
+ TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
948
+
949
+ def integrate_kernel_fn(
950
+ qp_arg: quadrature.Arg,
951
+ domain_arg: domain.ElementArg,
952
+ domain_index_arg: domain.ElementIndexArg,
953
+ fields: FieldStruct,
954
+ values: ValueStruct,
955
+ result: wp.array4d(dtype=float),
956
+ ):
957
+ domain_element_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
958
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
959
+
960
+ test_dof_offset = TEST_TAYLOR_DOF_COUNT * test_dof
961
+ trial_dof_offset = TRIAL_TAYLOR_DOF_COUNT * trial_dof
962
+
963
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
964
+ for k in range(qp_point_count):
965
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
966
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
967
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
968
+
969
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
970
+ qp_vol = vol * qp_weight
971
+
972
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
973
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
974
+
975
+ test_dof_index = DofIndex(qp_index, test_dof_offset + test_taylor_dof)
976
+ trial_dof_index = DofIndex(qp_index, trial_dof_offset + trial_taylor_dof)
977
+
978
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
979
+ val = integrand_func(sample, fields, values)
980
+ result[qp_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
981
+
982
+ return integrate_kernel_fn
983
+
984
+
750
985
  def _generate_integrate_kernel(
751
986
  integrand: Integrand,
752
987
  domain: GeometryDomain,
753
- nodal: bool,
754
988
  quadrature: Quadrature,
989
+ arguments: IntegrandArguments,
755
990
  test: Optional[TestField],
756
- test_name: str,
757
991
  trial: Optional[TrialField],
758
- trial_name: str,
759
- fields: Dict[str, FieldLike],
760
992
  output_dtype: type,
761
993
  accumulate_dtype: type,
762
994
  kernel_options: Optional[Dict[str, Any]] = None,
763
995
  ) -> wp.Kernel:
764
- if kernel_options is None:
765
- kernel_options = {}
766
-
767
996
  output_dtype = wp.types.type_scalar_type(output_dtype)
768
997
 
769
- # Extract field arguments from integrand
770
- field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
771
- integrand, fields=fields, domain=domain
772
- )
998
+ FieldStruct = _gen_field_struct(arguments.field_args)
999
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
773
1000
 
774
- FieldStruct = _gen_field_struct(field_args)
775
- ValueStruct = cache.get_argument_struct(value_args)
1001
+ _notify_operator_usage(integrand, arguments.field_args)
776
1002
 
777
1003
  # Check if kernel exist in cache
778
- kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
779
- if nodal:
780
- kernel_suffix += "_nodal"
781
- else:
782
- kernel_suffix += quadrature.name
1004
+ field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
1005
+ kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{field_names}"
783
1006
 
784
- if test:
785
- kernel_suffix += f"_test_{test.space_partition.name}_{test.space.name}"
786
- if trial:
787
- kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
1007
+ if quadrature is not None:
1008
+ kernel_suffix += quadrature.name
788
1009
 
789
- kernel = cache.get_integrand_kernel(
790
- integrand=integrand,
791
- suffix=kernel_suffix,
792
- )
1010
+ kernel = cache.get_integrand_kernel(integrand=integrand, suffix=kernel_suffix, kernel_options=kernel_options)
793
1011
  if kernel is not None:
794
1012
  return kernel, FieldStruct, ValueStruct
795
1013
 
796
- # Not found in cache, transform integrand and generate kernel
1014
+ # Not found in cache, transform integrand and generate kernel
1015
+ _check_field_compat(integrand, arguments, domain)
797
1016
 
798
- _check_field_compat(integrand, fields, field_args, domain)
1017
+ integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
799
1018
 
800
- integrand_func = _translate_integrand(
801
- integrand,
802
- field_args,
803
- )
804
-
805
- _register_integrand_field_wrappers(integrand_func, fields)
1019
+ nodal = quadrature is None
806
1020
 
807
1021
  if test is None and trial is None:
808
1022
  integrate_kernel_fn = get_integrate_constant_kernel(
@@ -824,6 +1038,15 @@ def _generate_integrate_kernel(
824
1038
  output_dtype=output_dtype,
825
1039
  accumulate_dtype=accumulate_dtype,
826
1040
  )
1041
+ elif isinstance(test, LocalTestField):
1042
+ integrate_kernel_fn = get_integrate_linear_local_kernel(
1043
+ integrand_func,
1044
+ domain,
1045
+ quadrature,
1046
+ FieldStruct,
1047
+ ValueStruct,
1048
+ test=test,
1049
+ )
827
1050
  else:
828
1051
  integrate_kernel_fn = get_integrate_linear_kernel(
829
1052
  integrand_func,
@@ -846,6 +1069,16 @@ def _generate_integrate_kernel(
846
1069
  output_dtype=output_dtype,
847
1070
  accumulate_dtype=accumulate_dtype,
848
1071
  )
1072
+ elif isinstance(test, LocalTestField):
1073
+ integrate_kernel_fn = get_integrate_bilinear_local_kernel(
1074
+ integrand_func,
1075
+ domain,
1076
+ quadrature,
1077
+ FieldStruct,
1078
+ ValueStruct,
1079
+ test=test,
1080
+ trial=trial,
1081
+ )
849
1082
  else:
850
1083
  integrate_kernel_fn = get_integrate_bilinear_kernel(
851
1084
  integrand_func,
@@ -866,13 +1099,7 @@ def _generate_integrate_kernel(
866
1099
  kernel_options=kernel_options,
867
1100
  code_transformers=[
868
1101
  PassFieldArgsToIntegrand(
869
- arg_names=integrand.argspec.args,
870
- field_args=field_args.keys(),
871
- value_args=value_args.keys(),
872
- sample_name=sample_name,
873
- domain_name=domain_name,
874
- test_name=test_name,
875
- trial_name=trial_name,
1102
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
876
1103
  )
877
1104
  ],
878
1105
  )
@@ -886,7 +1113,6 @@ def _launch_integrate_kernel(
886
1113
  FieldStruct: wp.codegen.Struct,
887
1114
  ValueStruct: wp.codegen.Struct,
888
1115
  domain: GeometryDomain,
889
- nodal: bool,
890
1116
  quadrature: Quadrature,
891
1117
  test: Optional[TestField],
892
1118
  trial: Optional[TrialField],
@@ -896,6 +1122,7 @@ def _launch_integrate_kernel(
896
1122
  temporary_store: Optional[cache.TemporaryStore],
897
1123
  output_dtype: type,
898
1124
  output: Optional[Union[wp.array, BsrMatrix]],
1125
+ add_to_output: bool,
899
1126
  device,
900
1127
  ):
901
1128
  # Set-up launch arguments
@@ -907,7 +1134,8 @@ def _launch_integrate_kernel(
907
1134
 
908
1135
  field_arg_values = FieldStruct()
909
1136
  for k, v in fields.items():
910
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
1137
+ if not isinstance(v, GeometryDomain):
1138
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
911
1139
 
912
1140
  value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
913
1141
 
@@ -927,7 +1155,9 @@ def _launch_integrate_kernel(
927
1155
  )
928
1156
  accumulate_array = accumulate_temporary.array
929
1157
 
930
- accumulate_array.zero_()
1158
+ if output != accumulate_array or not add_to_output:
1159
+ accumulate_array.zero_()
1160
+
931
1161
  wp.launch(
932
1162
  kernel=kernel,
933
1163
  dim=domain.element_count(),
@@ -944,26 +1174,31 @@ def _launch_integrate_kernel(
944
1174
 
945
1175
  if output == accumulate_array:
946
1176
  return output
947
- elif output is None:
1177
+ if output is None:
948
1178
  return accumulate_array.numpy()[0]
1179
+
1180
+ if add_to_output:
1181
+ # accumulate dtype is distinct from output dtype
1182
+ array_axpy(x=accumulate_array, y=output)
949
1183
  else:
950
1184
  array_cast(in_array=accumulate_array, out_array=output)
951
- return output
1185
+ return output
952
1186
 
953
1187
  test_arg = test.space_restriction.node_arg(device=device)
1188
+ nodal = quadrature is None
954
1189
 
955
1190
  # Linear form
956
1191
  if trial is None:
957
1192
  # If an output array is provided with the correct type, accumulate directly into it
958
1193
  # Otherwise, grab a temporary array
959
1194
  if output is None:
960
- if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
1195
+ if type_length(output_dtype) == test.node_dof_count:
961
1196
  output_shape = (test.space_partition.node_count(),)
962
1197
  elif type_length(output_dtype) == 1:
963
- output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
1198
+ output_shape = (test.space_partition.node_count(), test.node_dof_count)
964
1199
  else:
965
1200
  raise RuntimeError(
966
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
1201
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
967
1202
  )
968
1203
 
969
1204
  output_temporary = cache.borrow_temporary(
@@ -982,18 +1217,19 @@ def _launch_integrate_kernel(
982
1217
  raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
983
1218
 
984
1219
  output_dtype = output.dtype
985
- if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
1220
+ if type_length(output_dtype) != test.node_dof_count:
986
1221
  if type_length(output_dtype) != 1:
987
1222
  raise RuntimeError(
988
- f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
1223
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.node_dof_count}"
989
1224
  )
990
- if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
1225
+ if output.ndim != 2 and output.shape[1] != test.node_dof_count:
991
1226
  raise RuntimeError(
992
- f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
1227
+ f"Incompatible output array shape, last dimension must be of size {test.node_dof_count}"
993
1228
  )
994
1229
 
995
1230
  # Launch the integration on the kernel on a 2d scalar view of the actual array
996
- output.zero_()
1231
+ if not add_to_output:
1232
+ output.zero_()
997
1233
 
998
1234
  def as_2d_array(array):
999
1235
  return wp.array(
@@ -1001,7 +1237,7 @@ def _launch_integrate_kernel(
1001
1237
  ptr=array.ptr,
1002
1238
  capacity=array.capacity,
1003
1239
  device=array.device,
1004
- shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
1240
+ shape=(test.space_partition.node_count(), test.node_dof_count),
1005
1241
  dtype=wp.types.type_scalar_type(output_dtype),
1006
1242
  grad=None if array.grad is None else as_2d_array(array.grad),
1007
1243
  )
@@ -1011,7 +1247,7 @@ def _launch_integrate_kernel(
1011
1247
  if nodal:
1012
1248
  wp.launch(
1013
1249
  kernel=kernel,
1014
- dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
1250
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1015
1251
  inputs=[
1016
1252
  domain_elt_arg,
1017
1253
  domain_elt_index_arg,
@@ -1023,10 +1259,51 @@ def _launch_integrate_kernel(
1023
1259
  ],
1024
1260
  device=device,
1025
1261
  )
1262
+ elif isinstance(test, LocalTestField):
1263
+ local_result = cache.borrow_temporary(
1264
+ temporary_store=temporary_store,
1265
+ device=device,
1266
+ requires_grad=output.requires_grad,
1267
+ shape=(quadrature.total_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1268
+ dtype=float,
1269
+ )
1270
+
1271
+ wp.launch(
1272
+ kernel=kernel,
1273
+ dim=(domain.element_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1274
+ inputs=[
1275
+ qp_arg,
1276
+ domain_elt_arg,
1277
+ domain_elt_index_arg,
1278
+ field_arg_values,
1279
+ value_struct_values,
1280
+ local_result.array,
1281
+ ],
1282
+ device=device,
1283
+ )
1284
+
1285
+ dispatch_kernel = make_linear_dispatch_kernel(test, quadrature, accumulate_dtype)
1286
+ wp.launch(
1287
+ kernel=dispatch_kernel,
1288
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1289
+ inputs=[
1290
+ qp_arg,
1291
+ domain_elt_arg,
1292
+ domain_elt_index_arg,
1293
+ test_arg,
1294
+ test.global_field.eval_arg_value(device),
1295
+ local_result.array,
1296
+ output_view,
1297
+ ],
1298
+ device=device,
1299
+ )
1300
+
1301
+ local_result.release()
1302
+
1026
1303
  else:
1027
1304
  wp.launch(
1028
1305
  kernel=kernel,
1029
- dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
1306
+ dim=(test.space_restriction.node_count(), test.node_dof_count),
1030
1307
  inputs=[
1031
1308
  qp_arg,
1032
1309
  domain_elt_arg,
@@ -1046,12 +1323,10 @@ def _launch_integrate_kernel(
1046
1323
 
1047
1324
  # Bilinear form
1048
1325
 
1049
- if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
1326
+ if test.node_dof_count == 1 and trial.node_dof_count == 1:
1050
1327
  block_type = output_dtype
1051
1328
  else:
1052
- block_type = cache.cached_mat_type(
1053
- shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
1054
- )
1329
+ block_type = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=output_dtype)
1055
1330
 
1056
1331
  if nodal:
1057
1332
  nnz = test.space_restriction.node_count()
@@ -1064,8 +1339,8 @@ def _launch_integrate_kernel(
1064
1339
  temporary_store,
1065
1340
  shape=(
1066
1341
  nnz,
1067
- test.space.VALUE_DOF_COUNT,
1068
- trial.space.VALUE_DOF_COUNT,
1342
+ test.node_dof_count,
1343
+ trial.node_dof_count,
1069
1344
  ),
1070
1345
  dtype=output_dtype,
1071
1346
  device=device,
@@ -1093,6 +1368,75 @@ def _launch_integrate_kernel(
1093
1368
  ],
1094
1369
  device=device,
1095
1370
  )
1371
+ elif isinstance(test, LocalTestField):
1372
+ local_result = cache.borrow_temporary(
1373
+ temporary_store=temporary_store,
1374
+ device=device,
1375
+ requires_grad=False,
1376
+ shape=(
1377
+ quadrature.total_point_count(),
1378
+ test.value_dof_count,
1379
+ trial.value_dof_count,
1380
+ test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
1381
+ ),
1382
+ dtype=float,
1383
+ )
1384
+
1385
+ wp.launch(
1386
+ kernel=kernel,
1387
+ dim=(domain.element_count(), test.value_dof_count, trial.value_dof_count, trial.TAYLOR_DOF_COUNT),
1388
+ inputs=[
1389
+ qp_arg,
1390
+ domain_elt_arg,
1391
+ domain_elt_index_arg,
1392
+ field_arg_values,
1393
+ value_struct_values,
1394
+ local_result.array,
1395
+ ],
1396
+ device=device,
1397
+ )
1398
+
1399
+ vec_array_shape = (*local_result.array.shape[:-1], test.TAYLOR_DOF_COUNT)
1400
+ vec_array_dtype = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
1401
+ local_result_as_vec = wp.array(
1402
+ data=None,
1403
+ ptr=local_result.array.ptr,
1404
+ capacity=local_result.array.capacity,
1405
+ device=local_result.array.device,
1406
+ shape=vec_array_shape,
1407
+ dtype=vec_array_dtype,
1408
+ )
1409
+
1410
+ dispatch_kernel = make_bilinear_dispatch_kernel(test, trial, quadrature, accumulate_dtype)
1411
+
1412
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
1413
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1414
+ wp.launch(
1415
+ kernel=dispatch_kernel,
1416
+ dim=(
1417
+ test.space_restriction.node_count(),
1418
+ test.node_dof_count,
1419
+ trial.node_dof_count,
1420
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1421
+ ),
1422
+ inputs=[
1423
+ qp_arg,
1424
+ domain_elt_arg,
1425
+ domain_elt_index_arg,
1426
+ test_arg,
1427
+ test.global_field.eval_arg_value(device),
1428
+ trial_partition_arg,
1429
+ trial_topology_arg,
1430
+ trial.global_field.eval_arg_value(device),
1431
+ local_result_as_vec,
1432
+ triplet_rows,
1433
+ triplet_cols,
1434
+ triplet_values,
1435
+ ],
1436
+ device=device,
1437
+ )
1438
+
1439
+ local_result.release()
1096
1440
 
1097
1441
  else:
1098
1442
  trial_partition_arg = trial.space_partition.partition_arg_value(device)
@@ -1102,8 +1446,8 @@ def _launch_integrate_kernel(
1102
1446
  dim=(
1103
1447
  test.space_restriction.node_count(),
1104
1448
  trial.space.topology.MAX_NODES_PER_ELEMENT,
1105
- test.space.VALUE_DOF_COUNT,
1106
- trial.space.VALUE_DOF_COUNT,
1449
+ test.node_dof_count,
1450
+ trial.node_dof_count,
1107
1451
  ),
1108
1452
  inputs=[
1109
1453
  qp_arg,
@@ -1127,24 +1471,48 @@ def _launch_integrate_kernel(
1127
1471
  f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1128
1472
  )
1129
1473
 
1130
- else:
1131
- output = bsr_zeros(
1474
+ if output is None or add_to_output:
1475
+ bsr_result = bsr_zeros(
1132
1476
  rows_of_blocks=test.space_partition.node_count(),
1133
1477
  cols_of_blocks=trial.space_partition.node_count(),
1134
1478
  block_type=block_type,
1135
1479
  device=device,
1136
1480
  )
1481
+ else:
1482
+ bsr_result = output
1137
1483
 
1138
- bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
1484
+ bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values)
1139
1485
 
1140
1486
  # Do not wait for garbage collection
1141
1487
  triplet_values_temp.release()
1142
1488
  triplet_rows_temp.release()
1143
1489
  triplet_cols_temp.release()
1144
1490
 
1491
+ if add_to_output:
1492
+ output += bsr_result
1493
+ else:
1494
+ output = bsr_result
1495
+
1145
1496
  return output
1146
1497
 
1147
1498
 
1499
+ def _pick_assembly_strategy(
1500
+ assembly: Optional[str], nodal: bool, operators: Dict[str, Set[Operator]], arguments: IntegrandArguments
1501
+ ):
1502
+ if assembly is not None:
1503
+ if assembly not in ("generic", "nodal", "dispatch"):
1504
+ raise ValueError(f"Invalid assembly strategy'{assembly}'")
1505
+ return assembly
1506
+ elif nodal:
1507
+ return "nodal"
1508
+
1509
+ test_operators = operators.get(arguments.test_name, {})
1510
+ trial_operators = operators.get(arguments.trial_name, {})
1511
+ uses_at_node = at_node in test_operators or at_node in trial_operators
1512
+
1513
+ return "generic" if uses_at_node else "dispatch"
1514
+
1515
+
1148
1516
  def integrate(
1149
1517
  integrand: Integrand,
1150
1518
  domain: Optional[GeometryDomain] = None,
@@ -1158,6 +1526,8 @@ def integrate(
1158
1526
  device=None,
1159
1527
  temporary_store: Optional[cache.TemporaryStore] = None,
1160
1528
  kernel_options: Optional[Dict[str, Any]] = None,
1529
+ assembly: str = None,
1530
+ add: bool = False,
1161
1531
  ):
1162
1532
  """
1163
1533
  Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
@@ -1166,7 +1536,7 @@ def integrate(
1166
1536
  integrand: Form to be integrated, must have :func:`integrand` decorator
1167
1537
  domain: Integration domain. If None, deduced from fields
1168
1538
  quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1169
- nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1539
+ nodal: Deprecated. Use the equivalent assembly="nodal" instead.
1170
1540
  fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1171
1541
  values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1172
1542
  temporary_store: shared pool from which to allocate temporary arrays
@@ -1175,6 +1545,12 @@ def integrate(
1175
1545
  output_dtype: Scalar type for returned results in `output` is not provided. If None, defaults to `accumulate_dtype`
1176
1546
  device: Device on which to perform the integration
1177
1547
  kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1548
+ assembly: Specifies the strategy for assembling the integrated vector or matrix:
1549
+ - "nodal": For linear or bilinear forms, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1550
+ - "generic": Single-pass integration and shape-function evaluation. Makes no assumption about the integrand's content, but may lead to many redundant computations.
1551
+ - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
1552
+ - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1553
+ add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1178
1554
  """
1179
1555
  if fields is None:
1180
1556
  fields = {}
@@ -1182,13 +1558,22 @@ def integrate(
1182
1558
  if values is None:
1183
1559
  values = {}
1184
1560
 
1185
- if kernel_options is None:
1186
- kernel_options = {}
1187
-
1188
1561
  if not isinstance(integrand, Integrand):
1189
1562
  raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1190
1563
 
1191
- test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1564
+ # test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1565
+ arguments = _parse_integrand_arguments(integrand, fields)
1566
+
1567
+ test = None
1568
+ if arguments.test_name:
1569
+ test = arguments.field_args[arguments.test_name]
1570
+ trial = None
1571
+ if arguments.trial_name:
1572
+ if test is None:
1573
+ raise ValueError("A trial field cannot be provided without a test field")
1574
+ trial = arguments.field_args[arguments.trial_name]
1575
+ if test.domain != trial.domain:
1576
+ raise ValueError("Incompatible test and trial domains")
1192
1577
 
1193
1578
  if domain is None:
1194
1579
  if quadrature is not None:
@@ -1201,7 +1586,26 @@ def integrate(
1201
1586
  if test is not None and domain != test.domain:
1202
1587
  raise NotImplementedError("Mixing integration and test domain is not supported yet")
1203
1588
 
1204
- if nodal:
1589
+ if add and output is None:
1590
+ raise ValueError("An 'output' array or matrix needs to be provided for add=True")
1591
+
1592
+ if arguments.domain_name is not None:
1593
+ arguments.field_args[arguments.domain_name] = domain
1594
+
1595
+ _find_integrand_operators(integrand, arguments.field_args)
1596
+
1597
+ assembly = _pick_assembly_strategy(assembly, nodal, arguments=arguments, operators=integrand.operators)
1598
+ # print("assembly for ", integrand.name, ":", strategy)
1599
+
1600
+ if assembly == "dispatch":
1601
+ if test is not None:
1602
+ test = LocalTestField(test)
1603
+ arguments.field_args[arguments.test_name] = test
1604
+ if trial is not None:
1605
+ trial = LocalTrialField(trial)
1606
+ arguments.field_args[arguments.trial_name] = trial
1607
+
1608
+ if assembly == "nodal":
1205
1609
  if quadrature is not None:
1206
1610
  raise ValueError("Cannot specify quadrature for nodal integration")
1207
1611
 
@@ -1234,13 +1638,10 @@ def integrate(
1234
1638
  kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1235
1639
  integrand=integrand,
1236
1640
  domain=domain,
1237
- nodal=nodal,
1238
1641
  quadrature=quadrature,
1642
+ arguments=arguments,
1239
1643
  test=test,
1240
- test_name=test_name,
1241
1644
  trial=trial,
1242
- trial_name=trial_name,
1243
- fields=fields,
1244
1645
  accumulate_dtype=accumulate_dtype,
1245
1646
  output_dtype=output_dtype,
1246
1647
  kernel_options=kernel_options,
@@ -1252,16 +1653,16 @@ def integrate(
1252
1653
  FieldStruct=FieldStruct,
1253
1654
  ValueStruct=ValueStruct,
1254
1655
  domain=domain,
1255
- nodal=nodal,
1256
1656
  quadrature=quadrature,
1257
1657
  test=test,
1258
1658
  trial=trial,
1259
- fields=fields,
1659
+ fields=arguments.field_args,
1260
1660
  values=values,
1261
1661
  accumulate_dtype=accumulate_dtype,
1262
1662
  temporary_store=temporary_store,
1263
1663
  output_dtype=output_dtype,
1264
1664
  output=output,
1665
+ add_to_output=add,
1265
1666
  device=device,
1266
1667
  )
1267
1668
 
@@ -1340,6 +1741,30 @@ def get_interpolate_to_field_kernel(
1340
1741
  ValueStruct: wp.codegen.Struct,
1341
1742
  dest: FieldRestriction,
1342
1743
  ):
1744
+ @wp.func
1745
+ def _find_node_in_element(
1746
+ domain_arg: domain.ElementArg,
1747
+ domain_index_arg: domain.ElementIndexArg,
1748
+ dest_node_arg: dest.space_restriction.NodeArg,
1749
+ dest_eval_arg: dest.field.EvalArg,
1750
+ partition_node_index: int,
1751
+ ):
1752
+ element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1753
+
1754
+ for n in range(element_beg, element_end):
1755
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
1756
+ element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1757
+ coords = dest.space.node_coords_in_element(
1758
+ domain_arg,
1759
+ dest_eval_arg.space_arg,
1760
+ element_index,
1761
+ node_element_index.node_index_in_element,
1762
+ )
1763
+ if coords[0] != OUTSIDE:
1764
+ return element_index, node_element_index.node_index_in_element
1765
+
1766
+ return NULL_ELEMENT_INDEX, NULL_NODE_INDEX
1767
+
1343
1768
  def interpolate_to_field_kernel_fn(
1344
1769
  domain_arg: domain.ElementArg,
1345
1770
  domain_index_arg: domain.ElementIndexArg,
@@ -1355,8 +1780,20 @@ def get_interpolate_to_field_kernel(
1355
1780
  )
1356
1781
 
1357
1782
  if vol_sum > 0.0:
1358
- node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1359
- dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
1783
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1784
+
1785
+ # Grab first element containing node; there must be at least one since vol_sum != 0
1786
+ element_index, node_index_in_element = _find_node_in_element(
1787
+ domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, partition_node_index
1788
+ )
1789
+ dest.field.set_node_value(
1790
+ domain_arg,
1791
+ dest_eval_arg,
1792
+ element_index,
1793
+ node_index_in_element,
1794
+ partition_node_index,
1795
+ val_sum / vol_sum,
1796
+ )
1360
1797
 
1361
1798
  return interpolate_to_field_kernel_fn
1362
1799
 
@@ -1470,53 +1907,43 @@ def _generate_interpolate_kernel(
1470
1907
  domain: GeometryDomain,
1471
1908
  dest: Optional[Union[FieldLike, wp.array]],
1472
1909
  quadrature: Optional[Quadrature],
1473
- fields: Dict[str, FieldLike],
1910
+ arguments: IntegrandArguments,
1474
1911
  kernel_options: Optional[Dict[str, Any]] = None,
1475
1912
  ) -> wp.Kernel:
1476
- if kernel_options is None:
1477
- kernel_options = {}
1478
-
1479
- # Extract field arguments from integrand
1480
- field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
1481
- integrand, fields=fields, domain=domain
1482
- )
1483
-
1484
1913
  # Generate field struct
1485
- integrand_func = _translate_integrand(
1486
- integrand,
1487
- field_args,
1488
- )
1489
-
1490
- _register_integrand_field_wrappers(integrand_func, fields)
1914
+ FieldStruct = _gen_field_struct(arguments.field_args)
1915
+ ValueStruct = cache.get_argument_struct(arguments.value_args)
1491
1916
 
1492
- FieldStruct = _gen_field_struct(field_args)
1493
- ValueStruct = cache.get_argument_struct(value_args)
1917
+ _notify_operator_usage(integrand, arguments.field_args)
1494
1918
 
1495
1919
  # Check if kernel exist in cache
1920
+ field_names = "_".join(f"{k}{f.name}" for k, f in arguments.field_args.items())
1496
1921
  if isinstance(dest, FieldRestriction):
1497
- kernel_suffix = (
1498
- f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1499
- )
1922
+ kernel_suffix = f"_itp_{field_names}_{dest.domain.name}_{dest.space_restriction.space_partition.name}"
1500
1923
  else:
1501
1924
  dest_dtype = dest.dtype if dest else None
1502
1925
  type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
1503
1926
  if quadrature is None:
1504
- kernel_suffix = f"_itp_{FieldStruct.key}_{type_str}"
1927
+ kernel_suffix = f"_itp_{field_names}_{type_str}"
1505
1928
  else:
1506
- kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{type_str}"
1929
+ kernel_suffix = f"_itp_{field_names}_{quadrature.name}_{type_str}"
1507
1930
 
1508
1931
  kernel = cache.get_integrand_kernel(
1509
1932
  integrand=integrand,
1510
1933
  suffix=kernel_suffix,
1934
+ kernel_options=kernel_options,
1511
1935
  )
1512
1936
  if kernel is not None:
1513
1937
  return kernel, FieldStruct, ValueStruct
1514
1938
 
1515
- _check_field_compat(integrand, fields, field_args, domain)
1939
+ # Not found in cache, transform integrand and generate kernel
1940
+ _check_field_compat(integrand, arguments, domain)
1941
+
1942
+ integrand_func = IntegrandTransformer.apply(integrand, arguments.field_args)
1516
1943
 
1517
1944
  # Generate interpolation kernel
1518
1945
  if isinstance(dest, FieldRestriction):
1519
- # need to split into kernel + function for diffferentiability
1946
+ # need to split into kernel + function for differentiability
1520
1947
  interpolate_fn = get_interpolate_to_field_function(
1521
1948
  integrand_func,
1522
1949
  domain,
@@ -1531,11 +1958,7 @@ def _generate_interpolate_kernel(
1531
1958
  suffix=kernel_suffix,
1532
1959
  code_transformers=[
1533
1960
  PassFieldArgsToIntegrand(
1534
- arg_names=integrand.argspec.args,
1535
- field_args=field_args.keys(),
1536
- value_args=value_args.keys(),
1537
- sample_name=sample_name,
1538
- domain_name=domain_name,
1961
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
1539
1962
  )
1540
1963
  ],
1541
1964
  )
@@ -1572,11 +1995,7 @@ def _generate_interpolate_kernel(
1572
1995
  kernel_options=kernel_options,
1573
1996
  code_transformers=[
1574
1997
  PassFieldArgsToIntegrand(
1575
- arg_names=integrand.argspec.args,
1576
- field_args=field_args.keys(),
1577
- value_args=value_args.keys(),
1578
- sample_name=sample_name,
1579
- domain_name=domain_name,
1998
+ arg_names=integrand.argspec.args, parsed_args=arguments, integrand_func=integrand_func
1580
1999
  )
1581
2000
  ],
1582
2001
  )
@@ -1603,7 +2022,8 @@ def _launch_interpolate_kernel(
1603
2022
 
1604
2023
  field_arg_values = FieldStruct()
1605
2024
  for k, v in fields.items():
1606
- setattr(field_arg_values, k, v.eval_arg_value(device=device))
2025
+ if not isinstance(v, GeometryDomain):
2026
+ setattr(field_arg_values, k, v.eval_arg_value(device=device))
1607
2027
 
1608
2028
  value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1609
2029
 
@@ -1687,14 +2107,11 @@ def interpolate(
1687
2107
  if values is None:
1688
2108
  values = {}
1689
2109
 
1690
- if kernel_options is None:
1691
- kernel_options = {}
1692
-
1693
2110
  if not isinstance(integrand, Integrand):
1694
2111
  raise ValueError("integrand must be tagged with @integrand decorator")
1695
2112
 
1696
- test, _, trial, __ = _get_test_and_trial_fields(fields)
1697
- if test is not None or trial is not None:
2113
+ arguments = _parse_integrand_arguments(integrand, fields)
2114
+ if arguments.test_name or arguments.trial_name:
1698
2115
  raise ValueError("Test or Trial fields should not be used for interpolation")
1699
2116
 
1700
2117
  if isinstance(dest, DiscreteField):
@@ -1705,12 +2122,17 @@ def interpolate(
1705
2122
  elif quadrature is not None:
1706
2123
  domain = quadrature.domain
1707
2124
 
2125
+ if arguments.domain_name:
2126
+ arguments.field_args[arguments.domain_name] = domain
2127
+
2128
+ _find_integrand_operators(integrand, arguments.field_args)
2129
+
1708
2130
  kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
1709
2131
  integrand=integrand,
1710
2132
  domain=domain,
1711
2133
  dest=dest,
1712
2134
  quadrature=quadrature,
1713
- fields=fields,
2135
+ arguments=arguments,
1714
2136
  kernel_options=kernel_options,
1715
2137
  )
1716
2138
 
@@ -1723,7 +2145,7 @@ def interpolate(
1723
2145
  dest=dest,
1724
2146
  quadrature=quadrature,
1725
2147
  dim=dim,
1726
- fields=fields,
2148
+ fields=arguments.field_args,
1727
2149
  values=values,
1728
2150
  device=device,
1729
2151
  )