warp-lang 1.2.2__py3-none-win_amd64.whl → 1.3.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1412 -888
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +400 -198
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +91 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +65 -1
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +341 -224
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.1.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.2.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -9,14 +9,25 @@ from warp.fem.field import (
9
9
  DiscreteField,
10
10
  FieldLike,
11
11
  FieldRestriction,
12
- SpaceField,
12
+ GeometryField,
13
13
  TestField,
14
14
  TrialField,
15
15
  make_restriction,
16
16
  )
17
- from warp.fem.operator import Integrand, Operator
17
+ from warp.fem.operator import Integrand, Operator, integrand
18
18
  from warp.fem.quadrature import Quadrature, RegularQuadrature
19
- from warp.fem.types import NULL_DOF_INDEX, OUTSIDE, DofIndex, Domain, Field, Sample, make_free_sample
19
+ from warp.fem.types import (
20
+ NULL_DOF_INDEX,
21
+ NULL_ELEMENT_INDEX,
22
+ NULL_NODE_INDEX,
23
+ OUTSIDE,
24
+ Coords,
25
+ DofIndex,
26
+ Domain,
27
+ Field,
28
+ Sample,
29
+ make_free_sample,
30
+ )
20
31
  from warp.sparse import BsrMatrix, bsr_set_from_triplets, bsr_zeros
21
32
  from warp.types import type_length
22
33
  from warp.utils import array_cast
@@ -58,24 +69,11 @@ def _resolve_path(func, node):
58
69
  return None, path
59
70
 
60
71
 
61
- def _path_to_ast_attribute(name: str) -> ast.Attribute:
62
- path = name.split(".")
63
- path.reverse()
64
-
65
- node = ast.Name(id=path.pop(), ctx=ast.Load())
66
- while len(path):
67
- node = ast.Attribute(
68
- value=node,
69
- attr=path.pop(),
70
- ctx=ast.Load(),
71
- )
72
- return node
73
-
74
-
75
72
  class IntegrandTransformer(ast.NodeTransformer):
76
- def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
73
+ def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike], annotations: Dict[str, Any]):
77
74
  self._integrand = integrand
78
75
  self._field_args = field_args
76
+ self._annotations = annotations
79
77
 
80
78
  def visit_Call(self, call: ast.Call):
81
79
  call = self.generic_visit(call)
@@ -85,18 +83,15 @@ class IntegrandTransformer(ast.NodeTransformer):
85
83
  # Shortcut for evaluating fields as f(x...)
86
84
  field = self._field_args[callee]
87
85
 
88
- arg_type = self._integrand.argspec.annotations[callee]
89
- operator = arg_type.call_operator
86
+ # Replace with default call operator
87
+ abstract_arg_type = self._integrand.argspec.annotations[callee]
88
+ default_operator = abstract_arg_type.call_operator
89
+ concrete_arg_type = self._annotations[callee]
90
+ self._replace_call_func(call, concrete_arg_type, default_operator, field)
90
91
 
91
- call.func = ast.Attribute(
92
- value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
93
- attr="call_operator",
94
- ctx=ast.Load(),
95
- )
92
+ # insert callee as first argument
96
93
  call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
97
94
 
98
- self._replace_call_func(call, operator, field)
99
-
100
95
  return call
101
96
 
102
97
  func, _ = _resolve_path(self._integrand.func, call.func)
@@ -106,7 +101,7 @@ class IntegrandTransformer(ast.NodeTransformer):
106
101
  callee = getattr(call.args[0], "id", None)
107
102
  if callee in self._field_args:
108
103
  field = self._field_args[callee]
109
- self._replace_call_func(call, func, field)
104
+ self._replace_call_func(call, func, func, field)
110
105
 
111
106
  if isinstance(func, Integrand):
112
107
  key = self._translate_callee(func, call.args)
@@ -120,12 +115,18 @@ class IntegrandTransformer(ast.NodeTransformer):
120
115
 
121
116
  return call
122
117
 
123
- def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
118
+ def _replace_call_func(self, call: ast.Call, callee: Union[type, Operator], operator: Operator, field: FieldLike):
124
119
  try:
120
+ # Retrieve the function pointer corresponding to the operator implementation for the field type
125
121
  pointer = operator.resolver(field)
126
- setattr(operator, pointer.key, pointer)
127
- except AttributeError as e:
122
+ if pointer is None:
123
+ raise NotImplementedError(operator.resolver.__name__)
124
+
125
+ except (AttributeError, NotImplementedError) as e:
128
126
  raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}") from e
127
+ # Save the pointer as an attribute than can be accessed from the callee scope
128
+ setattr(callee, pointer.key, pointer)
129
+ # Update the ast Call node to use the new function pointer
129
130
  call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
130
131
 
131
132
  def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
@@ -162,7 +163,7 @@ def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike])
162
163
  annotations[arg] = arg_type
163
164
 
164
165
  # Transform field evaluation calls
165
- transformer = IntegrandTransformer(integrand, field_args)
166
+ transformer = IntegrandTransformer(integrand, field_args, annotations)
166
167
 
167
168
  suffix = "_".join([f.name for f in field_args.values()])
168
169
 
@@ -215,46 +216,22 @@ def _check_field_compat(
215
216
  field_args: Dict[str, FieldLike],
216
217
  domain: GeometryDomain = None,
217
218
  ):
218
- # Check field compatilibity
219
+ # Check field compatibility
219
220
  for name, field in fields.items():
220
221
  if name not in field_args:
221
222
  raise ValueError(
222
223
  f"Passed field argument '{name}' does not match any parameter of integrand '{integrand.name}'"
223
224
  )
224
225
 
225
- if isinstance(field, SpaceField) and domain is not None:
226
- space = field.space
227
- if space.geometry != domain.geometry:
226
+ if isinstance(field, GeometryField) and domain is not None:
227
+ if field.geometry != domain.geometry:
228
228
  raise ValueError(f"Field '{name}' must be defined on the same geometry as the integration domain")
229
- if space.dimension != domain.dimension:
229
+ if field.element_kind != domain.element_kind:
230
230
  raise ValueError(
231
- f"Field '{name}' dimension ({space.dimension}) does not match that of the integration domain ({domain.dimension}). Maybe a forgotten `.trace()`?"
231
+ f"Field '{name}' is not defined on the same kind of elements (cells or sides) as the integration domain. Maybe a forgotten `.trace()`?"
232
232
  )
233
233
 
234
234
 
235
- def _populate_value_struct(ValueStruct: wp.codegen.Struct, values: Dict[str, Any], integrand_name: str):
236
- value_struct_values = ValueStruct()
237
- for k, v in values.items():
238
- try:
239
- setattr(value_struct_values, k, v)
240
- except Exception as err:
241
- if k not in ValueStruct.vars:
242
- raise ValueError(
243
- f"Passed value argument '{k}' does not match any of the integrand '{integrand_name}' parameters"
244
- ) from err
245
- raise ValueError(
246
- f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the integrand '{integrand_name}' parameter of type '{wp.types.type_repr(ValueStruct.vars[k].type)}'"
247
- ) from err
248
-
249
- missing_values = ValueStruct.vars.keys() - values.keys()
250
- if missing_values:
251
- wp.utils.warn(
252
- f"Missing values for parameter(s) '{', '.join(missing_values)}' of the integrand '{integrand_name}', will be zero-initialized"
253
- )
254
-
255
- return value_struct_values
256
-
257
-
258
235
  def _get_test_and_trial_fields(
259
236
  fields: Dict[str, FieldLike],
260
237
  ):
@@ -310,36 +287,6 @@ def _gen_field_struct(field_args: Dict[str, FieldLike]):
310
287
  return cache.get_struct(Fields, suffix=suffix)
311
288
 
312
289
 
313
- def _gen_value_struct(value_args: Dict[str, type]):
314
- class Values:
315
- pass
316
-
317
- annotations = get_annotations(Values)
318
-
319
- for name, arg_type in value_args.items():
320
- setattr(Values, name, None)
321
- annotations[name] = arg_type
322
-
323
- def arg_type_name(arg_type):
324
- if isinstance(arg_type, wp.codegen.Struct):
325
- return arg_type_name(arg_type.cls)
326
- return getattr(arg_type, "__name__", str(arg_type))
327
-
328
- def arg_type_name(arg_type):
329
- if isinstance(arg_type, wp.codegen.Struct):
330
- return arg_type_name(arg_type.cls)
331
- return getattr(arg_type, "__name__", str(arg_type))
332
-
333
- try:
334
- Values.__annotations__ = annotations
335
- except AttributeError:
336
- Values.__dict__.__annotations__ = annotations
337
-
338
- suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
339
-
340
- return cache.get_struct(Values, suffix=suffix)
341
-
342
-
343
290
  def _get_trial_arg():
344
291
  pass
345
292
 
@@ -474,17 +421,18 @@ def get_integrate_constant_kernel(
474
421
  values: ValueStruct,
475
422
  result: wp.array(dtype=accumulate_dtype),
476
423
  ):
477
- element_index = domain.element_index(domain_index_arg, wp.tid())
424
+ domain_element_index = wp.tid()
425
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
478
426
  elem_sum = accumulate_dtype(0.0)
479
427
 
480
428
  test_dof_index = NULL_DOF_INDEX
481
429
  trial_dof_index = NULL_DOF_INDEX
482
430
 
483
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
431
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
484
432
  for k in range(qp_point_count):
485
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
486
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
487
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
433
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
434
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
435
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
488
436
 
489
437
  sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
490
438
  vol = domain.element_measure(domain_arg, sample)
@@ -519,23 +467,31 @@ def get_integrate_linear_kernel(
519
467
  ):
520
468
  local_node_index, test_dof = wp.tid()
521
469
  node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
522
- element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
470
+ element_beg, element_end = test.space_restriction.node_element_range(test_arg, node_index)
523
471
 
524
472
  trial_dof_index = NULL_DOF_INDEX
525
473
 
526
474
  val_sum = accumulate_dtype(0.0)
527
475
 
528
- for n in range(element_count):
529
- node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
476
+ for n in range(element_beg, element_end):
477
+ node_element_index = test.space_restriction.node_element_index(test_arg, n)
530
478
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
531
479
 
532
480
  test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
533
481
 
534
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
482
+ qp_point_count = quadrature.point_count(
483
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index
484
+ )
535
485
  for k in range(qp_point_count):
536
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
537
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
538
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
486
+ qp_index = quadrature.point_index(
487
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
488
+ )
489
+ qp_coords = quadrature.point_coords(
490
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
491
+ )
492
+ qp_weight = quadrature.point_weight(
493
+ domain_arg, qp_arg, node_element_index.domain_element_index, element_index, k
494
+ )
539
495
 
540
496
  vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
541
497
 
@@ -562,23 +518,29 @@ def get_integrate_linear_nodal_kernel(
562
518
  domain_arg: domain.ElementArg,
563
519
  domain_index_arg: domain.ElementIndexArg,
564
520
  test_restriction_arg: test.space_restriction.NodeArg,
521
+ test_topo_arg: test.space.topology.TopologyArg,
565
522
  fields: FieldStruct,
566
523
  values: ValueStruct,
567
524
  result: wp.array2d(dtype=output_dtype),
568
525
  ):
569
526
  local_node_index, dof = wp.tid()
570
527
 
571
- node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
572
- element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
528
+ partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
529
+ element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
573
530
 
574
531
  trial_dof_index = NULL_DOF_INDEX
575
532
 
576
533
  val_sum = accumulate_dtype(0.0)
577
534
 
578
- for n in range(element_count):
579
- node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
535
+ for n in range(element_beg, element_end):
536
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
580
537
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
581
538
 
539
+ if n == element_beg:
540
+ node_index = test.space.topology.element_node_index(
541
+ domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
542
+ )
543
+
582
544
  coords = test.space.node_coords_in_element(
583
545
  domain_arg,
584
546
  _get_test_arg(),
@@ -609,7 +571,7 @@ def get_integrate_linear_nodal_kernel(
609
571
 
610
572
  val_sum += accumulate_dtype(node_weight * vol * val)
611
573
 
612
- result[node_index, dof] = output_dtype(val_sum)
574
+ result[partition_node_index, dof] = output_dtype(val_sum)
613
575
 
614
576
  return integrate_kernel_fn
615
577
 
@@ -625,7 +587,7 @@ def get_integrate_bilinear_kernel(
625
587
  output_dtype,
626
588
  accumulate_dtype,
627
589
  ):
628
- NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
590
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
629
591
 
630
592
  def integrate_kernel_fn(
631
593
  qp_arg: quadrature.Arg,
@@ -636,22 +598,29 @@ def get_integrate_bilinear_kernel(
636
598
  trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
637
599
  fields: FieldStruct,
638
600
  values: ValueStruct,
639
- row_offsets: wp.array(dtype=int),
640
601
  triplet_rows: wp.array(dtype=int),
641
602
  triplet_cols: wp.array(dtype=int),
642
603
  triplet_values: wp.array3d(dtype=output_dtype),
643
604
  ):
644
605
  test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
645
606
 
646
- element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
647
607
  test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
608
+ element_beg, element_end = test.space_restriction.node_element_range(test_arg, test_node_index)
648
609
 
649
610
  trial_dof_index = DofIndex(trial_node, trial_dof)
650
611
 
651
- for element in range(element_count):
652
- test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
612
+ for element in range(element_beg, element_end):
613
+ test_element_index = test.space_restriction.node_element_index(test_arg, element)
653
614
  element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
654
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
615
+
616
+ element_trial_node_count = trial.space.topology.element_node_count(
617
+ domain_arg, trial_topology_arg, element_index
618
+ )
619
+ qp_point_count = wp.select(
620
+ trial_node < element_trial_node_count,
621
+ 0,
622
+ quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
623
+ )
655
624
 
656
625
  test_dof_index = DofIndex(
657
626
  test_element_index.node_index_in_element,
@@ -661,10 +630,16 @@ def get_integrate_bilinear_kernel(
661
630
  val_sum = accumulate_dtype(0.0)
662
631
 
663
632
  for k in range(qp_point_count):
664
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
665
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
633
+ qp_index = quadrature.point_index(
634
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
635
+ )
636
+ coords = quadrature.point_coords(
637
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
638
+ )
666
639
 
667
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
640
+ qp_weight = quadrature.point_weight(
641
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
642
+ )
668
643
  vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
669
644
 
670
645
  sample = Sample(
@@ -678,15 +653,20 @@ def get_integrate_bilinear_kernel(
678
653
  val = integrand_func(sample, fields, values)
679
654
  val_sum += accumulate_dtype(qp_weight * vol * val)
680
655
 
681
- block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
656
+ block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
682
657
  triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
683
658
 
684
659
  # Set row and column indices
685
660
  if test_dof == 0 and trial_dof == 0:
686
- trial_node_index = trial.space_partition.partition_node_index(
687
- trial_partition_arg,
688
- trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
689
- )
661
+ if trial_node < element_trial_node_count:
662
+ trial_node_index = trial.space_partition.partition_node_index(
663
+ trial_partition_arg,
664
+ trial.space.topology.element_node_index(
665
+ domain_arg, trial_topology_arg, element_index, trial_node
666
+ ),
667
+ )
668
+ else:
669
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
690
670
  triplet_rows[block_offset] = test_node_index
691
671
  triplet_cols[block_offset] = trial_node_index
692
672
 
@@ -706,6 +686,7 @@ def get_integrate_bilinear_nodal_kernel(
706
686
  domain_arg: domain.ElementArg,
707
687
  domain_index_arg: domain.ElementIndexArg,
708
688
  test_restriction_arg: test.space_restriction.NodeArg,
689
+ test_topo_arg: test.space.topology.TopologyArg,
709
690
  fields: FieldStruct,
710
691
  values: ValueStruct,
711
692
  triplet_rows: wp.array(dtype=int),
@@ -714,15 +695,20 @@ def get_integrate_bilinear_nodal_kernel(
714
695
  ):
715
696
  local_node_index, test_dof, trial_dof = wp.tid()
716
697
 
717
- element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
718
- node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
698
+ partition_node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
699
+ element_beg, element_end = test.space_restriction.node_element_range(test_restriction_arg, partition_node_index)
719
700
 
720
701
  val_sum = accumulate_dtype(0.0)
721
702
 
722
- for n in range(element_count):
723
- node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
703
+ for n in range(element_beg, element_end):
704
+ node_element_index = test.space_restriction.node_element_index(test_restriction_arg, n)
724
705
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
725
706
 
707
+ if n == element_beg:
708
+ node_index = test.space.topology.element_node_index(
709
+ domain_arg, test_topo_arg, element_index, node_element_index.node_index_in_element
710
+ )
711
+
726
712
  coords = test.space.node_coords_in_element(
727
713
  domain_arg,
728
714
  _get_test_arg(),
@@ -755,8 +741,8 @@ def get_integrate_bilinear_nodal_kernel(
755
741
  val_sum += accumulate_dtype(node_weight * vol * val)
756
742
 
757
743
  triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
758
- triplet_rows[local_node_index] = node_index
759
- triplet_cols[local_node_index] = node_index
744
+ triplet_rows[local_node_index] = partition_node_index
745
+ triplet_cols[local_node_index] = partition_node_index
760
746
 
761
747
  return integrate_kernel_fn
762
748
 
@@ -786,7 +772,7 @@ def _generate_integrate_kernel(
786
772
  )
787
773
 
788
774
  FieldStruct = _gen_field_struct(field_args)
789
- ValueStruct = _gen_value_struct(value_args)
775
+ ValueStruct = cache.get_argument_struct(value_args)
790
776
 
791
777
  # Check if kernel exist in cache
792
778
  kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
@@ -923,7 +909,7 @@ def _launch_integrate_kernel(
923
909
  for k, v in fields.items():
924
910
  setattr(field_arg_values, k, v.eval_arg_value(device=device))
925
911
 
926
- value_struct_values = _populate_value_struct(ValueStruct, values, integrand_name=integrand.name)
912
+ value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
927
913
 
928
914
  # Constant form
929
915
  if test is None and trial is None:
@@ -1030,6 +1016,7 @@ def _launch_integrate_kernel(
1030
1016
  domain_elt_arg,
1031
1017
  domain_elt_index_arg,
1032
1018
  test_arg,
1019
+ test.space.topology.topo_arg_value(device),
1033
1020
  field_arg_values,
1034
1021
  value_struct_values,
1035
1022
  output_view,
@@ -1069,7 +1056,7 @@ def _launch_integrate_kernel(
1069
1056
  if nodal:
1070
1057
  nnz = test.space_restriction.node_count()
1071
1058
  else:
1072
- nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
1059
+ nnz = test.space_restriction.total_node_element_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
1073
1060
 
1074
1061
  triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
1075
1062
  triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
@@ -1097,6 +1084,7 @@ def _launch_integrate_kernel(
1097
1084
  domain_elt_arg,
1098
1085
  domain_elt_index_arg,
1099
1086
  test_arg,
1087
+ test.space.topology.topo_arg_value(device),
1100
1088
  field_arg_values,
1101
1089
  value_struct_values,
1102
1090
  triplet_rows,
@@ -1107,15 +1095,13 @@ def _launch_integrate_kernel(
1107
1095
  )
1108
1096
 
1109
1097
  else:
1110
- offsets = test.space_restriction.partition_element_offsets()
1111
-
1112
1098
  trial_partition_arg = trial.space_partition.partition_arg_value(device)
1113
1099
  trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
1114
1100
  wp.launch(
1115
1101
  kernel=kernel,
1116
1102
  dim=(
1117
1103
  test.space_restriction.node_count(),
1118
- trial.space.topology.NODES_PER_ELEMENT,
1104
+ trial.space.topology.MAX_NODES_PER_ELEMENT,
1119
1105
  test.space.VALUE_DOF_COUNT,
1120
1106
  trial.space.VALUE_DOF_COUNT,
1121
1107
  ),
@@ -1128,7 +1114,6 @@ def _launch_integrate_kernel(
1128
1114
  trial_topology_arg,
1129
1115
  field_arg_values,
1130
1116
  value_struct_values,
1131
- offsets,
1132
1117
  triplet_rows,
1133
1118
  triplet_cols,
1134
1119
  triplet_values,
@@ -1299,8 +1284,8 @@ def get_interpolate_to_field_function(
1299
1284
  fields: FieldStruct,
1300
1285
  values: ValueStruct,
1301
1286
  ):
1302
- node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1303
- element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
1287
+ partition_node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
1288
+ element_beg, element_end = dest.space_restriction.node_element_range(dest_node_arg, partition_node_index)
1304
1289
 
1305
1290
  test_dof_index = NULL_DOF_INDEX
1306
1291
  trial_dof_index = NULL_DOF_INDEX
@@ -1312,10 +1297,15 @@ def get_interpolate_to_field_function(
1312
1297
  val_sum = value_type(0.0)
1313
1298
  vol_sum = float(0.0)
1314
1299
 
1315
- for n in range(element_count):
1316
- node_element_index = dest.space_restriction.node_element_index(dest_node_arg, local_node_index, n)
1300
+ for n in range(element_beg, element_end):
1301
+ node_element_index = dest.space_restriction.node_element_index(dest_node_arg, n)
1317
1302
  element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
1318
1303
 
1304
+ if n == element_beg:
1305
+ node_index = dest.space.topology.element_node_index(
1306
+ domain_arg, dest_eval_arg.topology_arg, element_index, node_element_index.node_index_in_element
1307
+ )
1308
+
1319
1309
  coords = dest.space.node_coords_in_element(
1320
1310
  domain_arg,
1321
1311
  dest_eval_arg.space_arg,
@@ -1371,7 +1361,7 @@ def get_interpolate_to_field_kernel(
1371
1361
  return interpolate_to_field_kernel_fn
1372
1362
 
1373
1363
 
1374
- def get_interpolate_to_array_kernel(
1364
+ def get_interpolate_at_quadrature_kernel(
1375
1365
  integrand_func: wp.Function,
1376
1366
  domain: GeometryDomain,
1377
1367
  quadrature: Quadrature,
@@ -1379,61 +1369,100 @@ def get_interpolate_to_array_kernel(
1379
1369
  ValueStruct: wp.codegen.Struct,
1380
1370
  value_type: type,
1381
1371
  ):
1382
- def interpolate_to_array_kernel_fn(
1372
+ def interpolate_at_quadrature_nonvalued_kernel_fn(
1383
1373
  qp_arg: quadrature.Arg,
1384
1374
  domain_arg: quadrature.domain.ElementArg,
1385
1375
  domain_index_arg: quadrature.domain.ElementIndexArg,
1386
1376
  fields: FieldStruct,
1387
1377
  values: ValueStruct,
1388
- result: wp.array(dtype=value_type),
1378
+ result: wp.array(dtype=float),
1389
1379
  ):
1390
- element_index = domain.element_index(domain_index_arg, wp.tid())
1380
+ domain_element_index = wp.tid()
1381
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1391
1382
 
1392
1383
  test_dof_index = NULL_DOF_INDEX
1393
1384
  trial_dof_index = NULL_DOF_INDEX
1394
1385
 
1395
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1386
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
1396
1387
  for k in range(qp_point_count):
1397
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1398
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1399
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1388
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
1389
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
1390
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
1400
1391
 
1401
1392
  sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1393
+ integrand_func(sample, fields, values)
1394
+
1395
+ def interpolate_at_quadrature_kernel_fn(
1396
+ qp_arg: quadrature.Arg,
1397
+ domain_arg: quadrature.domain.ElementArg,
1398
+ domain_index_arg: quadrature.domain.ElementIndexArg,
1399
+ fields: FieldStruct,
1400
+ values: ValueStruct,
1401
+ result: wp.array(dtype=value_type),
1402
+ ):
1403
+ domain_element_index = wp.tid()
1404
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1405
+
1406
+ test_dof_index = NULL_DOF_INDEX
1407
+ trial_dof_index = NULL_DOF_INDEX
1402
1408
 
1409
+ qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
1410
+ for k in range(qp_point_count):
1411
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
1412
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
1413
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
1414
+
1415
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1403
1416
  result[qp_index] = integrand_func(sample, fields, values)
1404
1417
 
1405
- return interpolate_to_array_kernel_fn
1418
+ return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
1406
1419
 
1407
1420
 
1408
- def get_interpolate_nonvalued_kernel(
1421
+ def get_interpolate_free_kernel(
1409
1422
  integrand_func: wp.Function,
1410
1423
  domain: GeometryDomain,
1411
- quadrature: Quadrature,
1412
1424
  FieldStruct: wp.codegen.Struct,
1413
1425
  ValueStruct: wp.codegen.Struct,
1426
+ value_type: type,
1414
1427
  ):
1415
- def interpolate_nonvalued_kernel_fn(
1416
- qp_arg: quadrature.Arg,
1417
- domain_arg: quadrature.domain.ElementArg,
1418
- domain_index_arg: quadrature.domain.ElementIndexArg,
1428
+ def interpolate_free_nonvalued_kernel_fn(
1429
+ dim: int,
1430
+ domain_arg: domain.ElementArg,
1419
1431
  fields: FieldStruct,
1420
1432
  values: ValueStruct,
1433
+ result: wp.array(dtype=float),
1421
1434
  ):
1422
- element_index = domain.element_index(domain_index_arg, wp.tid())
1435
+ qp_index = wp.tid()
1436
+ qp_weight = 1.0 / float(dim)
1437
+ element_index = NULL_ELEMENT_INDEX
1438
+ coords = Coords(OUTSIDE)
1423
1439
 
1424
1440
  test_dof_index = NULL_DOF_INDEX
1425
1441
  trial_dof_index = NULL_DOF_INDEX
1426
1442
 
1427
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
1428
- for k in range(qp_point_count):
1429
- qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
1430
- coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
1431
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
1443
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1444
+ integrand_func(sample, fields, values)
1432
1445
 
1433
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1434
- integrand_func(sample, fields, values)
1446
+ def interpolate_free_kernel_fn(
1447
+ dim: int,
1448
+ domain_arg: domain.ElementArg,
1449
+ fields: FieldStruct,
1450
+ values: ValueStruct,
1451
+ result: wp.array(dtype=value_type),
1452
+ ):
1453
+ qp_index = wp.tid()
1454
+ qp_weight = 1.0 / float(dim)
1455
+ element_index = NULL_ELEMENT_INDEX
1456
+ coords = Coords(OUTSIDE)
1457
+
1458
+ test_dof_index = NULL_DOF_INDEX
1459
+ trial_dof_index = NULL_DOF_INDEX
1460
+
1461
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1435
1462
 
1436
- return interpolate_nonvalued_kernel_fn
1463
+ result[qp_index] = integrand_func(sample, fields, values)
1464
+
1465
+ return interpolate_free_nonvalued_kernel_fn if value_type is None else interpolate_free_kernel_fn
1437
1466
 
1438
1467
 
1439
1468
  def _generate_interpolate_kernel(
@@ -1461,17 +1490,20 @@ def _generate_interpolate_kernel(
1461
1490
  _register_integrand_field_wrappers(integrand_func, fields)
1462
1491
 
1463
1492
  FieldStruct = _gen_field_struct(field_args)
1464
- ValueStruct = _gen_value_struct(value_args)
1493
+ ValueStruct = cache.get_argument_struct(value_args)
1465
1494
 
1466
1495
  # Check if kernel exist in cache
1467
1496
  if isinstance(dest, FieldRestriction):
1468
1497
  kernel_suffix = (
1469
1498
  f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
1470
1499
  )
1471
- elif wp.types.is_array(dest):
1472
- kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
1473
1500
  else:
1474
- kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
1501
+ dest_dtype = dest.dtype if dest else None
1502
+ type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
1503
+ if quadrature is None:
1504
+ kernel_suffix = f"_itp_{FieldStruct.key}_{type_str}"
1505
+ else:
1506
+ kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{type_str}"
1475
1507
 
1476
1508
  kernel = cache.get_integrand_kernel(
1477
1509
  integrand=integrand,
@@ -1515,20 +1547,20 @@ def _generate_interpolate_kernel(
1515
1547
  FieldStruct=FieldStruct,
1516
1548
  ValueStruct=ValueStruct,
1517
1549
  )
1518
- elif wp.types.is_array(dest):
1519
- interpolate_kernel_fn = get_interpolate_to_array_kernel(
1550
+ elif quadrature is not None:
1551
+ interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
1520
1552
  integrand_func,
1521
1553
  domain=domain,
1522
1554
  quadrature=quadrature,
1523
- value_type=dest.dtype,
1555
+ value_type=dest_dtype,
1524
1556
  FieldStruct=FieldStruct,
1525
1557
  ValueStruct=ValueStruct,
1526
1558
  )
1527
1559
  else:
1528
- interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
1560
+ interpolate_kernel_fn = get_interpolate_free_kernel(
1529
1561
  integrand_func,
1530
1562
  domain=domain,
1531
- quadrature=quadrature,
1563
+ value_type=dest_dtype,
1532
1564
  FieldStruct=FieldStruct,
1533
1565
  ValueStruct=ValueStruct,
1534
1566
  )
@@ -1560,6 +1592,7 @@ def _launch_interpolate_kernel(
1560
1592
  domain: GeometryDomain,
1561
1593
  dest: Optional[Union[FieldRestriction, wp.array]],
1562
1594
  quadrature: Optional[Quadrature],
1595
+ dim: int,
1563
1596
  fields: Dict[str, FieldLike],
1564
1597
  values: Dict[str, Any],
1565
1598
  device,
@@ -1572,7 +1605,7 @@ def _launch_interpolate_kernel(
1572
1605
  for k, v in fields.items():
1573
1606
  setattr(field_arg_values, k, v.eval_arg_value(device=device))
1574
1607
 
1575
- value_struct_values = _populate_value_struct(ValueStruct, values, integrand_name=integrand.name)
1608
+ value_struct_values = cache.populate_argument_struct(ValueStruct, values, func_name=integrand.name)
1576
1609
 
1577
1610
  if isinstance(dest, FieldRestriction):
1578
1611
  dest_node_arg = dest.space_restriction.node_arg(device=device)
@@ -1591,7 +1624,7 @@ def _launch_interpolate_kernel(
1591
1624
  ],
1592
1625
  device=device,
1593
1626
  )
1594
- elif wp.types.is_array(dest):
1627
+ elif quadrature is not None:
1595
1628
  qp_arg = quadrature.arg_value(device)
1596
1629
  wp.launch(
1597
1630
  kernel=kernel,
@@ -1600,19 +1633,25 @@ def _launch_interpolate_kernel(
1600
1633
  device=device,
1601
1634
  )
1602
1635
  else:
1603
- qp_arg = quadrature.arg_value(device)
1604
1636
  wp.launch(
1605
1637
  kernel=kernel,
1606
- dim=domain.element_count(),
1607
- inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
1638
+ dim=dim,
1639
+ inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
1608
1640
  device=device,
1609
1641
  )
1610
1642
 
1611
1643
 
1644
+ @integrand
1645
+ def _identity_field(field: Field, s: Sample):
1646
+ return field(s)
1647
+
1648
+
1612
1649
  def interpolate(
1613
- integrand: Integrand,
1650
+ integrand: Union[Integrand, FieldLike],
1614
1651
  dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
1615
1652
  quadrature: Optional[Quadrature] = None,
1653
+ dim: int = 0,
1654
+ domain: Optional[Domain] = None,
1616
1655
  fields: Optional[Dict[str, FieldLike]] = None,
1617
1656
  values: Optional[Dict[str, Any]] = None,
1618
1657
  device=None,
@@ -1622,18 +1661,26 @@ def interpolate(
1622
1661
  Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
1623
1662
 
1624
1663
  Args:
1625
- integrand: Function to be interpolated, must have :func:`integrand` decorator
1664
+ integrand: Function to be interpolated: either a function with :func:`warp.fem.integrand` decorator or a field
1626
1665
  dest: Where to store the interpolation result. Can be either
1627
1666
 
1628
1667
  - a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
1629
- - a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
1630
- - ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is responsible for dealing with the interpolation result.
1668
+ - a normal warp ``array``, or ``None``. In this case, the interpolation samples will determined by the `quadrature` or `dim` arguments, in that order.
1631
1669
  quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
1670
+ dim: Number of interpolation samples if `dest` is not a discrete field or restriction and `quadrature` is ``None``.
1671
+ In this case, the ``Sample`` passed to the `integrand` will be invalid, but the sample point index ``s.qp_index`` can be used to define custom interpolation logic.
1672
+ domain: Interpolation domain, only used if `dest` is not a field restriction and `quadrature` is ``None``
1632
1673
  fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
1633
1674
  values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
1634
1675
  device: Device on which to perform the interpolation
1635
1676
  kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
1636
1677
  """
1678
+
1679
+ if isinstance(integrand, FieldLike):
1680
+ fields = {"field": integrand}
1681
+ values = {}
1682
+ integrand = _identity_field
1683
+
1637
1684
  if fields is None:
1638
1685
  fields = {}
1639
1686
 
@@ -1651,14 +1698,11 @@ def interpolate(
1651
1698
  raise ValueError("Test or Trial fields should not be used for interpolation")
1652
1699
 
1653
1700
  if isinstance(dest, DiscreteField):
1654
- dest = make_restriction(dest)
1701
+ dest = make_restriction(dest, domain=domain)
1655
1702
 
1656
1703
  if isinstance(dest, FieldRestriction):
1657
1704
  domain = dest.domain
1658
- else:
1659
- if quadrature is None:
1660
- raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
1661
-
1705
+ elif quadrature is not None:
1662
1706
  domain = quadrature.domain
1663
1707
 
1664
1708
  kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
@@ -1678,6 +1722,7 @@ def interpolate(
1678
1722
  domain=domain,
1679
1723
  dest=dest,
1680
1724
  quadrature=quadrature,
1725
+ dim=dim,
1681
1726
  fields=fields,
1682
1727
  values=values,
1683
1728
  device=device,