warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ from typing import Any, Optional, Tuple, Union
18
18
  import warp as wp
19
19
  from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
20
20
  from warp.fem.domain import GeometryDomain
21
- from warp.fem.types import Coords, ElementIndex, make_free_sample
21
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
22
22
  from warp.fem.utils import compress_node_indices
23
23
 
24
24
  from .quadrature import Quadrature
@@ -68,10 +68,10 @@ class PicQuadrature(Quadrature):
68
68
  def domain(self, domain: GeometryDomain):
69
69
  # Allow changing the quadrature domain as long as underlying geometry and element kind are the same
70
70
  if self.domain is not None and (
71
- domain.geometry != self.domain.geometry or domain.element_kind != self.domain.element_kind
71
+ domain.element_kind != self.domain.element_kind or domain.geometry.base != self.domain.geometry.base
72
72
  ):
73
73
  raise RuntimeError(
74
- "Cannot change the domain to that of a different Geometry and/or using different element kinds."
74
+ "The new domain must use the same base geometry and kind of elements as the current one."
75
75
  )
76
76
 
77
77
  self._domain = domain
@@ -89,11 +89,11 @@ class PicQuadrature(Quadrature):
89
89
  arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
90
90
  arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
91
91
  arg.particle_fraction = self._particle_fraction.to(device)
92
- arg.particle_coords = self._particle_coords.to(device)
92
+ arg.particle_coords = self.particle_coords.to(device)
93
93
  return arg
94
94
 
95
95
  def total_point_count(self):
96
- return self._particle_coords.shape[0]
96
+ return self.particle_coords.shape[0]
97
97
 
98
98
  def active_cell_count(self):
99
99
  """Number of cells containing at least one particle"""
@@ -136,6 +136,12 @@ class PicQuadrature(Quadrature):
136
136
  particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
137
137
  return particle_index
138
138
 
139
+ @wp.func
140
+ def point_evaluation_index(
141
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
142
+ ):
143
+ return qp_arg.cell_particle_offsets[element_index] + index
144
+
139
145
  def fill_element_mask(self, mask: "wp.array(dtype=int)"):
140
146
  """Fills a mask array such that all non-empty elements are set to 1, all empty elements to zero.
141
147
 
@@ -156,7 +162,7 @@ class PicQuadrature(Quadrature):
156
162
  element_mask: wp.array(dtype=int),
157
163
  ):
158
164
  i = wp.tid()
159
- element_mask[i] = wp.select(element_particle_offsets[i] == element_particle_offsets[i + 1], 1, 0)
165
+ element_mask[i] = wp.where(element_particle_offsets[i] == element_particle_offsets[i + 1], 0, 1)
160
166
 
161
167
  @wp.kernel
162
168
  def _compute_uniform_fraction(
@@ -167,9 +173,11 @@ class PicQuadrature(Quadrature):
167
173
  p = wp.tid()
168
174
 
169
175
  cell = cell_index[p]
170
- cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
171
-
172
- cell_fraction[p] = 1.0 / float(cell_particle_count)
176
+ if cell == NULL_ELEMENT_INDEX:
177
+ cell_fraction[p] = 0.0
178
+ else:
179
+ cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
180
+ cell_fraction[p] = 1.0 / float(cell_particle_count)
173
181
 
174
182
  def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
175
183
  if wp.types.is_array(positions):
@@ -189,13 +197,13 @@ class PicQuadrature(Quadrature):
189
197
 
190
198
  device = positions.device
191
199
 
192
- cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
193
- cell_index = cell_index_temp.array
200
+ self._cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
201
+ self.cell_indices = self._cell_index_temp.array
194
202
 
195
203
  self._particle_coords_temp = borrow_temporary(
196
204
  temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
197
205
  )
198
- self._particle_coords = self._particle_coords_temp.array
206
+ self.particle_coords = self._particle_coords_temp.array
199
207
 
200
208
  wp.launch(
201
209
  dim=positions.shape[0],
@@ -203,25 +211,28 @@ class PicQuadrature(Quadrature):
203
211
  inputs=[
204
212
  self.domain.element_arg_value(device),
205
213
  positions,
206
- cell_index,
207
- self._particle_coords,
214
+ self.cell_indices,
215
+ self.particle_coords,
208
216
  ],
209
217
  device=device,
210
218
  )
211
219
 
212
220
  else:
213
- cell_index, self._particle_coords = positions
214
- if cell_index.shape != self._particle_coords.shape:
221
+ self.cell_indices, self.particle_coords = positions
222
+ if self.cell_indices.shape != self.particle_coords.shape:
215
223
  raise ValueError("Cell index and coordinates arrays must have the same shape")
216
224
 
217
- cell_index_temp = None
225
+ self._cell_index_temp = None
218
226
  self._particle_coords_temp = None
219
227
 
220
228
  self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
221
- self.domain.geometry_element_count(), cell_index, return_unique_nodes=True, temporary_store=temporary_store
229
+ self.domain.geometry_element_count(),
230
+ self.cell_indices,
231
+ return_unique_nodes=True,
232
+ temporary_store=temporary_store,
222
233
  )
223
234
 
224
- self._compute_fraction(cell_index, measures, temporary_store)
235
+ self._compute_fraction(self.cell_indices, measures, temporary_store)
225
236
 
226
237
  def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
227
238
  device = cell_index.device
@@ -260,9 +271,13 @@ class PicQuadrature(Quadrature):
260
271
  cell_fraction: wp.array(dtype=float),
261
272
  ):
262
273
  p = wp.tid()
263
- sample = make_free_sample(cell_index[p], cell_coords[p])
264
274
 
265
- cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
275
+ cell = cell_index[p]
276
+ if cell == NULL_ELEMENT_INDEX:
277
+ cell_fraction[p] = 0.0
278
+ else:
279
+ sample = make_free_sample(cell_index[p], cell_coords[p])
280
+ cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
266
281
 
267
282
  wp.launch(
268
283
  dim=measures.shape[0],
@@ -271,7 +286,7 @@ class PicQuadrature(Quadrature):
271
286
  self.domain.element_arg_value(device),
272
287
  measures,
273
288
  cell_index,
274
- self._particle_coords,
289
+ self.particle_coords,
275
290
  self._particle_fraction,
276
291
  ],
277
292
  device=device,
@@ -13,17 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any
16
+ from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem import cache, domain
19
+ from warp.fem import cache
20
+ from warp.fem.domain import GeometryDomain
20
21
  from warp.fem.geometry import Element
21
22
  from warp.fem.space import FunctionSpace
22
- from warp.fem.types import Coords, ElementIndex
23
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
23
24
 
24
25
  from ..polynomial import Polynomial
25
26
 
26
27
 
28
+ @wp.struct
29
+ class QuadraturePointElementIndex:
30
+ domain_element_index: ElementIndex
31
+ qp_index_in_element: int
32
+
33
+
27
34
  class Quadrature:
28
35
  """Interface class for quadrature rules"""
29
36
 
@@ -33,7 +40,7 @@ class Quadrature:
33
40
 
34
41
  pass
35
42
 
36
- def __init__(self, domain: domain.GeometryDomain):
43
+ def __init__(self, domain: GeometryDomain):
37
44
  self._domain = domain
38
45
 
39
46
  @property
@@ -45,52 +52,197 @@ class Quadrature:
45
52
  """
46
53
  Value of the argument to be passed to device
47
54
  """
48
- arg = RegularQuadrature.Arg()
55
+ arg = Quadrature.Arg()
49
56
  return arg
50
57
 
51
58
  def total_point_count(self):
52
- """Total number of quadrature points over the domain"""
59
+ """Number of unique quadrature points that can be indexed by this rule.
60
+ Returns a number such that `point_index()` is always smaller than this number.
61
+ """
53
62
  raise NotImplementedError()
54
63
 
64
+ def evaluation_point_count(self):
65
+ """Number of quadrature points that needs to be evaluated, mostly for internal purposes.
66
+ If the indexing scheme is sparse, or if a quadrature point is shared among multiple elements
67
+ (e.g, nodal quadrature), `evaluation_point_count` may be different than `total_point_count()`.
68
+ Returns a number such that `evaluation_point_index()` is always smaller than this number.
69
+ """
70
+ return self.total_point_count()
71
+
55
72
  def max_points_per_element(self):
56
73
  """Maximum number of points per element if known, or ``None`` otherwise"""
57
74
  return None
58
75
 
59
76
  @staticmethod
60
- def point_count(elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex):
77
+ def point_count(
78
+ elt_arg: "GeometryDomain.ElementArg",
79
+ qp_arg: Arg,
80
+ domain_element_index: ElementIndex,
81
+ geo_element_index: ElementIndex,
82
+ ):
61
83
  """Number of quadrature points for a given element"""
62
84
  raise NotImplementedError()
63
85
 
64
86
  @staticmethod
65
87
  def point_coords(
66
- elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
88
+ elt_arg: "GeometryDomain.ElementArg",
89
+ qp_arg: Arg,
90
+ domain_element_index: ElementIndex,
91
+ geo_element_index: ElementIndex,
92
+ element_qp_index: int,
67
93
  ):
68
94
  """Coordinates in element of the element's qp_index'th quadrature point"""
69
95
  raise NotImplementedError()
70
96
 
71
97
  @staticmethod
72
98
  def point_weight(
73
- elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
99
+ elt_arg: "GeometryDomain.ElementArg",
100
+ qp_arg: Arg,
101
+ domain_element_index: ElementIndex,
102
+ geo_element_index: ElementIndex,
103
+ element_qp_index: int,
74
104
  ):
75
105
  """Weight of the element's qp_index'th quadrature point"""
76
106
  raise NotImplementedError()
77
107
 
78
108
  @staticmethod
79
109
  def point_index(
80
- elt_arg: "domain.GeometryDomain.ElementArg",
110
+ elt_arg: "GeometryDomain.ElementArg",
111
+ qp_arg: Arg,
112
+ domain_element_index: ElementIndex,
113
+ geo_element_index: ElementIndex,
114
+ element_qp_index: int,
115
+ ):
116
+ """
117
+ Global index of the element's qp_index'th quadrature point.
118
+ May be shared among elements.
119
+ This is what determines `qp_index` in integrands' `Sample` arguments.
120
+ """
121
+ raise NotImplementedError()
122
+
123
+ @staticmethod
124
+ def point_evaluation_index(
125
+ elt_arg: "GeometryDomain.ElementArg",
81
126
  qp_arg: Arg,
82
127
  domain_element_index: ElementIndex,
83
128
  geo_element_index: ElementIndex,
84
129
  element_qp_index: int,
85
130
  ):
86
- """Global index of the element's qp_index'th quadrature point"""
131
+ """Quadrature point index according to evaluation order.
132
+ Quadrature points for distinct elements must have different evaluation indices.
133
+ Mostly for internal/parallelization purposes.
134
+ """
87
135
  raise NotImplementedError()
88
136
 
89
137
  def __str__(self) -> str:
90
138
  return self.name
91
139
 
140
+ # By default cache the mapping from evaluation point indices to domain elements
141
+
142
+ ElementIndexArg = wp.array(dtype=QuadraturePointElementIndex)
143
+
144
+ @cache.cached_arg_value
145
+ def element_index_arg_value(self, device):
146
+ """Builds a map from quadrature point evaluation indices to their index in the element to which they belong"""
147
+
148
+ @cache.dynamic_kernel(f"{self.name}{self.domain.name}")
149
+ def quadrature_point_element_indices(
150
+ qp_arg: self.Arg,
151
+ domain_arg: self.domain.ElementArg,
152
+ domain_index_arg: self.domain.ElementIndexArg,
153
+ result: wp.array(dtype=QuadraturePointElementIndex),
154
+ ):
155
+ domain_element_index = wp.tid()
156
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
157
+
158
+ qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
159
+ for k in range(qp_point_count):
160
+ qp_eval_index = self.point_evaluation_index(domain_arg, qp_arg, domain_element_index, element_index, k)
161
+ result[qp_eval_index] = QuadraturePointElementIndex(domain_element_index, k)
162
+
163
+ null_qp_index = QuadraturePointElementIndex()
164
+ null_qp_index.domain_element_index = NULL_ELEMENT_INDEX
165
+ result = wp.full(
166
+ value=null_qp_index,
167
+ shape=(self.evaluation_point_count()),
168
+ dtype=QuadraturePointElementIndex,
169
+ device=device,
170
+ )
171
+ wp.launch(
172
+ quadrature_point_element_indices,
173
+ device=result.device,
174
+ dim=self.domain.element_count(),
175
+ inputs=[
176
+ self.arg_value(result.device),
177
+ self.domain.element_arg_value(result.device),
178
+ self.domain.element_index_arg_value(result.device),
179
+ result,
180
+ ],
181
+ )
182
+
183
+ return result
184
+
185
+ @wp.func
186
+ def evaluation_point_element_index(
187
+ element_index_arg: wp.array(dtype=QuadraturePointElementIndex),
188
+ qp_eval_index: QuadraturePointIndex,
189
+ ):
190
+ """Maps from quadrature point evaluation indices to their index in the element to which they belong
191
+ If the quadrature point does not exist, should return NULL_ELEMENT_INDEX as the domain element index
192
+ """
193
+
194
+ element_index = element_index_arg[qp_eval_index]
195
+ return element_index.domain_element_index, element_index.qp_index_in_element
196
+
197
+
198
+ class _QuadratureWithRegularEvaluationPoints(Quadrature):
199
+ """Helper subclass for quadrature formulas which use a uniform number of
200
+ evaluations points per element. Avoids building explicit mapping"""
201
+
202
+ def __init__(self, domain: GeometryDomain, N: int):
203
+ super().__init__(domain)
204
+ self._EVALUATION_POINTS_PER_ELEMENT = N
205
+
206
+ self.point_evaluation_index = self._make_regular_point_evaluation_index()
207
+ self.evaluation_point_element_index = self._make_regular_evaluation_point_element_index()
92
208
 
93
- class RegularQuadrature(Quadrature):
209
+ ElementIndexArg = Quadrature.Arg
210
+ element_index_arg_value = Quadrature.arg_value
211
+
212
+ def evaluation_point_count(self):
213
+ return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
214
+
215
+ def _make_regular_point_evaluation_index(self):
216
+ N = self._EVALUATION_POINTS_PER_ELEMENT
217
+
218
+ @cache.dynamic_func(suffix=f"{self.name}")
219
+ def evaluation_point_index(
220
+ elt_arg: self.domain.ElementArg,
221
+ qp_arg: self.Arg,
222
+ domain_element_index: ElementIndex,
223
+ element_index: ElementIndex,
224
+ qp_index: int,
225
+ ):
226
+ return N * domain_element_index + qp_index
227
+
228
+ return evaluation_point_index
229
+
230
+ def _make_regular_evaluation_point_element_index(self):
231
+ N = self._EVALUATION_POINTS_PER_ELEMENT
232
+
233
+ @cache.dynamic_func(suffix=f"{N}")
234
+ def quadrature_evaluation_point_element_index(
235
+ qp_arg: Quadrature.Arg,
236
+ qp_index: QuadraturePointIndex,
237
+ ):
238
+ domain_element_index = qp_index // N
239
+ index_in_element = qp_index - domain_element_index * N
240
+ return domain_element_index, index_in_element
241
+
242
+ return quadrature_evaluation_point_element_index
243
+
244
+
245
+ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
94
246
  """Regular quadrature formula, using a constant set of quadrature points per element"""
95
247
 
96
248
  @wp.struct
@@ -127,16 +279,15 @@ class RegularQuadrature(Quadrature):
127
279
 
128
280
  def __init__(
129
281
  self,
130
- domain: domain.GeometryDomain,
282
+ domain: GeometryDomain,
131
283
  order: int,
132
284
  family: Polynomial = None,
133
285
  ):
134
- super().__init__(domain)
135
-
286
+ self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
136
287
  self.family = family
137
288
  self.order = order
138
289
 
139
- self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
290
+ super().__init__(domain, self._formula.count)
140
291
 
141
292
  self.point_count = self._make_point_count()
142
293
  self.point_index = self._make_point_index()
@@ -227,17 +378,18 @@ class NodalQuadrature(Quadrature):
227
378
  any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
228
379
  """
229
380
 
230
- def __init__(self, domain: domain.GeometryDomain, space: FunctionSpace):
231
- super().__init__(domain)
232
-
381
+ def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
233
382
  self._space = space
234
383
 
384
+ super().__init__(domain)
385
+
235
386
  self.Arg = self._make_arg()
236
387
 
237
388
  self.point_count = self._make_point_count()
238
389
  self.point_index = self._make_point_index()
239
390
  self.point_coords = self._make_point_coords()
240
391
  self.point_weight = self._make_point_weight()
392
+ self.point_evaluation_index = self._make_point_evaluation_index()
241
393
 
242
394
  @property
243
395
  def name(self):
@@ -315,8 +467,26 @@ class NodalQuadrature(Quadrature):
315
467
 
316
468
  return point_index
317
469
 
470
+ def evaluation_point_count(self):
471
+ return self.domain.element_count() * self._space.topology.MAX_NODES_PER_ELEMENT
318
472
 
319
- class ExplicitQuadrature(Quadrature):
473
+ def _make_point_evaluation_index(self):
474
+ N = self._space.topology.MAX_NODES_PER_ELEMENT
475
+
476
+ @cache.dynamic_func(suffix=self.name)
477
+ def evaluation_point_index(
478
+ elt_arg: self.domain.ElementArg,
479
+ qp_arg: self.Arg,
480
+ domain_element_index: ElementIndex,
481
+ element_index: ElementIndex,
482
+ qp_index: int,
483
+ ):
484
+ return N * domain_element_index + qp_index
485
+
486
+ return evaluation_point_index
487
+
488
+
489
+ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
320
490
  """Quadrature using explicit per-cell points and weights.
321
491
 
322
492
  The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
@@ -336,11 +506,7 @@ class ExplicitQuadrature(Quadrature):
336
506
  points: wp.array2d(dtype=Coords)
337
507
  weights: wp.array2d(dtype=float)
338
508
 
339
- def __init__(
340
- self, domain: domain.GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"
341
- ):
342
- super().__init__(domain)
343
-
509
+ def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
344
510
  if points.shape != weights.shape:
345
511
  raise ValueError("Points and weights arrays must have the same shape")
346
512
 
@@ -358,7 +524,10 @@ class ExplicitQuadrature(Quadrature):
358
524
  )
359
525
 
360
526
  self._points_per_cell = points.shape[1]
527
+
361
528
  self._whole_geo = points.shape[0] == domain.geometry_element_count()
529
+
530
+ super().__init__(domain, self._points_per_cell)
362
531
  self._points = points
363
532
  self._weights = weights
364
533
 
@@ -112,7 +112,7 @@ def make_polynomial_basis_space(
112
112
  the constructed basis space
113
113
  """
114
114
 
115
- base_geo = geo.base if isinstance(geo, _geometry.DeformedGeometry) else geo
115
+ base_geo = geo.base
116
116
 
117
117
  if element_basis is None:
118
118
  element_basis = ElementBasis.LAGRANGE
@@ -19,7 +19,7 @@ import warp as wp
19
19
  from warp.fem import cache
20
20
  from warp.fem.geometry import Geometry
21
21
  from warp.fem.linalg import basis_element, generalized_inner, generalized_outer
22
- from warp.fem.types import Coords, ElementIndex, make_free_sample
22
+ from warp.fem.types import NULL_QP_INDEX, Coords, ElementIndex, make_free_sample
23
23
 
24
24
  from .basis_space import BasisSpace
25
25
  from .dof_mapper import DofMapper, IdentityMapper
@@ -305,7 +305,9 @@ class VectorValuedFunctionSpace(FunctionSpace):
305
305
  space_value: self.dtype,
306
306
  ):
307
307
  coords = self.node_coords_in_element(elt_arg, space_arg, element_index, node_index_in_elt)
308
- weight = self.element_inner_weight(elt_arg, space_arg, element_index, coords, node_index_in_elt)
308
+ weight = self.element_inner_weight(
309
+ elt_arg, space_arg, element_index, coords, node_index_in_elt, NULL_QP_INDEX
310
+ )
309
311
  local_value_map = self.local_value_map_inner(elt_arg, element_index, coords)
310
312
 
311
313
  unit_value = local_value_map * weight
@@ -21,7 +21,14 @@ import warp as wp
21
21
  from warp.fem import cache
22
22
  from warp.fem.geometry import Geometry
23
23
  from warp.fem.quadrature import Quadrature
24
- from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
24
+ from warp.fem.types import (
25
+ NULL_ELEMENT_INDEX,
26
+ NULL_QP_INDEX,
27
+ Coords,
28
+ ElementIndex,
29
+ QuadraturePointIndex,
30
+ make_free_sample,
31
+ )
25
32
 
26
33
  from .shape import ShapeFunction
27
34
  from .topology import RegularDiscontinuousSpaceTopology, SpaceTopology
@@ -235,6 +242,7 @@ class ShapeBasisSpace(BasisSpace):
235
242
  element_index: ElementIndex,
236
243
  coords: Coords,
237
244
  node_index_in_elt: int,
245
+ qp_index: QuadraturePointIndex,
238
246
  ):
239
247
  if wp.static(self.value == ShapeFunction.Value.Scalar):
240
248
  return shape_element_inner_weight(coords, node_index_in_elt)
@@ -254,6 +262,7 @@ class ShapeBasisSpace(BasisSpace):
254
262
  element_index: ElementIndex,
255
263
  coords: Coords,
256
264
  node_index_in_elt: int,
265
+ qp_index: QuadraturePointIndex,
257
266
  ):
258
267
  if wp.static(self.value == ShapeFunction.Value.Scalar):
259
268
  return shape_element_inner_weight_gradient(coords, node_index_in_elt)
@@ -373,6 +382,7 @@ class TraceBasisSpace(BasisSpace):
373
382
  element_index: ElementIndex,
374
383
  coords: Coords,
375
384
  node_index_in_elt: int,
385
+ qp_index: QuadraturePointIndex,
376
386
  ):
377
387
  cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
378
388
  if cell_index == NULL_ELEMENT_INDEX:
@@ -381,13 +391,7 @@ class TraceBasisSpace(BasisSpace):
381
391
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
382
392
 
383
393
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
384
- return cell_inner_weight(
385
- geo_cell_arg,
386
- basis_arg,
387
- cell_index,
388
- cell_coords,
389
- index_in_cell,
390
- )
394
+ return cell_inner_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
391
395
 
392
396
  return trace_element_inner_weight
393
397
 
@@ -401,6 +405,7 @@ class TraceBasisSpace(BasisSpace):
401
405
  element_index: ElementIndex,
402
406
  coords: Coords,
403
407
  node_index_in_elt: int,
408
+ qp_index: QuadraturePointIndex,
404
409
  ):
405
410
  cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
406
411
  if cell_index == NULL_ELEMENT_INDEX:
@@ -409,13 +414,7 @@ class TraceBasisSpace(BasisSpace):
409
414
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
410
415
 
411
416
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
412
- return cell_outer_weight(
413
- geo_cell_arg,
414
- basis_arg,
415
- cell_index,
416
- cell_coords,
417
- index_in_cell,
418
- )
417
+ return cell_outer_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
419
418
 
420
419
  return trace_element_outer_weight
421
420
 
@@ -429,6 +428,7 @@ class TraceBasisSpace(BasisSpace):
429
428
  element_index: ElementIndex,
430
429
  coords: Coords,
431
430
  node_index_in_elt: int,
431
+ qp_index: QuadraturePointIndex,
432
432
  ):
433
433
  cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
434
434
  if cell_index == NULL_ELEMENT_INDEX:
@@ -436,7 +436,9 @@ class TraceBasisSpace(BasisSpace):
436
436
 
437
437
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
438
438
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
439
- return cell_inner_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
439
+ return cell_inner_weight_gradient(
440
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
441
+ )
440
442
 
441
443
  return trace_element_inner_weight_gradient
442
444
 
@@ -450,6 +452,7 @@ class TraceBasisSpace(BasisSpace):
450
452
  element_index: ElementIndex,
451
453
  coords: Coords,
452
454
  node_index_in_elt: int,
455
+ qp_index: QuadraturePointIndex,
453
456
  ):
454
457
  cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
455
458
  if cell_index == NULL_ELEMENT_INDEX:
@@ -457,7 +460,9 @@ class TraceBasisSpace(BasisSpace):
457
460
 
458
461
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
459
462
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
460
- return cell_outer_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
463
+ return cell_outer_weight_gradient(
464
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
465
+ )
461
466
 
462
467
  return trace_element_outer_weight_gradient
463
468
 
@@ -624,11 +629,12 @@ class PointBasisSpace(BasisSpace):
624
629
  element_index: ElementIndex,
625
630
  coords: Coords,
626
631
  node_index_in_elt: int,
632
+ qp_index: QuadraturePointIndex,
627
633
  ):
628
634
  qp_coord = self._quadrature.point_coords(
629
635
  elt_arg, basis_arg, element_index, element_index, node_index_in_elt
630
636
  )
631
- return wp.select(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 0.0, 1.0)
637
+ return wp.where(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 1.0, 0.0)
632
638
 
633
639
  return element_inner_weight
634
640
 
@@ -642,6 +648,7 @@ class PointBasisSpace(BasisSpace):
642
648
  element_index: ElementIndex,
643
649
  coords: Coords,
644
650
  node_index_in_elt: int,
651
+ qp_index: QuadraturePointIndex,
645
652
  ):
646
653
  return gradient_vec(0.0)
647
654
 
@@ -237,13 +237,13 @@ class HexmeshSpaceTopology(SpaceTopology):
237
237
  hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
238
238
  v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
239
239
  v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
240
- return wp.select(v0 > v1, 1.0, -1.0)
240
+ return wp.where(v0 > v1, -1.0, 1.0)
241
241
 
242
242
  if wp.static(FACE_NODE_COUNT > 0):
243
243
  if node_type == CubeShapeFunction.FACE:
244
244
  face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
245
245
  flip = face_index_and_ori[1] & 1
246
- return wp.select(flip == 0, -1.0, 1.0)
246
+ return wp.where(flip == 0, 1.0, -1.0)
247
247
 
248
248
  return 1.0
249
249