warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,2 +1,2 @@
1
- from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature
1
+ from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature
2
2
  from .pic_quadrature import PicQuadrature
@@ -1,11 +1,11 @@
1
- from typing import Any
1
+ from typing import Union, Tuple, Any, Optional
2
2
 
3
3
  import warp as wp
4
4
 
5
5
  from warp.fem.domain import GeometryDomain
6
- from warp.fem.types import ElementIndex, Coords
6
+ from warp.fem.types import ElementIndex, Coords, make_free_sample
7
7
  from warp.fem.utils import compress_node_indices
8
- from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary
8
+ from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary, dynamic_kernel
9
9
 
10
10
  from .quadrature import Quadrature
11
11
 
@@ -14,24 +14,36 @@ wp.set_module_options({"enable_backward": False})
14
14
 
15
15
 
16
16
  class PicQuadrature(Quadrature):
17
- """Particle-based quadrature formula, using a global set of points irregularely spread out over geometry elements.
17
+ """Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
18
18
 
19
19
  Useful for Particle-In-Cell and derived methods.
20
+
21
+ Args:
22
+ domain: Undelying domain for the qaudrature
23
+ positions: Either an array containing the world positions of all particles, or a tuple of arrays containing
24
+ the cell indices and coordinates for each particle. Note that the former requires the underlying geometry to
25
+ define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
26
+ measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
27
+ If ``None``, defaults to the cell measure divided by the number of particles in the cell.
28
+ temporary_store: shared pool from which to allocate temporary arrays
20
29
  """
21
30
 
22
31
  def __init__(
23
32
  self,
24
33
  domain: GeometryDomain,
25
- positions: "wp.array()",
26
- measures: "wp.array(dtype=float)",
34
+ positions: Union[
35
+ "wp.array(dtype=wp.vecXd)",
36
+ Tuple[
37
+ "wp.array(dtype=ElementIndex)",
38
+ "wp.array(dtype=Coords)",
39
+ ],
40
+ ],
41
+ measures: Optional["wp.array(dtype=float)"] = None,
27
42
  temporary_store: TemporaryStore = None,
28
43
  ):
29
44
  super().__init__(domain)
30
45
 
31
- self.positions = positions
32
- self.measures = measures
33
-
34
- self._bin_particles(temporary_store)
46
+ self._bin_particles(positions, measures, temporary_store)
35
47
 
36
48
  @property
37
49
  def name(self):
@@ -61,12 +73,16 @@ class PicQuadrature(Quadrature):
61
73
  arg = PicQuadrature.Arg()
62
74
  arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
63
75
  arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
64
- arg.particle_fraction = self._particle_fraction.array.to(device)
65
- arg.particle_coords = self._particle_coords.array.to(device)
76
+ arg.particle_fraction = self._particle_fraction.to(device)
77
+ arg.particle_coords = self._particle_coords.to(device)
66
78
  return arg
67
79
 
68
80
  def total_point_count(self):
69
- return self.positions.shape[0]
81
+ return self._particle_coords.shape[0]
82
+
83
+ def active_cell_count(self):
84
+ """Number of cells containing at least one particle"""
85
+ return self._cell_count
70
86
 
71
87
  @wp.func
72
88
  def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
@@ -109,52 +125,121 @@ class PicQuadrature(Quadrature):
109
125
  i = wp.tid()
110
126
  element_mask[i] = wp.select(element_particle_offsets[i] == element_particle_offsets[i + 1], 1, 0)
111
127
 
112
- def _bin_particles(self, temporary_store: TemporaryStore):
113
- from warp.fem import cache
114
-
115
- @cache.dynamic_kernel(suffix=f"{self.domain.name}")
116
- def bin_particles(
117
- cell_arg_value: self.domain.ElementArg,
118
- positions: wp.array(dtype=self.positions.dtype),
119
- measures: wp.array(dtype=float),
120
- cell_index: wp.array(dtype=ElementIndex),
121
- cell_coords: wp.array(dtype=Coords),
122
- cell_fraction: wp.array(dtype=float),
123
- ):
124
- p = wp.tid()
125
- sample = self.domain.element_lookup(cell_arg_value, positions[p])
128
+ @wp.kernel
129
+ def _compute_uniform_fraction(
130
+ cell_index: wp.array(dtype=ElementIndex),
131
+ cell_particle_offsets: wp.array(dtype=int),
132
+ cell_fraction: wp.array(dtype=float),
133
+ ):
134
+ p = wp.tid()
126
135
 
127
- cell_index[p] = sample.element_index
136
+ cell = cell_index[p]
137
+ cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
128
138
 
129
- cell_coords[p] = sample.element_coords
130
- cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
139
+ cell_fraction[p] = 1.0 / float(cell_particle_count)
131
140
 
132
- device = self.positions.device
141
+ def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
142
+ if wp.types.is_array(positions):
143
+ # Initialize from positions
144
+ @dynamic_kernel(suffix=f"{self.domain.name}")
145
+ def bin_particles(
146
+ cell_arg_value: self.domain.ElementArg,
147
+ positions: wp.array(dtype=positions.dtype),
148
+ cell_index: wp.array(dtype=ElementIndex),
149
+ cell_coords: wp.array(dtype=Coords),
150
+ ):
151
+ p = wp.tid()
152
+ sample = self.domain.element_lookup(cell_arg_value, positions[p])
133
153
 
134
- cell_index = borrow_temporary(temporary_store, shape=self.positions.shape, dtype=int, device=device)
135
- self._particle_coords = borrow_temporary(
136
- temporary_store, shape=self.positions.shape, dtype=Coords, device=device
137
- )
138
- self._particle_fraction = borrow_temporary(
139
- temporary_store, shape=self.positions.shape, dtype=float, device=device
140
- )
154
+ cell_index[p] = sample.element_index
155
+ cell_coords[p] = sample.element_coords
141
156
 
142
- wp.launch(
143
- dim=self.positions.shape[0],
144
- kernel=bin_particles,
145
- inputs=[
146
- self.domain.element_arg_value(device),
147
- self.positions,
148
- self.measures,
149
- cell_index.array,
150
- self._particle_coords.array,
151
- self._particle_fraction.array,
152
- ],
153
- device=device,
154
- )
157
+ device = positions.device
158
+
159
+ cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
160
+ cell_index = cell_index_temp.array
161
+
162
+ self._particle_coords_temp = borrow_temporary(
163
+ temporary_store, shape=positions.shape, dtype=Coords, device=device
164
+ )
165
+ self._particle_coords = self._particle_coords_temp.array
166
+
167
+ wp.launch(
168
+ dim=positions.shape[0],
169
+ kernel=bin_particles,
170
+ inputs=[
171
+ self.domain.element_arg_value(device),
172
+ positions,
173
+ cell_index,
174
+ self._particle_coords,
175
+ ],
176
+ device=device,
177
+ )
178
+
179
+ else:
180
+ cell_index, self._particle_coords = positions
181
+ if cell_index.shape != self._particle_coords.shape:
182
+ raise ValueError("Cell index and coordinates arrays must have the same shape")
183
+
184
+ cell_index_temp = None
185
+ self._particle_coords_temp = None
155
186
 
156
187
  self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
157
- self.domain.geometry_element_count(), cell_index.array
188
+ self.domain.geometry_element_count(), cell_index
158
189
  )
159
190
 
160
- cell_index.release()
191
+ self._compute_fraction(cell_index, measures, temporary_store)
192
+
193
+ def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
194
+ device = cell_index.device
195
+
196
+ self._particle_fraction_temp = borrow_temporary(
197
+ temporary_store, shape=cell_index.shape, dtype=float, device=device
198
+ )
199
+ self._particle_fraction = self._particle_fraction_temp.array
200
+
201
+ if measures is None:
202
+ # Split fraction uniformly over all particles in cell
203
+
204
+ wp.launch(
205
+ dim=cell_index.shape,
206
+ kernel=PicQuadrature._compute_uniform_fraction,
207
+ inputs=[
208
+ cell_index,
209
+ self._cell_particle_offsets.array,
210
+ self._particle_fraction,
211
+ ],
212
+ device=device,
213
+ )
214
+
215
+ else:
216
+ # Fraction from particle measure
217
+
218
+ if measures.shape != cell_index.shape:
219
+ raise ValueError("Measures should be an 1d array or length equal to particle count")
220
+
221
+ @dynamic_kernel(suffix=f"{self.domain.name}")
222
+ def compute_fraction(
223
+ cell_arg_value: self.domain.ElementArg,
224
+ measures: wp.array(dtype=float),
225
+ cell_index: wp.array(dtype=ElementIndex),
226
+ cell_coords: wp.array(dtype=Coords),
227
+ cell_fraction: wp.array(dtype=float),
228
+ ):
229
+ p = wp.tid()
230
+ sample = make_free_sample(cell_index[p], cell_coords[p])
231
+
232
+ cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
233
+
234
+ wp.launch(
235
+ dim=measures.shape[0],
236
+ kernel=compute_fraction,
237
+ inputs=[
238
+ self.domain.element_arg_value(device),
239
+ measures,
240
+ cell_index,
241
+ self._particle_coords,
242
+ self._particle_fraction,
243
+ ],
244
+ device=device,
245
+ )
@@ -35,33 +35,37 @@ class Quadrature:
35
35
 
36
36
  def total_point_count(self):
37
37
  """Total number of quadrature points over the domain"""
38
- pass
38
+ raise NotImplementedError()
39
+
40
+ def points_per_element(self):
41
+ """Number of points per element if constant, or ``None`` if varying"""
42
+ return None
39
43
 
40
44
  @staticmethod
41
45
  def point_count(elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex):
42
46
  """Number of quadrature points for a given element"""
43
- pass
47
+ raise NotImplementedError()
44
48
 
45
49
  @staticmethod
46
50
  def point_coords(
47
51
  elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
48
52
  ):
49
53
  """Coordinates in element of the element's qp_index'th quadrature point"""
50
- pass
54
+ raise NotImplementedError()
51
55
 
52
56
  @staticmethod
53
57
  def point_weight(
54
58
  elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
55
59
  ):
56
60
  """Weight of the element's qp_index'th quadrature point"""
57
- pass
61
+ raise NotImplementedError()
58
62
 
59
63
  @staticmethod
60
64
  def point_index(
61
65
  elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
62
66
  ):
63
67
  """Global index of the element's qp_index'th quadrature point"""
64
- pass
68
+ raise NotImplementedError()
65
69
 
66
70
  def __str__(self) -> str:
67
71
  return self.name
@@ -98,13 +102,14 @@ class RegularQuadrature(Quadrature):
98
102
 
99
103
  @property
100
104
  def name(self):
101
- return (
102
- f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
103
- )
105
+ return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
104
106
 
105
107
  def total_point_count(self):
106
108
  return len(self.points) * self.domain.geometry_element_count()
107
109
 
110
+ def points_per_element(self):
111
+ return self._N
112
+
108
113
  @property
109
114
  def points(self):
110
115
  return self._element_quadrature[0]
@@ -153,7 +158,7 @@ class RegularQuadrature(Quadrature):
153
158
  class NodalQuadrature(Quadrature):
154
159
  """Quadrature using space node points as quadrature points
155
160
 
156
- Note that in constant to the `nodal=True` flag for :func:`integrate`, this quadrature odes not make any assumption
161
+ Note that in contrast to the `nodal=True` flag for :func:`integrate`, this quadrature odes not make any assumption
157
162
  about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
158
163
  """
159
164
 
@@ -176,6 +181,9 @@ class NodalQuadrature(Quadrature):
176
181
  def total_point_count(self):
177
182
  return self._space.node_count()
178
183
 
184
+ def points_per_element(self):
185
+ return self._space.topology.NODES_PER_ELEMENT
186
+
179
187
  def _make_arg(self):
180
188
  @cache.dynamic_struct(suffix=self.name)
181
189
  class Arg:
@@ -220,3 +228,67 @@ class NodalQuadrature(Quadrature):
220
228
  return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
221
229
 
222
230
  return point_index
231
+
232
+
233
+ class ExplicitQuadrature(Quadrature):
234
+ """Quadrature using explicit per-cell points and weights. The number of quadrature points per cell is assumed
235
+ to be constant and deduced from the shape of the points and weights arrays.
236
+
237
+ Args:
238
+ domain: Domain of definition of the quadrature formula
239
+ points: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the coordinates of each quadrature point.
240
+ weights: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the weight for each quadrature point.
241
+
242
+ See also: :class:`PicQuadrature`
243
+ """
244
+
245
+ @wp.struct
246
+ class Arg:
247
+ points_per_cell: int
248
+ points: wp.array2d(dtype=Coords)
249
+ weights: wp.array2d(dtype=float)
250
+
251
+ def __init__(self, domain: domain.GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
252
+ super().__init__(domain)
253
+
254
+ if points.shape != weights.shape:
255
+ raise ValueError("Points and weights arrays must have the same shape")
256
+
257
+ self._points_per_cell = points.shape[1]
258
+ self._points = points
259
+ self._weights = weights
260
+
261
+ @property
262
+ def name(self):
263
+ return f"{self.__class__.__name__}"
264
+
265
+ def total_point_count(self):
266
+ return self._weights.size
267
+
268
+ def points_per_element(self):
269
+ return self._points_per_cell
270
+
271
+ @cache.cached_arg_value
272
+ def arg_value(self, device):
273
+ arg = self.Arg()
274
+ arg.points_per_cell = self._points_per_cell
275
+ arg.points = self._points.to(device)
276
+ arg.weights = self._weights.to(device)
277
+
278
+ return arg
279
+
280
+ @wp.func
281
+ def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
282
+ return qp_arg.points_per_cell
283
+
284
+ @wp.func
285
+ def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
286
+ return qp_arg.points[element_index, qp_index]
287
+
288
+ @wp.func
289
+ def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
290
+ return qp_arg.weights[element_index, qp_index]
291
+
292
+ @wp.func
293
+ def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
294
+ return qp_arg.points_per_cell * element_index + qp_index
@@ -7,7 +7,7 @@ import warp.fem.polynomial as _polynomial
7
7
 
8
8
  from .function_space import FunctionSpace
9
9
  from .topology import SpaceTopology
10
- from .basis_space import BasisSpace
10
+ from .basis_space import BasisSpace, PointBasisSpace
11
11
  from .collocated_function_space import CollocatedFunctionSpace
12
12
 
13
13
  from .grid_2d_function_space import (