warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -5,13 +5,12 @@ import warp as wp
5
5
  import re
6
6
  import ast
7
7
 
8
- from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_diag
8
+ from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_assign
9
9
  from warp.types import type_length
10
10
  from warp.utils import array_cast
11
11
  from warp.codegen import get_annotations
12
12
 
13
13
  from warp.fem.domain import GeometryDomain
14
- from warp.fem.space import SpaceRestriction
15
14
  from warp.fem.field import (
16
15
  TestField,
17
16
  TrialField,
@@ -23,7 +22,7 @@ from warp.fem.field import (
23
22
  from warp.fem.quadrature import Quadrature, RegularQuadrature
24
23
  from warp.fem.operator import Operator, Integrand
25
24
  from warp.fem import cache
26
- from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE
25
+ from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE, make_free_sample
27
26
 
28
27
 
29
28
  def _resolve_path(func, node):
@@ -98,7 +97,7 @@ class IntegrandTransformer(ast.NodeTransformer):
98
97
  operator = arg_type.call_operator
99
98
 
100
99
  call.func = ast.Attribute(
101
- value=_path_to_ast_attribute(arg_type.__qualname__),
100
+ value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
102
101
  attr="call_operator",
103
102
  ctx=ast.Load(),
104
103
  )
@@ -164,7 +163,7 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
164
163
  for arg in argspec.args:
165
164
  arg_type = argspec.annotations[arg]
166
165
  if arg_type == Field:
167
- annotations[arg] = field_args[arg].EvalArg
166
+ annotations[arg] = field_args[arg].ElementEvalArg
168
167
  elif arg_type == Domain:
169
168
  annotations[arg] = field_args[arg].ElementArg
170
169
  else:
@@ -174,11 +173,9 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
174
173
  transformer = IntegrandTransformer(integrand, field_args)
175
174
 
176
175
  def is_field_like(f):
177
- # WAR for isinstance not supporting Union in Python < 3.10
178
- return any(isinstance(f, field_class) for field_class in FieldLike.__args__)
176
+ return isinstance(f, FieldLike)
179
177
 
180
178
  suffix = "_".join([f.name for f in field_args.values() if is_field_like(f)])
181
- key = integrand.name + suffix
182
179
 
183
180
  func = cache.get_integrand_function(
184
181
  integrand=integrand,
@@ -265,18 +262,14 @@ def _gen_field_struct(field_args: Dict[str, FieldLike]):
265
262
  setattr(Fields, name, arg.EvalArg())
266
263
  annotations[name] = arg.EvalArg
267
264
 
268
- Fields.__qualname__ = (
269
- Fields.__name__
270
- + "_"
271
- + "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
272
- )
273
-
274
265
  try:
275
266
  Fields.__annotations__ = annotations
276
267
  except AttributeError:
277
268
  setattr(Fields.__dict__, "__annotations__", annotations)
278
269
 
279
- return cache.get_struct(Fields)
270
+ suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
271
+
272
+ return cache.get_struct(Fields, suffix=suffix)
280
273
 
281
274
 
282
275
  def _gen_value_struct(value_args: Dict[str, type]):
@@ -299,25 +292,34 @@ def _gen_value_struct(value_args: Dict[str, type]):
299
292
  return arg_type_name(arg_type.cls)
300
293
  return getattr(arg_type, "__name__", str(arg_type))
301
294
 
302
- Values.__qualname__ = (
303
- Values.__name__
304
- + "_"
305
- + "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
306
- )
307
-
308
295
  try:
309
296
  Values.__annotations__ = annotations
310
297
  except AttributeError:
311
298
  setattr(Values.__dict__, "__annotations__", annotations)
312
299
 
313
- return cache.get_struct(Values)
300
+ suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
301
+
302
+ return cache.get_struct(Values, suffix=suffix)
314
303
 
315
304
 
316
305
  def _get_trial_arg():
317
306
  pass
318
307
 
308
+
319
309
  def _get_test_arg():
320
310
  pass
311
+
312
+
313
+ class _FieldWrappers:
314
+ pass
315
+
316
+
317
+ def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
318
+ integrand_func._field_wrappers = _FieldWrappers()
319
+ for name, field in fields.items():
320
+ setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
321
+
322
+
321
323
  class PassFieldArgsToIntegrand(ast.NodeTransformer):
322
324
  def __init__(
323
325
  self,
@@ -333,6 +335,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
333
335
  values_var_name: str = "values",
334
336
  domain_var_name: str = "domain_arg",
335
337
  sample_var_name: str = "sample",
338
+ field_wrappers_attr: str = "_field_wrappers",
336
339
  ):
337
340
  self._arg_names = arg_names
338
341
  self._field_args = field_args
@@ -346,6 +349,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
346
349
  self._values_var_name = values_var_name
347
350
  self._domain_var_name = domain_var_name
348
351
  self._sample_var_name = sample_var_name
352
+ self._field_wrappers_attr = field_wrappers_attr
349
353
 
350
354
  def visit_Call(self, call: ast.Call):
351
355
  call = self.generic_visit(call)
@@ -366,10 +370,25 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
366
370
  )
367
371
  elif arg in self._field_args:
368
372
  call.args.append(
369
- ast.Attribute(
370
- value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
371
- attr=arg,
372
- ctx=ast.Load(),
373
+ ast.Call(
374
+ func=ast.Attribute(
375
+ value=ast.Attribute(
376
+ value=ast.Name(id=self._func_name, ctx=ast.Load()),
377
+ attr=self._field_wrappers_attr,
378
+ ctx=ast.Load(),
379
+ ),
380
+ attr=arg,
381
+ ctx=ast.Load(),
382
+ ),
383
+ args=[
384
+ ast.Name(id=self._domain_var_name, ctx=ast.Load()),
385
+ ast.Attribute(
386
+ value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
387
+ attr=arg,
388
+ ctx=ast.Load(),
389
+ ),
390
+ ],
391
+ keywords=[],
373
392
  )
374
393
  )
375
394
  elif arg in self._value_args:
@@ -401,36 +420,6 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
401
420
  return call
402
421
 
403
422
 
404
- def get_integrate_null_kernel(
405
- integrand_func: wp.Function,
406
- domain: GeometryDomain,
407
- quadrature: Quadrature,
408
- FieldStruct: wp.codegen.Struct,
409
- ValueStruct: wp.codegen.Struct,
410
- ):
411
- def integrate_kernel_fn(
412
- qp_arg: quadrature.Arg,
413
- domain_arg: domain.ElementArg,
414
- domain_index_arg: domain.ElementIndexArg,
415
- fields: FieldStruct,
416
- values: ValueStruct,
417
- ):
418
- element_index = domain.element_index(domain_index_arg, wp.tid())
419
-
420
- test_dof_index = NULL_DOF_INDEX
421
- trial_dof_index = NULL_DOF_INDEX
422
-
423
- qp_point_count = quadrature.point_count(qp_arg, element_index)
424
- for k in range(qp_point_count):
425
- qp_index = quadrature.point_index(qp_arg, element_index, k)
426
- qp_coords = quadrature.point_coords(qp_arg, element_index, k)
427
- qp_weight = quadrature.point_weight(qp_arg, element_index, k)
428
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
429
- integrand_func(sample, fields, values)
430
-
431
- return integrate_kernel_fn
432
-
433
-
434
423
  def get_integrate_constant_kernel(
435
424
  integrand_func: wp.Function,
436
425
  domain: GeometryDomain,
@@ -453,14 +442,15 @@ def get_integrate_constant_kernel(
453
442
  test_dof_index = NULL_DOF_INDEX
454
443
  trial_dof_index = NULL_DOF_INDEX
455
444
 
456
- qp_point_count = quadrature.point_count(qp_arg, element_index)
445
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
457
446
  for k in range(qp_point_count):
458
- qp_index = quadrature.point_index(qp_arg, element_index, k)
459
- coords = quadrature.point_coords(qp_arg, element_index, k)
460
- qp_weight = quadrature.point_weight(qp_arg, element_index, k)
461
- vol = domain.element_measure(domain_arg, element_index, coords)
447
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
448
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
449
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
462
450
 
463
451
  sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
452
+ vol = domain.element_measure(domain_arg, sample)
453
+
464
454
  val = integrand_func(sample, fields, values)
465
455
 
466
456
  elem_sum += accumulate_dtype(qp_weight * vol * val)
@@ -476,42 +466,47 @@ def get_integrate_linear_kernel(
476
466
  quadrature: Quadrature,
477
467
  FieldStruct: wp.codegen.Struct,
478
468
  ValueStruct: wp.codegen.Struct,
479
- test_space: SpaceRestriction,
469
+ test: TestField,
470
+ output_dtype,
480
471
  accumulate_dtype,
481
472
  ):
482
473
  def integrate_kernel_fn(
483
474
  qp_arg: quadrature.Arg,
484
475
  domain_arg: domain.ElementArg,
485
476
  domain_index_arg: domain.ElementIndexArg,
486
- test_arg: test_space.NodeArg,
477
+ test_arg: test.space_restriction.NodeArg,
487
478
  fields: FieldStruct,
488
479
  values: ValueStruct,
489
- result: wp.array2d(dtype=accumulate_dtype),
480
+ result: wp.array2d(dtype=output_dtype),
490
481
  ):
491
- local_node_index = wp.tid()
492
- node_index = test_space.node_partition_index(test_arg, local_node_index)
493
- element_count = test_space.node_element_count(test_arg, local_node_index)
482
+ local_node_index, test_dof = wp.tid()
483
+ node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
484
+ element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
494
485
 
495
486
  trial_dof_index = NULL_DOF_INDEX
496
487
 
488
+ val_sum = accumulate_dtype(0.0)
489
+
497
490
  for n in range(element_count):
498
- node_element_index = test_space.node_element_index(test_arg, local_node_index, n)
491
+ node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
499
492
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
500
493
 
501
- qp_point_count = quadrature.point_count(qp_arg, element_index)
494
+ test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
495
+
496
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
502
497
  for k in range(qp_point_count):
503
- qp_index = quadrature.point_index(qp_arg, element_index, k)
504
- coords = quadrature.point_coords(qp_arg, element_index, k)
498
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
499
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
500
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
505
501
 
506
- qp_weight = quadrature.point_weight(qp_arg, element_index, k)
507
- vol = domain.element_measure(domain_arg, element_index, coords)
502
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
508
503
 
509
- for i in range(test_space.space.VALUE_DOF_COUNT):
510
- test_dof_index = DofIndex(node_element_index.node_index_in_element, i)
511
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
512
- val = integrand_func(sample, fields, values)
504
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
505
+ val = integrand_func(sample, fields, values)
506
+
507
+ val_sum += accumulate_dtype(qp_weight * vol * val)
513
508
 
514
- result[node_index, i] = result[node_index, i] + accumulate_dtype(qp_weight * vol * val)
509
+ result[node_index, test_dof] = output_dtype(val_sum)
515
510
 
516
511
  return integrate_kernel_fn
517
512
 
@@ -522,6 +517,7 @@ def get_integrate_linear_nodal_kernel(
522
517
  FieldStruct: wp.codegen.Struct,
523
518
  ValueStruct: wp.codegen.Struct,
524
519
  test: TestField,
520
+ output_dtype,
525
521
  accumulate_dtype,
526
522
  ):
527
523
  def integrate_kernel_fn(
@@ -530,7 +526,7 @@ def get_integrate_linear_nodal_kernel(
530
526
  test_restriction_arg: test.space_restriction.NodeArg,
531
527
  fields: FieldStruct,
532
528
  values: ValueStruct,
533
- result: wp.array2d(dtype=accumulate_dtype),
529
+ result: wp.array2d(dtype=output_dtype),
534
530
  ):
535
531
  local_node_index, dof = wp.tid()
536
532
 
@@ -546,6 +542,7 @@ def get_integrate_linear_nodal_kernel(
546
542
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
547
543
 
548
544
  coords = test.space.node_coords_in_element(
545
+ domain_arg,
549
546
  _get_test_arg(),
550
547
  element_index,
551
548
  node_element_index.node_index_in_element,
@@ -553,12 +550,12 @@ def get_integrate_linear_nodal_kernel(
553
550
 
554
551
  if coords[0] != OUTSIDE:
555
552
  node_weight = test.space.node_quadrature_weight(
553
+ domain_arg,
556
554
  _get_test_arg(),
557
555
  element_index,
558
556
  node_element_index.node_index_in_element,
559
557
  )
560
558
 
561
- vol = domain.element_measure(domain_arg, element_index, coords)
562
559
  test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
563
560
 
564
561
  sample = Sample(
@@ -569,11 +566,12 @@ def get_integrate_linear_nodal_kernel(
569
566
  test_dof_index,
570
567
  trial_dof_index,
571
568
  )
569
+ vol = domain.element_measure(domain_arg, sample)
572
570
  val = integrand_func(sample, fields, values)
573
571
 
574
572
  val_sum += accumulate_dtype(node_weight * vol * val)
575
573
 
576
- result[node_index, dof] = val_sum
574
+ result[node_index, dof] = output_dtype(val_sum)
577
575
 
578
576
  return integrate_kernel_fn
579
577
 
@@ -584,80 +582,75 @@ def get_integrate_bilinear_kernel(
584
582
  quadrature: Quadrature,
585
583
  FieldStruct: wp.codegen.Struct,
586
584
  ValueStruct: wp.codegen.Struct,
587
- test_space: SpaceRestriction,
585
+ test: TestField,
588
586
  trial: TrialField,
587
+ output_dtype,
589
588
  accumulate_dtype,
590
589
  ):
591
- NODES_PER_ELEMENT = trial.space.NODES_PER_ELEMENT
590
+ NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
592
591
 
593
592
  def integrate_kernel_fn(
594
593
  qp_arg: quadrature.Arg,
595
594
  domain_arg: domain.ElementArg,
596
595
  domain_index_arg: domain.ElementIndexArg,
597
- test_arg: test_space.NodeArg,
596
+ test_arg: test.space_restriction.NodeArg,
598
597
  trial_partition_arg: trial.space_partition.PartitionArg,
598
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
599
599
  fields: FieldStruct,
600
600
  values: ValueStruct,
601
601
  row_offsets: wp.array(dtype=int),
602
602
  triplet_rows: wp.array(dtype=int),
603
603
  triplet_cols: wp.array(dtype=int),
604
- triplet_values: wp.array3d(dtype=accumulate_dtype),
604
+ triplet_values: wp.array3d(dtype=output_dtype),
605
605
  ):
606
- test_local_node_index = wp.tid()
606
+ test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
607
+
608
+ element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
609
+ test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
607
610
 
608
- element_count = test_space.node_element_count(test_arg, test_local_node_index)
609
- test_node_index = test_space.node_partition_index(test_arg, test_local_node_index)
611
+ trial_dof_index = DofIndex(trial_node, trial_dof)
610
612
 
611
613
  for element in range(element_count):
612
- test_element_index = test_space.node_element_index(test_arg, test_local_node_index, element)
614
+ test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
613
615
  element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
614
- qp_point_count = quadrature.point_count(qp_arg, element_index)
616
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
615
617
 
616
- start_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT
618
+ test_dof_index = DofIndex(
619
+ test_element_index.node_index_in_element,
620
+ test_dof,
621
+ )
622
+
623
+ val_sum = accumulate_dtype(0.0)
617
624
 
618
625
  for k in range(qp_point_count):
619
- qp_index = quadrature.point_index(qp_arg, element_index, k)
620
- coords = quadrature.point_coords(qp_arg, element_index, k)
621
-
622
- qp_weight = quadrature.point_weight(qp_arg, element_index, k)
623
- vol = domain.element_measure(domain_arg, element_index, coords)
624
-
625
- offset_cur = start_offset
626
-
627
- for trial_n in range(NODES_PER_ELEMENT):
628
- for i in range(test_space.space.VALUE_DOF_COUNT):
629
- for j in range(trial.space.VALUE_DOF_COUNT):
630
- test_dof_index = DofIndex(
631
- test_element_index.node_index_in_element,
632
- i,
633
- )
634
- trial_dof_index = DofIndex(trial_n, j)
635
- sample = Sample(
636
- element_index,
637
- coords,
638
- qp_index,
639
- qp_weight,
640
- test_dof_index,
641
- trial_dof_index,
642
- )
643
- val = integrand_func(sample, fields, values)
644
- triplet_values[offset_cur, i, j] = triplet_values[offset_cur, i, j] + accumulate_dtype(
645
- qp_weight * vol * val
646
- )
647
-
648
- offset_cur += 1
649
-
650
- # Set column indices
651
- offset_cur = start_offset
652
- for trial_n in range(NODES_PER_ELEMENT):
626
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
627
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
628
+
629
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
630
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
631
+
632
+ sample = Sample(
633
+ element_index,
634
+ coords,
635
+ qp_index,
636
+ qp_weight,
637
+ test_dof_index,
638
+ trial_dof_index,
639
+ )
640
+ val = integrand_func(sample, fields, values)
641
+ val_sum += accumulate_dtype(qp_weight * vol * val)
642
+
643
+ block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
644
+ triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
645
+
646
+ # Set row and column indices
647
+ if test_dof == 0 and trial_dof == 0:
653
648
  trial_node_index = trial.space_partition.partition_node_index(
654
649
  trial_partition_arg,
655
- trial.space.element_node_index(_get_trial_arg(), element_index, trial_n),
650
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
656
651
  )
657
-
658
- triplet_rows[offset_cur] = test_node_index
659
- triplet_cols[offset_cur] = trial_node_index
660
- offset_cur += 1
652
+ triplet_rows[block_offset] = test_node_index
653
+ triplet_cols[block_offset] = trial_node_index
661
654
 
662
655
  return integrate_kernel_fn
663
656
 
@@ -668,6 +661,7 @@ def get_integrate_bilinear_nodal_kernel(
668
661
  FieldStruct: wp.codegen.Struct,
669
662
  ValueStruct: wp.codegen.Struct,
670
663
  test: TestField,
664
+ output_dtype,
671
665
  accumulate_dtype,
672
666
  ):
673
667
  def integrate_kernel_fn(
@@ -678,7 +672,7 @@ def get_integrate_bilinear_nodal_kernel(
678
672
  values: ValueStruct,
679
673
  triplet_rows: wp.array(dtype=int),
680
674
  triplet_cols: wp.array(dtype=int),
681
- triplet_values: wp.array3d(dtype=accumulate_dtype),
675
+ triplet_values: wp.array3d(dtype=output_dtype),
682
676
  ):
683
677
  local_node_index, test_dof, trial_dof = wp.tid()
684
678
 
@@ -692,6 +686,7 @@ def get_integrate_bilinear_nodal_kernel(
692
686
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
693
687
 
694
688
  coords = test.space.node_coords_in_element(
689
+ domain_arg,
695
690
  _get_test_arg(),
696
691
  element_index,
697
692
  node_element_index.node_index_in_element,
@@ -699,13 +694,12 @@ def get_integrate_bilinear_nodal_kernel(
699
694
 
700
695
  if coords[0] != OUTSIDE:
701
696
  node_weight = test.space.node_quadrature_weight(
697
+ domain_arg,
702
698
  _get_test_arg(),
703
699
  element_index,
704
700
  node_element_index.node_index_in_element,
705
701
  )
706
702
 
707
- vol = domain.element_measure(domain_arg, element_index, coords)
708
-
709
703
  test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
710
704
  trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
711
705
 
@@ -717,11 +711,12 @@ def get_integrate_bilinear_nodal_kernel(
717
711
  test_dof_index,
718
712
  trial_dof_index,
719
713
  )
714
+ vol = domain.element_measure(domain_arg, sample)
720
715
  val = integrand_func(sample, fields, values)
721
716
 
722
717
  val_sum += accumulate_dtype(node_weight * vol * val)
723
718
 
724
- triplet_values[local_node_index, test_dof, trial_dof] = val_sum
719
+ triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
725
720
  triplet_rows[local_node_index] = node_index
726
721
  triplet_cols[local_node_index] = node_index
727
722
 
@@ -738,8 +733,12 @@ def _generate_integrate_kernel(
738
733
  trial: Optional[TrialField],
739
734
  trial_name: str,
740
735
  fields: Dict[str, FieldLike],
736
+ output_dtype: type,
741
737
  accumulate_dtype: type,
738
+ kernel_options: Dict[str, Any] = {},
742
739
  ) -> wp.Kernel:
740
+ output_dtype = wp.types.type_scalar_type(output_dtype)
741
+
743
742
  # Extract field arguments from integrand
744
743
  field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
745
744
  integrand, fields=fields, domain=domain
@@ -749,7 +748,7 @@ def _generate_integrate_kernel(
749
748
  ValueStruct = _gen_value_struct(value_args)
750
749
 
751
750
  # Check if kernel exist in cache
752
- kernel_suffix = f"_itg_{domain.name}_{FieldStruct.key}"
751
+ kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
753
752
  if nodal:
754
753
  kernel_suffix += "_nodal"
755
754
  else:
@@ -774,6 +773,8 @@ def _generate_integrate_kernel(
774
773
  field_args,
775
774
  )
776
775
 
776
+ _register_integrand_field_wrappers(integrand_func, fields)
777
+
777
778
  if test is None and trial is None:
778
779
  integrate_kernel_fn = get_integrate_constant_kernel(
779
780
  integrand_func,
@@ -791,6 +792,7 @@ def _generate_integrate_kernel(
791
792
  FieldStruct,
792
793
  ValueStruct,
793
794
  test=test,
795
+ output_dtype=output_dtype,
794
796
  accumulate_dtype=accumulate_dtype,
795
797
  )
796
798
  else:
@@ -800,7 +802,8 @@ def _generate_integrate_kernel(
800
802
  quadrature,
801
803
  FieldStruct,
802
804
  ValueStruct,
803
- test_space=test.space_restriction,
805
+ test=test,
806
+ output_dtype=output_dtype,
804
807
  accumulate_dtype=accumulate_dtype,
805
808
  )
806
809
  else:
@@ -811,6 +814,7 @@ def _generate_integrate_kernel(
811
814
  FieldStruct,
812
815
  ValueStruct,
813
816
  test=test,
817
+ output_dtype=output_dtype,
814
818
  accumulate_dtype=accumulate_dtype,
815
819
  )
816
820
  else:
@@ -820,8 +824,9 @@ def _generate_integrate_kernel(
820
824
  quadrature,
821
825
  FieldStruct,
822
826
  ValueStruct,
823
- test_space=test.space_restriction,
827
+ test=test,
824
828
  trial=trial,
829
+ output_dtype=output_dtype,
825
830
  accumulate_dtype=accumulate_dtype,
826
831
  )
827
832
 
@@ -829,6 +834,7 @@ def _generate_integrate_kernel(
829
834
  integrand=integrand,
830
835
  kernel_fn=integrate_kernel_fn,
831
836
  suffix=kernel_suffix,
837
+ kernel_options=kernel_options,
832
838
  code_transformers=[
833
839
  PassFieldArgsToIntegrand(
834
840
  arg_names=integrand.argspec.args,
@@ -837,7 +843,7 @@ def _generate_integrate_kernel(
837
843
  sample_name=sample_name,
838
844
  domain_name=domain_name,
839
845
  test_name=test_name,
840
- trial_name=trial_name
846
+ trial_name=trial_name,
841
847
  )
842
848
  ],
843
849
  )
@@ -846,7 +852,7 @@ def _generate_integrate_kernel(
846
852
 
847
853
 
848
854
  def _launch_integrate_kernel(
849
- kernel: wp.kernel,
855
+ kernel: wp.Kernel,
850
856
  FieldStruct: wp.codegen.Struct,
851
857
  ValueStruct: wp.codegen.Struct,
852
858
  domain: GeometryDomain,
@@ -857,16 +863,11 @@ def _launch_integrate_kernel(
857
863
  fields: Dict[str, FieldLike],
858
864
  values: Dict[str, Any],
859
865
  accumulate_dtype: type,
866
+ temporary_store: Optional[cache.TemporaryStore],
860
867
  output_dtype: type,
861
868
  output: Optional[Union[wp.array, BsrMatrix]],
862
869
  device,
863
- ) -> wp.Kernel:
864
- if output_dtype is None:
865
- if output is not None:
866
- output_dtype = output.dtype
867
- else:
868
- output_dtype = accumulate_dtype
869
-
870
+ ):
870
871
  # Set-up launch arguments
871
872
  domain_elt_arg = domain.element_arg_value(device=device)
872
873
  domain_elt_index_arg = domain.element_index_arg_value(device=device)
@@ -882,14 +883,23 @@ def _launch_integrate_kernel(
882
883
  for k, v in values.items():
883
884
  setattr(value_struct_values, k, v)
884
885
 
885
- # Constant
886
+ # Constant form
886
887
  if test is None and trial is None:
887
- if output is None or output.dtype != accumulate_dtype:
888
- result = wp.zeros(shape=(1), device=device, dtype=output_dtype)
888
+ if output is not None and output.dtype == accumulate_dtype:
889
+ if output.size < 1:
890
+ raise RuntimeError("Output array must be of size at least 1")
891
+ accumulate_array = output
889
892
  else:
890
- result = output
891
- result.zero_()
893
+ accumulate_temporary = cache.borrow_temporary(
894
+ shape=(1),
895
+ device=device,
896
+ dtype=accumulate_dtype,
897
+ temporary_store=temporary_store,
898
+ requires_grad=output is not None and output.requires_grad,
899
+ )
900
+ accumulate_array = accumulate_temporary.array
892
901
 
902
+ accumulate_array.zero_()
893
903
  wp.launch(
894
904
  kernel=kernel,
895
905
  dim=domain.element_count(),
@@ -899,43 +909,77 @@ def _launch_integrate_kernel(
899
909
  domain_elt_index_arg,
900
910
  field_arg_values,
901
911
  value_struct_values,
902
- result,
912
+ accumulate_array,
903
913
  ],
904
914
  device=device,
905
915
  )
906
916
 
907
- if output is None:
908
- return output_dtype(result.numpy()[0])
917
+ if output == accumulate_array:
918
+ return output
919
+ elif output is None:
920
+ return accumulate_array.numpy()[0]
909
921
  else:
910
- if output != result:
911
- array_cast(in_array=result, out_array=output)
922
+ array_cast(in_array=accumulate_array, out_array=output)
912
923
  return output
913
924
 
914
925
  test_arg = test.space_restriction.node_arg(device=device)
915
926
 
916
927
  # Linear form
917
928
  if trial is None:
918
- if test.space.VALUE_DOF_COUNT == 1:
919
- result_dtype = accumulate_dtype
929
+ # If an output array is provided with the correct type, accumulate directly into it
930
+ # Otherwise, grab a temporary array
931
+ if output is None:
932
+ if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
933
+ output_shape = (test.space_partition.node_count(),)
934
+ elif type_length(output_dtype) == 1:
935
+ output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
936
+ else:
937
+ raise RuntimeError(
938
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
939
+ )
940
+
941
+ output_temporary = cache.borrow_temporary(
942
+ temporary_store=temporary_store,
943
+ shape=output_shape,
944
+ dtype=output_dtype,
945
+ device=device,
946
+ )
947
+
948
+ output = output_temporary.array
949
+
920
950
  else:
921
- result_dtype = wp.vec(length=test.space.VALUE_DOF_COUNT, dtype=accumulate_dtype)
951
+ output_temporary = None
922
952
 
923
- result_array = wp.zeros(
924
- shape=test.space_partition.node_count(),
925
- dtype=result_dtype,
926
- device=device,
927
- )
953
+ if output.shape[0] < test.space_partition.node_count():
954
+ raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
955
+
956
+ output_dtype = output.dtype
957
+ if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
958
+ if type_length(output_dtype) != 1:
959
+ raise RuntimeError(
960
+ f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
961
+ )
962
+ if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
963
+ raise RuntimeError(
964
+ f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
965
+ )
928
966
 
929
967
  # Launch the integration on the kernel on a 2d scalar view of the actual array
930
- result_2d_view = wp.array(
931
- data=None,
932
- ptr=result_array.ptr,
933
- capacity=result_array.capacity,
934
- owner=False,
935
- device=result_array.device,
936
- shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
937
- dtype=accumulate_dtype,
938
- )
968
+ output.zero_()
969
+
970
+ def as_2d_array(array):
971
+ return wp.array(
972
+ data=None,
973
+ ptr=array.ptr,
974
+ capacity=array.capacity,
975
+ owner=False,
976
+ device=array.device,
977
+ shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
978
+ dtype=wp.types.type_scalar_type(output_dtype),
979
+ grad=None if array.grad is None else as_2d_array(array.grad),
980
+ )
981
+
982
+ output_view = output if output.ndim == 2 else as_2d_array(output)
939
983
 
940
984
  if nodal:
941
985
  wp.launch(
@@ -947,14 +991,14 @@ def _launch_integrate_kernel(
947
991
  test_arg,
948
992
  field_arg_values,
949
993
  value_struct_values,
950
- result_2d_view,
994
+ output_view,
951
995
  ],
952
996
  device=device,
953
997
  )
954
998
  else:
955
999
  wp.launch(
956
1000
  kernel=kernel,
957
- dim=test.space_restriction.node_count(),
1001
+ dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
958
1002
  inputs=[
959
1003
  qp_arg,
960
1004
  domain_elt_arg,
@@ -962,55 +1006,47 @@ def _launch_integrate_kernel(
962
1006
  test_arg,
963
1007
  field_arg_values,
964
1008
  value_struct_values,
965
- result_2d_view,
1009
+ output_view,
966
1010
  ],
967
1011
  device=device,
968
1012
  )
969
1013
 
970
- if output_dtype == result_array.dtype:
971
- return result_array
972
-
973
- output_type_length = type_length(output_dtype)
974
- if output_type_length == test.space.VALUE_DOF_COUNT:
975
- cast_result = wp.empty(dtype=output_dtype, shape=result_array.shape)
976
- else:
977
- cast_result = wp.empty(dtype=output_dtype, shape=result_2d_view.shape)
1014
+ if output_temporary is not None:
1015
+ return output_temporary.detach()
978
1016
 
979
- array_cast(in_array=result_array, out_array=cast_result)
980
- return cast_result
1017
+ return output
981
1018
 
982
1019
  # Bilinear form
983
1020
 
984
1021
  if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
985
- block_type = accumulate_dtype
1022
+ block_type = output_dtype
986
1023
  else:
987
- block_type = wp.types.matrix(
988
- shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=accumulate_dtype
1024
+ block_type = cache.cached_mat_type(
1025
+ shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
989
1026
  )
990
1027
 
991
- bsr_matrix = bsr_zeros(
992
- rows_of_blocks=test.space_partition.node_count(),
993
- cols_of_blocks=trial.space_partition.node_count(),
994
- block_type=block_type,
995
- device=device,
996
- )
997
-
998
1028
  if nodal:
999
1029
  nnz = test.space_restriction.node_count()
1000
1030
  else:
1001
- nnz = test.space_restriction.total_node_element_count() * trial.space.NODES_PER_ELEMENT
1031
+ nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
1002
1032
 
1003
- triplet_rows = wp.empty(n=nnz, dtype=int, device=device)
1004
- triplet_cols = wp.empty(n=nnz, dtype=int, device=device)
1005
- triplet_values = wp.zeros(
1033
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1034
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1035
+ triplet_values_temp = cache.borrow_temporary(
1036
+ temporary_store,
1006
1037
  shape=(
1007
1038
  nnz,
1008
1039
  test.space.VALUE_DOF_COUNT,
1009
1040
  trial.space.VALUE_DOF_COUNT,
1010
1041
  ),
1011
- dtype=accumulate_dtype,
1042
+ dtype=output_dtype,
1012
1043
  device=device,
1013
1044
  )
1045
+ triplet_cols = triplet_cols_temp.array
1046
+ triplet_rows = triplet_rows_temp.array
1047
+ triplet_values = triplet_values_temp.array
1048
+
1049
+ triplet_values.zero_()
1014
1050
 
1015
1051
  if nodal:
1016
1052
  wp.launch(
@@ -1033,15 +1069,22 @@ def _launch_integrate_kernel(
1033
1069
  offsets = test.space_restriction.partition_element_offsets()
1034
1070
 
1035
1071
  trial_partition_arg = trial.space_partition.partition_arg_value(device)
1072
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1036
1073
  wp.launch(
1037
1074
  kernel=kernel,
1038
- dim=test.space_restriction.node_count(),
1075
+ dim=(
1076
+ test.space_restriction.node_count(),
1077
+ trial.space.topology.NODES_PER_ELEMENT,
1078
+ test.space.VALUE_DOF_COUNT,
1079
+ trial.space.VALUE_DOF_COUNT,
1080
+ ),
1039
1081
  inputs=[
1040
1082
  qp_arg,
1041
1083
  domain_elt_arg,
1042
1084
  domain_elt_index_arg,
1043
1085
  test_arg,
1044
1086
  trial_partition_arg,
1087
+ trial_topology_arg,
1045
1088
  field_arg_values,
1046
1089
  value_struct_values,
1047
1090
  offsets,
@@ -1052,38 +1095,63 @@ def _launch_integrate_kernel(
1052
1095
  device=device,
1053
1096
  )
1054
1097
 
1055
- bsr_set_from_triplets(bsr_matrix, triplet_rows, triplet_cols, triplet_values)
1056
- return bsr_matrix if output_dtype == accumulate_dtype else bsr_copy(bsr_matrix, scalar_type=output_dtype)
1098
+ if output is not None:
1099
+ if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
1100
+ raise RuntimeError(
1101
+ f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
1102
+ )
1103
+
1104
+ else:
1105
+ output = bsr_zeros(
1106
+ rows_of_blocks=test.space_partition.node_count(),
1107
+ cols_of_blocks=trial.space_partition.node_count(),
1108
+ block_type=block_type,
1109
+ device=device,
1110
+ )
1111
+
1112
+ bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
1113
+
1114
+ # Do not wait for garbage collection
1115
+ triplet_values_temp.release()
1116
+ triplet_rows_temp.release()
1117
+ triplet_cols_temp.release()
1118
+
1119
+ return output
1057
1120
 
1058
1121
 
1059
1122
  def integrate(
1060
1123
  integrand: Integrand,
1061
- domain: GeometryDomain = None,
1062
- quadrature: Quadrature = None,
1124
+ domain: Optional[GeometryDomain] = None,
1125
+ quadrature: Optional[Quadrature] = None,
1063
1126
  nodal: bool = False,
1064
- fields={},
1065
- values={},
1127
+ fields: Dict[str, FieldLike] = {},
1128
+ values: Dict[str, Any] = {},
1129
+ accumulate_dtype: type = wp.float64,
1130
+ output_dtype: Optional[type] = None,
1131
+ output: Optional[Union[BsrMatrix, wp.array]] = None,
1066
1132
  device=None,
1067
- accumulate_dtype=wp.float64,
1068
- output_dtype=None,
1069
- output=None,
1133
+ temporary_store: Optional[cache.TemporaryStore] = None,
1134
+ kernel_options: Dict[str, Any] = {},
1070
1135
  ):
1071
1136
  """
1072
1137
  Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
1073
1138
 
1074
1139
  Args:
1075
- integrand: Form to be integrated, must have `wp.integrand` decorator
1140
+ integrand: Form to be integrated, must have :func:`integrand` decorator
1076
1141
  domain: Integration domain. If None, deduced from fields
1077
1142
  quadrature: Quadrature formula. If None, deduced from domain and fields degree.
1078
1143
  nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
1079
1144
  fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
1080
- values: Additional variable values to be passed to the integrand, can by of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1081
- device: Device on which to perform the integration
1145
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1146
+ temporary_store: shared pool from which to allocate temporary arrays
1082
1147
  accumulate_dtype: Scalar type to be used for accumulating integration samples
1083
- output_dtype: Scalar type for returned results. If None, defaults to accumulate_dtype
1148
+ output: Sparse matrix or warp array into which to store the result of the integration
1149
+ output_dtype: Scalar type for returned results in `output` is notr provided. If None, defaults to `accumulate_dtype`
1150
+ device: Device on which to perform the integration
1151
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1084
1152
  """
1085
1153
  if not isinstance(integrand, Integrand):
1086
- raise ValueError("integrand must be tagged with @integrand decorator")
1154
+ raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
1087
1155
 
1088
1156
  test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
1089
1157
 
@@ -1111,15 +1179,23 @@ def integrate(
1111
1179
  )
1112
1180
  else:
1113
1181
  if quadrature is None:
1114
- order = 0
1115
- if test is not None:
1116
- order += test.space.degree
1117
- if trial is not None:
1118
- order += trial.space.degree
1182
+ order = sum(field.degree for field in fields.values())
1119
1183
  quadrature = RegularQuadrature(domain=domain, order=order)
1120
1184
  elif domain != quadrature.domain:
1121
1185
  raise ValueError("Incompatible integration and quadrature domain")
1122
1186
 
1187
+ # Canonicalize types
1188
+ accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
1189
+ if output is not None:
1190
+ if isinstance(output, BsrMatrix):
1191
+ output_dtype = output.scalar_type
1192
+ else:
1193
+ output_dtype = output.dtype
1194
+ elif output_dtype is None:
1195
+ output_dtype = accumulate_dtype
1196
+ else:
1197
+ output_dtype = wp.types.type_to_warp(output_dtype)
1198
+
1123
1199
  kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
1124
1200
  integrand=integrand,
1125
1201
  domain=domain,
@@ -1131,6 +1207,8 @@ def integrate(
1131
1207
  trial_name=trial_name,
1132
1208
  fields=fields,
1133
1209
  accumulate_dtype=accumulate_dtype,
1210
+ output_dtype=output_dtype,
1211
+ kernel_options=kernel_options,
1134
1212
  )
1135
1213
 
1136
1214
  return _launch_integrate_kernel(
@@ -1145,13 +1223,14 @@ def integrate(
1145
1223
  fields=fields,
1146
1224
  values=values,
1147
1225
  accumulate_dtype=accumulate_dtype,
1226
+ temporary_store=temporary_store,
1148
1227
  output_dtype=output_dtype,
1149
1228
  output=output,
1150
1229
  device=device,
1151
1230
  )
1152
1231
 
1153
1232
 
1154
- def get_interpolate_kernel(
1233
+ def get_interpolate_to_field_function(
1155
1234
  integrand_func: wp.Function,
1156
1235
  domain: GeometryDomain,
1157
1236
  FieldStruct: wp.codegen.Struct,
@@ -1160,7 +1239,8 @@ def get_interpolate_kernel(
1160
1239
  ):
1161
1240
  value_type = dest.space.dtype
1162
1241
 
1163
- def interpolate_kernel_fn(
1242
+ def interpolate_to_field_fn(
1243
+ local_node_index: int,
1164
1244
  domain_arg: domain.ElementArg,
1165
1245
  domain_index_arg: domain.ElementIndexArg,
1166
1246
  dest_node_arg: dest.space_restriction.NodeArg,
@@ -1168,19 +1248,15 @@ def get_interpolate_kernel(
1168
1248
  fields: FieldStruct,
1169
1249
  values: ValueStruct,
1170
1250
  ):
1171
- local_node_index = wp.tid()
1172
1251
  node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1173
-
1174
1252
  element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
1175
- if element_count == 0:
1176
- return
1177
1253
 
1178
1254
  test_dof_index = NULL_DOF_INDEX
1179
1255
  trial_dof_index = NULL_DOF_INDEX
1180
1256
  node_weight = 1.0
1181
1257
 
1182
- # Volume-weighted average accross elements
1183
- # Superfluous if the function is continuous, but we might as well
1258
+ # Volume-weighted average across elements
1259
+ # Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
1184
1260
 
1185
1261
  val_sum = value_type(0.0)
1186
1262
  vol_sum = float(0.0)
@@ -1190,14 +1266,13 @@ def get_interpolate_kernel(
1190
1266
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1191
1267
 
1192
1268
  coords = dest.space.node_coords_in_element(
1269
+ domain_arg,
1193
1270
  dest_eval_arg.space_arg,
1194
1271
  element_index,
1195
1272
  node_element_index.node_index_in_element,
1196
1273
  )
1197
1274
 
1198
1275
  if coords[0] != OUTSIDE:
1199
- vol = domain.element_measure(domain_arg, element_index, coords)
1200
-
1201
1276
  sample = Sample(
1202
1277
  element_index,
1203
1278
  coords,
@@ -1206,20 +1281,118 @@ def get_interpolate_kernel(
1206
1281
  test_dof_index,
1207
1282
  trial_dof_index,
1208
1283
  )
1284
+ vol = domain.element_measure(domain_arg, sample)
1209
1285
  val = integrand_func(sample, fields, values)
1210
1286
 
1211
1287
  vol_sum += vol
1212
1288
  val_sum += vol * val
1213
1289
 
1290
+ return val_sum, vol_sum
1291
+
1292
+ return interpolate_to_field_fn
1293
+
1294
+
1295
+ def get_interpolate_to_field_kernel(
1296
+ interpolate_to_field_fn: wp.Function,
1297
+ domain: GeometryDomain,
1298
+ FieldStruct: wp.codegen.Struct,
1299
+ ValueStruct: wp.codegen.Struct,
1300
+ dest: FieldRestriction,
1301
+ ):
1302
+ def interpolate_to_field_kernel_fn(
1303
+ domain_arg: domain.ElementArg,
1304
+ domain_index_arg: domain.ElementIndexArg,
1305
+ dest_node_arg: dest.space_restriction.NodeArg,
1306
+ dest_eval_arg: dest.field.EvalArg,
1307
+ fields: FieldStruct,
1308
+ values: ValueStruct,
1309
+ ):
1310
+ local_node_index = wp.tid()
1311
+
1312
+ val_sum, vol_sum = interpolate_to_field_fn(
1313
+ local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
1314
+ )
1315
+
1214
1316
  if vol_sum > 0.0:
1317
+ node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1215
1318
  dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
1216
1319
 
1217
- return interpolate_kernel_fn
1320
+ return interpolate_to_field_kernel_fn
1321
+
1322
+
1323
+ def get_interpolate_to_array_kernel(
1324
+ integrand_func: wp.Function,
1325
+ domain: GeometryDomain,
1326
+ quadrature: Quadrature,
1327
+ FieldStruct: wp.codegen.Struct,
1328
+ ValueStruct: wp.codegen.Struct,
1329
+ value_type: type,
1330
+ ):
1331
+ def interpolate_to_array_kernel_fn(
1332
+ qp_arg: quadrature.Arg,
1333
+ domain_arg: quadrature.domain.ElementArg,
1334
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1335
+ fields: FieldStruct,
1336
+ values: ValueStruct,
1337
+ result: wp.array(dtype=value_type),
1338
+ ):
1339
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1340
+
1341
+ test_dof_index = NULL_DOF_INDEX
1342
+ trial_dof_index = NULL_DOF_INDEX
1343
+
1344
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1345
+ for k in range(qp_point_count):
1346
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1347
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1348
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1349
+
1350
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1351
+
1352
+ result[qp_index] = integrand_func(sample, fields, values)
1353
+
1354
+ return interpolate_to_array_kernel_fn
1355
+
1356
+
1357
+ def get_interpolate_nonvalued_kernel(
1358
+ integrand_func: wp.Function,
1359
+ domain: GeometryDomain,
1360
+ quadrature: Quadrature,
1361
+ FieldStruct: wp.codegen.Struct,
1362
+ ValueStruct: wp.codegen.Struct,
1363
+ ):
1364
+ def interpolate_nonvalued_kernel_fn(
1365
+ qp_arg: quadrature.Arg,
1366
+ domain_arg: quadrature.domain.ElementArg,
1367
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1368
+ fields: FieldStruct,
1369
+ values: ValueStruct,
1370
+ ):
1371
+ element_index = domain.element_index(domain_index_arg, wp.tid())
1372
+
1373
+ test_dof_index = NULL_DOF_INDEX
1374
+ trial_dof_index = NULL_DOF_INDEX
1375
+
1376
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1377
+ for k in range(qp_point_count):
1378
+ qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1379
+ coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1380
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1381
+
1382
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1383
+ integrand_func(sample, fields, values)
1218
1384
 
1385
+ return interpolate_nonvalued_kernel_fn
1219
1386
 
1220
- def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields: Dict[str, FieldLike]) -> wp.Kernel:
1221
- domain = dest.domain
1222
1387
 
1388
+ def _generate_interpolate_kernel(
1389
+ integrand: Integrand,
1390
+ domain: GeometryDomain,
1391
+ dest: Optional[Union[FieldLike, wp.array]],
1392
+ quadrature: Optional[Quadrature],
1393
+ fields: Dict[str, FieldLike],
1394
+ kernel_options: Dict[str, Any] = {},
1395
+ ) -> wp.Kernel:
1223
1396
  # Extract field arguments from integrand
1224
1397
  field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
1225
1398
  integrand, fields=fields, domain=domain
@@ -1231,11 +1404,20 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
1231
1404
  field_args,
1232
1405
  )
1233
1406
 
1407
+ _register_integrand_field_wrappers(integrand_func, fields)
1408
+
1234
1409
  FieldStruct = _gen_field_struct(field_args)
1235
1410
  ValueStruct = _gen_value_struct(value_args)
1236
1411
 
1237
1412
  # Check if kernel exist in cache
1238
- kernel_suffix = f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1413
+ if isinstance(dest, FieldRestriction):
1414
+ kernel_suffix = (
1415
+ f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1416
+ )
1417
+ elif wp.types.is_array(dest):
1418
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
1419
+ else:
1420
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
1239
1421
 
1240
1422
  kernel = cache.get_integrand_kernel(
1241
1423
  integrand=integrand,
@@ -1245,18 +1427,61 @@ def _generate_interpolate_kernel(integrand: Integrand, dest: FieldLike, fields:
1245
1427
  return kernel, FieldStruct, ValueStruct
1246
1428
 
1247
1429
  # Generate interpolation kernel
1248
- interpolate_kernel_fn = get_interpolate_kernel(
1249
- integrand_func,
1250
- domain,
1251
- dest=dest,
1252
- FieldStruct=FieldStruct,
1253
- ValueStruct=ValueStruct,
1254
- )
1430
+ if isinstance(dest, FieldRestriction):
1431
+ # need to split into kernel + function for diffferentiability
1432
+ interpolate_fn = get_interpolate_to_field_function(
1433
+ integrand_func,
1434
+ domain,
1435
+ dest=dest,
1436
+ FieldStruct=FieldStruct,
1437
+ ValueStruct=ValueStruct,
1438
+ )
1439
+
1440
+ interpolate_fn = cache.get_integrand_function(
1441
+ integrand=integrand,
1442
+ func=interpolate_fn,
1443
+ suffix=kernel_suffix,
1444
+ code_transformers=[
1445
+ PassFieldArgsToIntegrand(
1446
+ arg_names=integrand.argspec.args,
1447
+ field_args=field_args.keys(),
1448
+ value_args=value_args.keys(),
1449
+ sample_name=sample_name,
1450
+ domain_name=domain_name,
1451
+ )
1452
+ ],
1453
+ )
1454
+
1455
+ interpolate_kernel_fn = get_interpolate_to_field_kernel(
1456
+ interpolate_fn,
1457
+ domain,
1458
+ dest=dest,
1459
+ FieldStruct=FieldStruct,
1460
+ ValueStruct=ValueStruct,
1461
+ )
1462
+ elif wp.types.is_array(dest):
1463
+ interpolate_kernel_fn = get_interpolate_to_array_kernel(
1464
+ integrand_func,
1465
+ domain=domain,
1466
+ quadrature=quadrature,
1467
+ value_type=dest.dtype,
1468
+ FieldStruct=FieldStruct,
1469
+ ValueStruct=ValueStruct,
1470
+ )
1471
+ else:
1472
+ interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
1473
+ integrand_func,
1474
+ domain=domain,
1475
+ quadrature=quadrature,
1476
+ FieldStruct=FieldStruct,
1477
+ ValueStruct=ValueStruct,
1478
+ )
1255
1479
 
1256
1480
  kernel = cache.get_integrand_kernel(
1257
1481
  integrand=integrand,
1258
1482
  kernel_fn=interpolate_kernel_fn,
1259
1483
  suffix=kernel_suffix,
1484
+ kernel_options=kernel_options,
1260
1485
  code_transformers=[
1261
1486
  PassFieldArgsToIntegrand(
1262
1487
  arg_names=integrand.argspec.args,
@@ -1275,16 +1500,16 @@ def _launch_interpolate_kernel(
1275
1500
  kernel: wp.kernel,
1276
1501
  FieldStruct: wp.codegen.Struct,
1277
1502
  ValueStruct: wp.codegen.Struct,
1278
- dest: FieldLike,
1503
+ domain: GeometryDomain,
1504
+ dest: Optional[Union[FieldRestriction, wp.array]],
1505
+ quadrature: Optional[Quadrature],
1279
1506
  fields: Dict[str, FieldLike],
1280
1507
  values: Dict[str, Any],
1281
1508
  device,
1282
1509
  ) -> wp.Kernel:
1283
1510
  # Set-up launch arguments
1284
- elt_arg = dest.domain.element_arg_value(device=device)
1285
- elt_index_arg = dest.domain.element_index_arg_value(device=device)
1286
- dest_node_arg = dest.space_restriction.node_arg(device=device)
1287
- dest_eval_arg = dest.field.eval_arg_value(device=device)
1511
+ elt_arg = domain.element_arg_value(device=device)
1512
+ elt_index_arg = domain.element_index_arg_value(device=device)
1288
1513
 
1289
1514
  field_arg_values = FieldStruct()
1290
1515
  for k, v in fields.items():
@@ -1294,37 +1519,65 @@ def _launch_interpolate_kernel(
1294
1519
  for k, v in values.items():
1295
1520
  setattr(value_struct_values, k, v)
1296
1521
 
1297
- wp.launch(
1298
- kernel=kernel,
1299
- dim=dest.space_restriction.node_count(),
1300
- inputs=[
1301
- elt_arg,
1302
- elt_index_arg,
1303
- dest_node_arg,
1304
- dest_eval_arg,
1305
- field_arg_values,
1306
- value_struct_values,
1307
- ],
1308
- device=device,
1309
- )
1522
+ if isinstance(dest, FieldRestriction):
1523
+ dest_node_arg = dest.space_restriction.node_arg(device=device)
1524
+ dest_eval_arg = dest.field.eval_arg_value(device=device)
1525
+
1526
+ wp.launch(
1527
+ kernel=kernel,
1528
+ dim=dest.space_restriction.node_count(),
1529
+ inputs=[
1530
+ elt_arg,
1531
+ elt_index_arg,
1532
+ dest_node_arg,
1533
+ dest_eval_arg,
1534
+ field_arg_values,
1535
+ value_struct_values,
1536
+ ],
1537
+ device=device,
1538
+ )
1539
+ elif wp.types.is_array(dest):
1540
+ qp_arg = quadrature.arg_value(device)
1541
+ wp.launch(
1542
+ kernel=kernel,
1543
+ dim=domain.element_count(),
1544
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
1545
+ device=device,
1546
+ )
1547
+ else:
1548
+ qp_arg = quadrature.arg_value(device)
1549
+ wp.launch(
1550
+ kernel=kernel,
1551
+ dim=domain.element_count(),
1552
+ inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
1553
+ device=device,
1554
+ )
1310
1555
 
1311
1556
 
1312
1557
  def interpolate(
1313
1558
  integrand: Integrand,
1314
- dest: Union[DiscreteField, FieldRestriction],
1315
- fields={},
1316
- values={},
1559
+ dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
1560
+ quadrature: Optional[Quadrature] = None,
1561
+ fields: Dict[str, FieldLike] = {},
1562
+ values: Dict[str, Any] = {},
1317
1563
  device=None,
1564
+ kernel_options: Dict[str, Any] = {},
1318
1565
  ):
1319
1566
  """
1320
- Interpolates a function and assigns the result to a discrete field.
1567
+ Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
1321
1568
 
1322
1569
  Args:
1323
- integrand: Function to be interpolated, must have `wp.integrand` decorator
1324
- dest: Discrete field, or restriction of a discrete field to a domain, to which the interpolation result will be assigned
1570
+ integrand: Function to be interpolated, must have :func:`integrand` decorator
1571
+ dest: Where to store the interpolation result. Can be either
1572
+
1573
+ - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
1574
+ - a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
1575
+ - ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is reponsible for dealing with the interpolation result.
1576
+ quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
1325
1577
  fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
1326
- values: Additional variable values to be passed to the integrand, can by of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1578
+ values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
1327
1579
  device: Device on which to perform the interpolation
1580
+ kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1328
1581
  """
1329
1582
  if not isinstance(integrand, Integrand):
1330
1583
  raise ValueError("integrand must be tagged with @integrand decorator")
@@ -1333,20 +1586,33 @@ def interpolate(
1333
1586
  if test is not None or trial is not None:
1334
1587
  raise ValueError("Test or Trial fields should not be used for interpolation")
1335
1588
 
1336
- if not isinstance(dest, FieldRestriction):
1589
+ if isinstance(dest, DiscreteField):
1337
1590
  dest = make_restriction(dest)
1338
1591
 
1592
+ if isinstance(dest, FieldRestriction):
1593
+ domain = dest.domain
1594
+ else:
1595
+ if quadrature is None:
1596
+ raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
1597
+
1598
+ domain = quadrature.domain
1599
+
1339
1600
  kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
1340
1601
  integrand=integrand,
1602
+ domain=domain,
1341
1603
  dest=dest,
1604
+ quadrature=quadrature,
1342
1605
  fields=fields,
1606
+ kernel_options=kernel_options,
1343
1607
  )
1344
1608
 
1345
1609
  return _launch_interpolate_kernel(
1346
1610
  kernel=kernel,
1347
1611
  FieldStruct=FieldStruct,
1348
1612
  ValueStruct=ValueStruct,
1613
+ domain=domain,
1349
1614
  dest=dest,
1615
+ quadrature=quadrature,
1350
1616
  fields=fields,
1351
1617
  values=values,
1352
1618
  device=device,