warp-lang 1.2.1__py3-none-macosx_10_13_universal2.whl → 1.3.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +401 -199
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +122 -39
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +344 -227
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.1.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/fem/operator.py CHANGED
@@ -1,9 +1,8 @@
1
- import inspect
2
1
  from typing import Any, Callable
3
2
 
4
3
  import warp as wp
5
4
  from warp.fem import utils
6
- from warp.fem.types import Domain, Field, Sample
5
+ from warp.fem.types import Domain, Field, NodeIndex, Sample
7
6
 
8
7
 
9
8
  class Integrand:
@@ -15,7 +14,7 @@ class Integrand:
15
14
  self.func = func
16
15
  self.name = wp.codegen.make_full_qualified_name(self.func)
17
16
  self.module = wp.get_module(self.func.__module__)
18
- self.argspec = inspect.getfullargspec(self.func)
17
+ self.argspec = wp.codegen.get_full_arg_spec(self.func)
19
18
 
20
19
 
21
20
  class Operator:
@@ -55,7 +54,7 @@ def position(domain: Domain, s: Sample):
55
54
  pass
56
55
 
57
56
 
58
- @operator(resolver=lambda dmn: dmn.eval_normal)
57
+ @operator(resolver=lambda dmn: dmn.element_normal)
59
58
  def normal(domain: Domain, s: Sample):
60
59
  """Evaluates the element normal at the sample point `s`. Null for interior points."""
61
60
  pass
@@ -71,13 +70,12 @@ def deformation_gradient(domain: Domain, s: Sample):
71
70
  def lookup(domain: Domain, x: Any) -> Sample:
72
71
  """Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the domain.
73
72
 
74
- Arg:
73
+ Args:
75
74
  x: world position of the point to look-up in the geometry
76
75
  guess: (optional) :class:`Sample` initial guess, may help perform the query
77
76
 
78
- Notes:
79
- Currently this operator is only fully supported for :class:`Grid2D` and :class:`Grid3D` geometries.
80
- For :class:`TriangleMesh2D` and :class:`Tetmesh` geometries, the operator requires providing `guess`.
77
+ Note:
78
+ Currently this operator is unsupported for :class:`Hexmesh`, :class:`Quadmesh2D` and deformed geometries.
81
79
  """
82
80
  pass
83
81
 
@@ -142,7 +140,14 @@ def degree(f: Field):
142
140
 
143
141
  @operator(resolver=lambda f: f.at_node)
144
142
  def at_node(f: Field, s: Sample):
145
- """For a Test or Trial field, returns a copy of the Sample `s` moved to the coordinates of the node being evaluated"""
143
+ """For a Test or Trial field `f`, returns a copy of the Sample `s` moved to the coordinates of the node being evaluated"""
144
+ pass
145
+
146
+
147
+ @operator(resolver=lambda f: f.node_partition_index)
148
+ def node_partition_index(f: Field, node_index: NodeIndex):
149
+ """For a NodalField `f`, returns the index of a given node in the fields's space partition,
150
+ or ``NULL_NODE_INDEX`` if it does not exists"""
146
151
  pass
147
152
 
148
153
 
@@ -8,8 +8,6 @@ from warp.fem.utils import compress_node_indices
8
8
 
9
9
  from .quadrature import Quadrature
10
10
 
11
- wp.set_module_options({"enable_backward": False})
12
-
13
11
 
14
12
  class PicQuadrature(Quadrature):
15
13
  """Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
@@ -23,6 +21,7 @@ class PicQuadrature(Quadrature):
23
21
  define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
24
22
  measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
25
23
  If ``None``, defaults to the cell measure divided by the number of particles in the cell.
24
+ requires_grad: Whether gradients should be allocated for the computed quantities
26
25
  temporary_store: shared pool from which to allocate temporary arrays
27
26
  """
28
27
 
@@ -37,11 +36,14 @@ class PicQuadrature(Quadrature):
37
36
  ],
38
37
  ],
39
38
  measures: Optional["wp.array(dtype=float)"] = None,
39
+ requires_grad: bool = False,
40
40
  temporary_store: TemporaryStore = None,
41
41
  ):
42
42
  super().__init__(domain)
43
43
 
44
+ self._requires_grad = requires_grad
44
45
  self._bin_particles(positions, measures, temporary_store)
46
+ self._max_particles_per_cell: int = None
45
47
 
46
48
  @property
47
49
  def name(self):
@@ -82,22 +84,40 @@ class PicQuadrature(Quadrature):
82
84
  """Number of cells containing at least one particle"""
83
85
  return self._cell_count
84
86
 
87
+ def max_points_per_element(self):
88
+ if self._max_particles_per_cell is None:
89
+ max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.array.device)
90
+ wp.launch(
91
+ PicQuadrature._max_particles_per_cell_kernel,
92
+ self._cell_particle_offsets.array.shape[0] - 1,
93
+ device=max_ppc.device,
94
+ inputs=[self._cell_particle_offsets.array, max_ppc],
95
+ )
96
+ self._max_particles_per_cell = int(max_ppc.numpy()[0])
97
+ return self._max_particles_per_cell
98
+
85
99
  @wp.func
86
- def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
100
+ def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
87
101
  return qp_arg.cell_particle_offsets[element_index + 1] - qp_arg.cell_particle_offsets[element_index]
88
102
 
89
103
  @wp.func
90
- def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
104
+ def point_coords(
105
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
106
+ ):
91
107
  particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
92
108
  return qp_arg.particle_coords[particle_index]
93
109
 
94
110
  @wp.func
95
- def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
111
+ def point_weight(
112
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
113
+ ):
96
114
  particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
97
115
  return qp_arg.particle_fraction[particle_index]
98
116
 
99
117
  @wp.func
100
- def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
118
+ def point_index(
119
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
120
+ ):
101
121
  particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
102
122
  return particle_index
103
123
 
@@ -158,7 +178,7 @@ class PicQuadrature(Quadrature):
158
178
  cell_index = cell_index_temp.array
159
179
 
160
180
  self._particle_coords_temp = borrow_temporary(
161
- temporary_store, shape=positions.shape, dtype=Coords, device=device
181
+ temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
162
182
  )
163
183
  self._particle_coords = self._particle_coords_temp.array
164
184
 
@@ -183,7 +203,7 @@ class PicQuadrature(Quadrature):
183
203
  self._particle_coords_temp = None
184
204
 
185
205
  self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
186
- self.domain.geometry_element_count(), cell_index
206
+ self.domain.geometry_element_count(), cell_index, return_unique_nodes=True, temporary_store=temporary_store
187
207
  )
188
208
 
189
209
  self._compute_fraction(cell_index, measures, temporary_store)
@@ -192,7 +212,7 @@ class PicQuadrature(Quadrature):
192
212
  device = cell_index.device
193
213
 
194
214
  self._particle_fraction_temp = borrow_temporary(
195
- temporary_store, shape=cell_index.shape, dtype=float, device=device
215
+ temporary_store, shape=cell_index.shape, dtype=float, device=device, requires_grad=self._requires_grad
196
216
  )
197
217
  self._particle_fraction = self._particle_fraction_temp.array
198
218
 
@@ -241,3 +261,9 @@ class PicQuadrature(Quadrature):
241
261
  ],
242
262
  device=device,
243
263
  )
264
+
265
+ @wp.kernel
266
+ def _max_particles_per_cell_kernel(offsets: wp.array(dtype=int), max_count: wp.array(dtype=int)):
267
+ cell = wp.tid()
268
+ particle_count = offsets[cell + 1] - offsets[cell]
269
+ wp.atomic_max(max_count, 0, particle_count)
@@ -36,8 +36,8 @@ class Quadrature:
36
36
  """Total number of quadrature points over the domain"""
37
37
  raise NotImplementedError()
38
38
 
39
- def points_per_element(self):
40
- """Number of points per element if constant, or ``None`` if varying"""
39
+ def max_points_per_element(self):
40
+ """Maximum number of points per element if known, or ``None`` otherwise"""
41
41
  return None
42
42
 
43
43
  @staticmethod
@@ -61,7 +61,11 @@ class Quadrature:
61
61
 
62
62
  @staticmethod
63
63
  def point_index(
64
- elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
64
+ elt_arg: "domain.GeometryDomain.ElementArg",
65
+ qp_arg: Arg,
66
+ domain_element_index: ElementIndex,
67
+ geo_element_index: ElementIndex,
68
+ element_qp_index: int,
65
69
  ):
66
70
  """Global index of the element's qp_index'th quadrature point"""
67
71
  raise NotImplementedError()
@@ -106,7 +110,7 @@ class RegularQuadrature(Quadrature):
106
110
  def total_point_count(self):
107
111
  return len(self.points) * self.domain.geometry_element_count()
108
112
 
109
- def points_per_element(self):
113
+ def max_points_per_element(self):
110
114
  return self._N
111
115
 
112
116
  @property
@@ -121,7 +125,12 @@ class RegularQuadrature(Quadrature):
121
125
  N = self._N
122
126
 
123
127
  @cache.dynamic_func(suffix=self.name)
124
- def point_count(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex):
128
+ def point_count(
129
+ elt_arg: self.domain.ElementArg,
130
+ qp_arg: self.Arg,
131
+ domain_element_index: ElementIndex,
132
+ element_index: ElementIndex,
133
+ ):
125
134
  return N
126
135
 
127
136
  return point_count
@@ -130,7 +139,13 @@ class RegularQuadrature(Quadrature):
130
139
  POINTS = self._POINTS
131
140
 
132
141
  @cache.dynamic_func(suffix=self.name)
133
- def point_coords(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
142
+ def point_coords(
143
+ elt_arg: self.domain.ElementArg,
144
+ qp_arg: self.Arg,
145
+ domain_element_index: ElementIndex,
146
+ element_index: ElementIndex,
147
+ qp_index: int,
148
+ ):
134
149
  return Coords(POINTS[qp_index, 0], POINTS[qp_index, 1], POINTS[qp_index, 2])
135
150
 
136
151
  return point_coords
@@ -139,7 +154,13 @@ class RegularQuadrature(Quadrature):
139
154
  WEIGHTS = self._WEIGHTS
140
155
 
141
156
  @cache.dynamic_func(suffix=self.name)
142
- def point_weight(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
157
+ def point_weight(
158
+ elt_arg: self.domain.ElementArg,
159
+ qp_arg: self.Arg,
160
+ domain_element_index: ElementIndex,
161
+ element_index: ElementIndex,
162
+ qp_index: int,
163
+ ):
143
164
  return WEIGHTS[qp_index]
144
165
 
145
166
  return point_weight
@@ -148,8 +169,14 @@ class RegularQuadrature(Quadrature):
148
169
  N = self._N
149
170
 
150
171
  @cache.dynamic_func(suffix=self.name)
151
- def point_index(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
152
- return N * element_index + qp_index
172
+ def point_index(
173
+ elt_arg: self.domain.ElementArg,
174
+ qp_arg: self.Arg,
175
+ domain_element_index: ElementIndex,
176
+ element_index: ElementIndex,
177
+ qp_index: int,
178
+ ):
179
+ return N * domain_element_index + qp_index
153
180
 
154
181
  return point_index
155
182
 
@@ -157,8 +184,8 @@ class RegularQuadrature(Quadrature):
157
184
  class NodalQuadrature(Quadrature):
158
185
  """Quadrature using space node points as quadrature points
159
186
 
160
- Note that in contrast to the `nodal=True` flag for :func:`integrate`, this quadrature odes not make any assumption
161
- about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
187
+ Note that in contrast to the `nodal=True` flag for :func:`integrate`, using this quadrature does not imply
188
+ any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
162
189
  """
163
190
 
164
191
  def __init__(self, domain: domain.GeometryDomain, space: FunctionSpace):
@@ -180,8 +207,8 @@ class NodalQuadrature(Quadrature):
180
207
  def total_point_count(self):
181
208
  return self._space.node_count()
182
209
 
183
- def points_per_element(self):
184
- return self._space.topology.NODES_PER_ELEMENT
210
+ def max_points_per_element(self):
211
+ return self._space.topology.MAX_NODES_PER_ELEMENT
185
212
 
186
213
  def _make_arg(self):
187
214
  @cache.dynamic_struct(suffix=self.name)
@@ -199,44 +226,67 @@ class NodalQuadrature(Quadrature):
199
226
  return arg
200
227
 
201
228
  def _make_point_count(self):
202
- N = self._space.topology.NODES_PER_ELEMENT
203
-
204
229
  @cache.dynamic_func(suffix=self.name)
205
- def point_count(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex):
206
- return N
230
+ def point_count(
231
+ elt_arg: self.domain.ElementArg,
232
+ qp_arg: self.Arg,
233
+ domain_element_index: ElementIndex,
234
+ element_index: ElementIndex,
235
+ ):
236
+ return self._space.topology.element_node_count(elt_arg, qp_arg.topo_arg, element_index)
207
237
 
208
238
  return point_count
209
239
 
210
240
  def _make_point_coords(self):
211
241
  @cache.dynamic_func(suffix=self.name)
212
- def point_coords(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
242
+ def point_coords(
243
+ elt_arg: self.domain.ElementArg,
244
+ qp_arg: self.Arg,
245
+ domain_element_index: ElementIndex,
246
+ element_index: ElementIndex,
247
+ qp_index: int,
248
+ ):
213
249
  return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
214
250
 
215
251
  return point_coords
216
252
 
217
253
  def _make_point_weight(self):
218
254
  @cache.dynamic_func(suffix=self.name)
219
- def point_weight(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
255
+ def point_weight(
256
+ elt_arg: self.domain.ElementArg,
257
+ qp_arg: self.Arg,
258
+ domain_element_index: ElementIndex,
259
+ element_index: ElementIndex,
260
+ qp_index: int,
261
+ ):
220
262
  return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
221
263
 
222
264
  return point_weight
223
265
 
224
266
  def _make_point_index(self):
225
267
  @cache.dynamic_func(suffix=self.name)
226
- def point_index(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
268
+ def point_index(
269
+ elt_arg: self.domain.ElementArg,
270
+ qp_arg: self.Arg,
271
+ domain_element_index: ElementIndex,
272
+ element_index: ElementIndex,
273
+ qp_index: int,
274
+ ):
227
275
  return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
228
276
 
229
277
  return point_index
230
278
 
231
279
 
232
280
  class ExplicitQuadrature(Quadrature):
233
- """Quadrature using explicit per-cell points and weights. The number of quadrature points per cell is assumed
234
- to be constant and deduced from the shape of the points and weights arrays.
281
+ """Quadrature using explicit per-cell points and weights.
282
+
283
+ The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
284
+ Quadrature points may be provided for either the whole geometry or just the domain's elements.
235
285
 
236
286
  Args:
237
287
  domain: Domain of definition of the quadrature formula
238
- points: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the coordinates of each quadrature point.
239
- weights: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the weight for each quadrature point.
288
+ points: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the coordinates of each quadrature point.
289
+ weights: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the weight for each quadrature point.
240
290
 
241
291
  See also: :class:`PicQuadrature`
242
292
  """
@@ -255,41 +305,78 @@ class ExplicitQuadrature(Quadrature):
255
305
  if points.shape != weights.shape:
256
306
  raise ValueError("Points and weights arrays must have the same shape")
257
307
 
308
+ if points.shape[0] == domain.geometry_element_count():
309
+ self.point_index = ExplicitQuadrature._point_index_geo
310
+ self.point_coords = ExplicitQuadrature._point_coords_geo
311
+ self.point_weight = ExplicitQuadrature._point_weight_geo
312
+ elif points.shape[0] == domain.element_count():
313
+ self.point_index = ExplicitQuadrature._point_index_domain
314
+ self.point_coords = ExplicitQuadrature._point_coords_domain
315
+ self.point_weight = ExplicitQuadrature._point_weight_domain
316
+ else:
317
+ raise NotImplementedError(
318
+ "The number of rows of points and weights must match the element count of either the domain or the geometry"
319
+ )
320
+
258
321
  self._points_per_cell = points.shape[1]
322
+ self._whole_geo = points.shape[0] == domain.geometry_element_count()
259
323
  self._points = points
260
324
  self._weights = weights
261
325
 
262
326
  @property
263
327
  def name(self):
264
- return f"{self.__class__.__name__}"
328
+ return f"{self.__class__.__name__}_{self._whole_geo}"
265
329
 
266
330
  def total_point_count(self):
267
331
  return self._weights.size
268
332
 
269
- def points_per_element(self):
333
+ def max_points_per_element(self):
270
334
  return self._points_per_cell
271
335
 
272
336
  @cache.cached_arg_value
273
337
  def arg_value(self, device):
274
338
  arg = self.Arg()
275
- arg.points_per_cell = self._points_per_cell
276
339
  arg.points = self._points.to(device)
277
340
  arg.weights = self._weights.to(device)
278
341
 
279
342
  return arg
280
343
 
281
344
  @wp.func
282
- def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
283
- return qp_arg.points_per_cell
345
+ def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
346
+ return qp_arg.points.shape[1]
284
347
 
285
348
  @wp.func
286
- def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
349
+ def _point_coords_domain(
350
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
351
+ ):
352
+ return qp_arg.points[domain_element_index, qp_index]
353
+
354
+ @wp.func
355
+ def _point_weight_domain(
356
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
357
+ ):
358
+ return qp_arg.weights[domain_element_index, qp_index]
359
+
360
+ @wp.func
361
+ def _point_index_domain(
362
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
363
+ ):
364
+ return qp_arg.points_per_cell * domain_element_index + qp_index
365
+
366
+ @wp.func
367
+ def _point_coords_geo(
368
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
369
+ ):
287
370
  return qp_arg.points[element_index, qp_index]
288
371
 
289
372
  @wp.func
290
- def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
373
+ def _point_weight_geo(
374
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
375
+ ):
291
376
  return qp_arg.weights[element_index, qp_index]
292
377
 
293
378
  @wp.func
294
- def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
379
+ def _point_index_geo(
380
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
381
+ ):
295
382
  return qp_arg.points_per_cell * element_index + qp_index