warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/fem/operator.py CHANGED
@@ -17,7 +17,15 @@ from typing import Any, Callable, Dict, Optional, Set
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem.linalg import skew_part, symmetric_part
20
- from warp.fem.types import Coords, Domain, ElementIndex, Field, NodeIndex, Sample, make_free_sample
20
+ from warp.fem.types import (
21
+ Coords,
22
+ Domain,
23
+ ElementIndex,
24
+ Field,
25
+ NodeIndex,
26
+ Sample,
27
+ make_free_sample,
28
+ )
21
29
 
22
30
 
23
31
  class Integrand:
@@ -33,7 +41,7 @@ class Integrand:
33
41
  self.kernel_options = {} if kernel_options is None else kernel_options
34
42
 
35
43
  # Operators for each field argument. This will be populated at first integrate call
36
- self.operators: Dict[str, Set[Operator]] = None
44
+ self.operators: Optional[Dict[str, Set[Operator]]] = None
37
45
 
38
46
 
39
47
  class Operator:
@@ -41,14 +49,21 @@ class Operator:
41
49
  Operators provide syntactic sugar over Field and Domain evaluation functions and arguments
42
50
  """
43
51
 
44
- def __init__(self, func: Callable, resolver: Callable, field_result: Callable = None):
52
+ def __init__(
53
+ self,
54
+ func: Callable,
55
+ resolver: Callable,
56
+ field_result: Optional[Callable] = None,
57
+ attr: Optional[str] = None,
58
+ ):
45
59
  self.func = func
46
60
  self.name = func.__name__
47
61
  self.resolver = resolver
62
+ self.attr = attr
48
63
  self.field_result = field_result
49
64
 
50
65
 
51
- def integrand(func: Callable = None, kernel_options: Optional[Dict[str, Any]] = None):
66
+ def integrand(func: Optional[Callable] = None, kernel_options: Optional[Dict[str, Any]] = None):
52
67
  """Decorator for functions to be integrated (or interpolated) using warp.fem
53
68
 
54
69
  Args:
@@ -83,84 +98,124 @@ def operator(**kwargs):
83
98
  # Domain operators
84
99
 
85
100
 
86
- @operator(resolver=lambda dmn: dmn.element_position)
101
+ @operator(resolver=lambda dmn: dmn.element_position, attr="geo")
87
102
  def position(domain: Domain, s: Sample):
88
103
  """Evaluates the world position of the sample point `s`"""
89
104
  pass
90
105
 
91
106
 
92
- @operator(resolver=lambda dmn: dmn.element_normal)
107
+ @operator(resolver=lambda dmn: dmn.element_normal, attr="geo")
93
108
  def normal(domain: Domain, s: Sample):
94
109
  """Evaluates the element normal at the sample point `s`. Non zero if the element is a side or the geometry is embedded in a higher-dimensional space (e.g. :class:`Trimesh3D`)"""
95
110
  pass
96
111
 
97
112
 
98
- @operator(resolver=lambda dmn: dmn.element_deformation_gradient)
113
+ @operator(resolver=lambda dmn: dmn.element_deformation_gradient, attr="geo")
99
114
  def deformation_gradient(domain: Domain, s: Sample):
100
115
  """Evaluates the gradient of the domain position with respect to the element reference space at the sample point `s`"""
101
116
  pass
102
117
 
103
118
 
104
- @operator(resolver=lambda dmn: dmn.element_lookup)
119
+ @operator(resolver=lambda dmn: dmn.element_lookup, attr="geo")
105
120
  def lookup(domain: Domain, x: Any) -> Sample:
106
- """Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the domain.
121
+ """Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the geometry.
107
122
 
108
123
  Args:
109
- x: world position of the point to look-up in the geometry
110
- guess: (optional) :class:`Sample` initial guess, may help perform the query
124
+ x (vec3): world position of the point to look-up in the geometry
125
+ max_dist (float): maximum distance to look for a closest point
126
+ guess (:class:`Sample`): initial guess, may help perform the query
127
+ filter_array (wp.array): Used in conjunction with `filter_target`. Only cells such that ``filter_array[element_index]==filter_target`` will be considered.
128
+ filter_target (Any): See `filter_array`
129
+ """
130
+ pass
131
+
132
+
133
+ @operator(resolver=lambda dmn: dmn.element_partition_lookup)
134
+ def partition_lookup(domain: Domain, x: Any) -> Sample:
135
+ """Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the geometry partition.
111
136
 
112
- Note:
113
- Currently this operator is unsupported for :class:`Hexmesh`, :class:`Quadmesh2D`, :class:`Quadmesh3D` and deformed geometries.
137
+ Args:
138
+ x (vec3): world position of the point to look-up in the geometry
139
+ max_dist (float): maximum distance to look for a closest point
114
140
  """
115
141
  pass
116
142
 
117
143
 
118
- @operator(resolver=lambda dmn: dmn.element_measure)
144
+ @operator(resolver=lambda dmn: dmn.element_measure, attr="geo")
119
145
  def measure(domain: Domain, s: Sample) -> float:
120
146
  """Returns the measure (volume, area, or length) determinant of an element at a sample point `s`"""
121
147
  pass
122
148
 
123
149
 
124
- @operator(resolver=lambda dmn: dmn.element_measure_ratio)
150
+ @operator(resolver=lambda dmn: dmn.element_measure_ratio, attr="geo")
125
151
  def measure_ratio(domain: Domain, s: Sample) -> float:
126
152
  """Returns the maximum ratio between the measure of this element and that of higher-dimensional neighbors."""
127
153
  pass
128
154
 
129
155
 
156
+ @operator(resolver=lambda dmn: dmn.element_closest_point, attr="geo")
157
+ def element_closest_point(domain: Domain, element_index: ElementIndex, x: Any) -> Sample:
158
+ """
159
+ Computes the coordinates of the closest point to a world position within a given element.
160
+ Returns a tuple (closest point coordinates; squared distance to the closest point)
161
+
162
+ Args:
163
+ element_index: Index of the element to consider
164
+ x: world position of the point to compute the closest point to
165
+ """
166
+ pass
167
+
168
+
169
+ @operator(resolver=lambda dmn: dmn.element_coordinates, attr="geo")
170
+ def element_coordinates(domain: Domain, element_index: ElementIndex, x: Any) -> Sample:
171
+ """Returns the coordinates in an element reference system corresponding to a work position.
172
+ The returned coordinates may be in the element's exterior.
173
+
174
+ Args:
175
+ element_index: Index of the element to consider
176
+ x: world position of the point to find coordinates for
177
+ """
178
+ pass
179
+
180
+
130
181
  # Operators for evaluating cell-level quantities on domains defined on sides
131
182
 
132
183
 
133
184
  @operator(
134
- resolver=lambda dmn: dmn.domain_cell_arg, field_result=lambda dmn: (dmn.cell_domain(), Domain, dmn.geometry.CellArg)
185
+ resolver=lambda dmn: dmn.domain_cell_arg,
186
+ field_result=lambda dmn: (dmn.cell_domain(), Domain, dmn.cell_domain().DomainArg),
135
187
  )
136
188
  def cells(domain: Domain) -> Domain:
137
189
  """Converts a domain defined on geometry sides to a domain defined of cells."""
138
190
  pass
139
191
 
140
192
 
141
- @operator(resolver=lambda dmn: dmn.element_inner_cell_index)
193
+ @operator(resolver=lambda dmn: dmn.element_inner_cell_index, attr="geo")
142
194
  def _inner_cell_index(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
143
195
  pass
144
196
 
145
197
 
146
- @operator(resolver=lambda dmn: dmn.element_outer_cell_index)
198
+ @operator(resolver=lambda dmn: dmn.element_outer_cell_index, attr="geo")
147
199
  def _outer_cell_index(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
148
200
  pass
149
201
 
150
202
 
151
- @operator(resolver=lambda dmn: dmn.element_inner_cell_coords)
203
+ @operator(resolver=lambda dmn: dmn.element_inner_cell_coords, attr="geo")
152
204
  def _inner_cell_coords(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
153
205
  pass
154
206
 
155
207
 
156
- @operator(resolver=lambda dmn: dmn.element_outer_cell_coords)
208
+ @operator(resolver=lambda dmn: dmn.element_outer_cell_coords, attr="geo")
157
209
  def _outer_cell_coords(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
158
210
  pass
159
211
 
160
212
 
161
- @operator(resolver=lambda dmn: dmn.cell_to_element_coords)
213
+ @operator(resolver=lambda dmn: dmn.cell_to_element_coords, attr="geo")
162
214
  def _cell_to_element_coords(
163
- domain: Domain, side_index: ElementIndex, cell_index: ElementIndex, cell_coords: Coords
215
+ domain: Domain,
216
+ side_index: ElementIndex,
217
+ cell_index: ElementIndex,
218
+ cell_coords: Coords,
164
219
  ) -> Sample:
165
220
  pass
166
221
 
@@ -169,7 +224,8 @@ def _cell_to_element_coords(
169
224
  def to_inner_cell(domain: Domain, s: Sample):
170
225
  """Converts a :class:`Sample` defined on a side to a sample defined on the side's inner cell"""
171
226
  return make_free_sample(
172
- _inner_cell_index(domain, s.element_index), _inner_cell_coords(domain, s.element_index, s.element_coords)
227
+ _inner_cell_index(domain, s.element_index),
228
+ _inner_cell_coords(domain, s.element_index, s.element_coords),
173
229
  )
174
230
 
175
231
 
@@ -177,7 +233,8 @@ def to_inner_cell(domain: Domain, s: Sample):
177
233
  def to_outer_cell(domain: Domain, s: Sample):
178
234
  """Converts a :class:`Sample` defined on a side to a sample defined on the side's outer cell"""
179
235
  return make_free_sample(
180
- _outer_cell_index(domain, s.element_index), _outer_cell_coords(domain, s.element_index, s.element_coords)
236
+ _outer_cell_index(domain, s.element_index),
237
+ _outer_cell_coords(domain, s.element_index, s.element_coords),
181
238
  )
182
239
 
183
240
 
@@ -186,10 +243,26 @@ def to_cell_side(domain: Domain, cell_s: Sample, side_index: ElementIndex):
186
243
  """Converts a :class:`Sample` defined on a cell to a sample defined on one of its side.
187
244
  If the result does not lie on the side `side_index`, the resulting coordinates will be set to ``OUTSIDE``."""
188
245
  return make_free_sample(
189
- side_index, _cell_to_element_coords(domain, side_index, cell_s.element_index, cell_s.element_coords)
246
+ side_index,
247
+ _cell_to_element_coords(domain, side_index, cell_s.element_index, cell_s.element_coords),
190
248
  )
191
249
 
192
250
 
251
+ @operator(resolver=lambda dmn: dmn.element_index, attr="index")
252
+ def element_index(domain: Domain, domain_element_index: ElementIndex):
253
+ """Returns the index in the geometry of the `domain_element_index`'th domain element."""
254
+ pass
255
+
256
+
257
+ @operator(resolver=lambda dmn: dmn.element_partition_index, attr="index")
258
+ def element_partition_index(domain: Domain, cell_index: ElementIndex):
259
+ """Returns the index of the passed cell in the domain's geometry partition, or `NULL_ELEMENT_INDEX` if not part of the partition.
260
+
261
+ :note: Currently only available for `fem.ElementKind.CELL` elements
262
+ """
263
+ pass
264
+
265
+
193
266
  # Field operators
194
267
  # On a side, inner and outer are such that normal goes from inner to outer
195
268
 
@@ -236,9 +309,33 @@ def degree(f: Field):
236
309
  pass
237
310
 
238
311
 
312
+ @operator(resolver=lambda f: f.node_count)
313
+ def node_count(f: Field, s: Sample):
314
+ """
315
+ Returns the number of nodes associated to the field `f` in the element containing the sample `s`
316
+ """
317
+
318
+
239
319
  @operator(resolver=lambda f: f.at_node)
240
- def at_node(f: Field, s: Sample):
241
- """For a Test or Trial field `f`, returns a copy of the Sample `s` moved to the coordinates of the node being evaluated"""
320
+ def at_node(f: Field, s: Sample, node_index_in_elt: Optional[int] = None):
321
+ """
322
+ Returns a copy of the Sample `s` moved to the coordinates of a local node of the field `f`.
323
+
324
+ If `f` is a discrete field, `node_index_in_elt` is required and indicates the element-local index of the node to consider.
325
+ If `f` is a Test or Trial field, `node_index_in_elt` **must not** be provided, and will be automatically set
326
+ to the test (resp. trial) node currently being evaluated.
327
+ """
328
+ pass
329
+
330
+
331
+ @operator(resolver=lambda f: f.node_index)
332
+ def node_index(f: Field, s: Sample, node_index_in_elt: Optional[int] = None):
333
+ """Returns the index in the function space of a local node the field `f`.
334
+
335
+ If `f` is a discrete field, `node_index_in_elt` is required and indicates the element-local index of the node to consider.
336
+ If `f` is a Test or Trial field, `node_index_in_elt` **must not** be provided, and will be automatically set
337
+ to the test (resp. trial) node currently being evaluated.
338
+ """
242
339
  pass
243
340
 
244
341
 
@@ -18,7 +18,7 @@ from typing import Any, Optional, Tuple, Union
18
18
  import warp as wp
19
19
  from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
20
20
  from warp.fem.domain import GeometryDomain
21
- from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
21
+ from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, make_free_sample
22
22
  from warp.fem.utils import compress_node_indices
23
23
 
24
24
  from .quadrature import Quadrature
@@ -32,10 +32,10 @@ class PicQuadrature(Quadrature):
32
32
  Args:
33
33
  domain: Underlying domain for the quadrature
34
34
  positions: Either an array containing the world positions of all particles, or a tuple of arrays containing
35
- the cell indices and coordinates for each particle. Note that the former requires the underlying geometry to
36
- define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
35
+ the cell indices and coordinates for each particle.
37
36
  measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
38
37
  If ``None``, defaults to the cell measure divided by the number of particles in the cell.
38
+ max_dist: When providing world positions that fall outside of the domain's geometry partition, maximum distance to look up for embedding cells
39
39
  requires_grad: Whether gradients should be allocated for the computed quantities
40
40
  temporary_store: shared pool from which to allocate temporary arrays
41
41
  """
@@ -52,12 +52,13 @@ class PicQuadrature(Quadrature):
52
52
  ],
53
53
  measures: Optional["wp.array(dtype=float)"] = None,
54
54
  requires_grad: bool = False,
55
+ max_dist: float = 0.0,
55
56
  temporary_store: TemporaryStore = None,
56
57
  ):
57
58
  super().__init__(domain)
58
59
 
59
60
  self._requires_grad = requires_grad
60
- self._bin_particles(positions, measures, temporary_store)
61
+ self._bin_particles(positions, measures, max_dist=max_dist, temporary_store=temporary_store)
61
62
  self._max_particles_per_cell: int = None
62
63
 
63
64
  @property
@@ -86,12 +87,15 @@ class PicQuadrature(Quadrature):
86
87
  @cached_arg_value
87
88
  def arg_value(self, device) -> Arg:
88
89
  arg = PicQuadrature.Arg()
89
- arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
90
- arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
91
- arg.particle_fraction = self._particle_fraction.to(device)
92
- arg.particle_coords = self.particle_coords.to(device)
90
+ self.fill_arg(arg, device)
93
91
  return arg
94
92
 
93
+ def fill_arg(self, args: Arg, device):
94
+ args.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
95
+ args.cell_particle_indices = self._cell_particle_indices.array.to(device)
96
+ args.particle_fraction = self._particle_fraction.to(device)
97
+ args.particle_coords = self.particle_coords.to(device)
98
+
95
99
  def total_point_count(self):
96
100
  return self.particle_coords.shape[0]
97
101
 
@@ -179,23 +183,37 @@ class PicQuadrature(Quadrature):
179
183
  cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
180
184
  cell_fraction[p] = 1.0 / float(cell_particle_count)
181
185
 
182
- def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
186
+ def _bin_particles(self, positions, measures, max_dist: float, temporary_store: TemporaryStore):
183
187
  if wp.types.is_array(positions):
188
+ device = positions.device
189
+ if not self.domain.supports_lookup(device):
190
+ raise RuntimeError(
191
+ "Attempting to build a PicQuadrature from positions on a domain that does not support global lookups"
192
+ )
193
+
194
+ cell_lookup = self.domain.element_partition_lookup
195
+ cell_coordinates = self.domain.element_coordinates
196
+
184
197
  # Initialize from positions
185
- @dynamic_kernel(suffix=f"{self.domain.name}")
198
+ @dynamic_kernel(suffix=self.domain.name)
186
199
  def bin_particles(
187
200
  cell_arg_value: self.domain.ElementArg,
201
+ domain_index_arg_value: self.domain.ElementIndexArg,
188
202
  positions: wp.array(dtype=positions.dtype),
203
+ max_dist: float,
189
204
  cell_index: wp.array(dtype=ElementIndex),
190
205
  cell_coords: wp.array(dtype=Coords),
191
206
  ):
192
207
  p = wp.tid()
193
- sample = self.domain.element_lookup(cell_arg_value, positions[p])
208
+ sample = cell_lookup(
209
+ self.domain.DomainArg(cell_arg_value, domain_index_arg_value), positions[p], max_dist
210
+ )
194
211
 
195
212
  cell_index[p] = sample.element_index
196
- cell_coords[p] = sample.element_coords
197
-
198
- device = positions.device
213
+ if sample.element_index == NULL_ELEMENT_INDEX:
214
+ cell_coords[p] = Coords(OUTSIDE)
215
+ else:
216
+ cell_coords[p] = cell_coordinates(cell_arg_value, sample.element_index, positions[p])
199
217
 
200
218
  self._cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
201
219
  self.cell_indices = self._cell_index_temp.array
@@ -210,7 +228,11 @@ class PicQuadrature(Quadrature):
210
228
  kernel=bin_particles,
211
229
  inputs=[
212
230
  self.domain.element_arg_value(device),
231
+ self.domain.element_index_arg_value(device),
213
232
  positions,
233
+ max_dist,
234
+ ],
235
+ outputs=[
214
236
  self.cell_indices,
215
237
  self.particle_coords,
216
238
  ],
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, Optional
16
+ from typing import Any, ClassVar, Optional
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
@@ -55,6 +55,12 @@ class Quadrature:
55
55
  arg = Quadrature.Arg()
56
56
  return arg
57
57
 
58
+ def fill_arg(self, arg: Arg, device):
59
+ """
60
+ Fill the argument with the value of the argument to be passed to device
61
+ """
62
+ pass
63
+
58
64
  def total_point_count(self):
59
65
  """Number of unique quadrature points that can be indexed by this rule.
60
66
  Returns a number such that `point_index()` is always smaller than this number.
@@ -199,12 +205,16 @@ class _QuadratureWithRegularEvaluationPoints(Quadrature):
199
205
  """Helper subclass for quadrature formulas which use a uniform number of
200
206
  evaluations points per element. Avoids building explicit mapping"""
201
207
 
208
+ _dynamic_attribute_constructors: ClassVar = {
209
+ "point_evaluation_index": lambda obj: obj._make_regular_point_evaluation_index(),
210
+ "evaluation_point_element_index": lambda obj: obj._make_regular_evaluation_point_element_index(),
211
+ }
212
+
202
213
  def __init__(self, domain: GeometryDomain, N: int):
203
214
  super().__init__(domain)
204
215
  self._EVALUATION_POINTS_PER_ELEMENT = N
205
216
 
206
- self.point_evaluation_index = self._make_regular_point_evaluation_index()
207
- self.evaluation_point_element_index = self._make_regular_evaluation_point_element_index()
217
+ cache.setup_dynamic_attributes(self, cls=__class__)
208
218
 
209
219
  ElementIndexArg = Quadrature.Arg
210
220
  element_index_arg_value = Quadrature.arg_value
@@ -254,7 +264,7 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
254
264
 
255
265
  # Cache common formulas so we do dot have to do h2d transfer for each call
256
266
  class CachedFormula:
257
- _cache = {}
267
+ _cache: ClassVar = {}
258
268
 
259
269
  def __init__(self, element: Element, order: int, family: Polynomial):
260
270
  self.points, self.weights = element.instantiate_quadrature(order, family)
@@ -263,9 +273,12 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
263
273
  @cache.cached_arg_value
264
274
  def arg_value(self, device):
265
275
  arg = RegularQuadrature.Arg()
276
+ self.fill_arg(arg, device)
277
+ return arg
278
+
279
+ def fill_arg(self, arg: "RegularQuadrature.Arg", device):
266
280
  arg.points = wp.array(self.points, device=device, dtype=Coords)
267
281
  arg.weights = wp.array(self.weights, device=device, dtype=float)
268
- return arg
269
282
 
270
283
  @staticmethod
271
284
  def get(element: Element, order: int, family: Polynomial):
@@ -277,6 +290,13 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
277
290
  RegularQuadrature.CachedFormula._cache[key] = quadrature
278
291
  return quadrature
279
292
 
293
+ _dynamic_attribute_constructors: ClassVar = {
294
+ "point_count": lambda obj: obj._make_point_count(),
295
+ "point_index": lambda obj: obj._make_point_index(),
296
+ "point_coords": lambda obj: obj._make_point_coords(),
297
+ "point_weight": lambda obj: obj._make_point_weight(),
298
+ }
299
+
280
300
  def __init__(
281
301
  self,
282
302
  domain: GeometryDomain,
@@ -289,10 +309,7 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
289
309
 
290
310
  super().__init__(domain, self._formula.count)
291
311
 
292
- self.point_count = self._make_point_count()
293
- self.point_index = self._make_point_index()
294
- self.point_coords = self._make_point_coords()
295
- self.point_weight = self._make_point_weight()
312
+ cache.setup_dynamic_attributes(self)
296
313
 
297
314
  @property
298
315
  def name(self):
@@ -315,6 +332,9 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
315
332
  def arg_value(self, device):
316
333
  return self._formula.arg_value(device)
317
334
 
335
+ def fill_arg(self, arg: "RegularQuadrature.Arg", device):
336
+ self._formula.fill_arg(arg, device)
337
+
318
338
  def _make_point_count(self):
319
339
  N = self._formula.count
320
340
 
@@ -374,7 +394,7 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
374
394
  class NodalQuadrature(Quadrature):
375
395
  """Quadrature using space node points as quadrature points
376
396
 
377
- Note that in contrast to the `nodal=True` flag for :func:`integrate`, using this quadrature does not imply
397
+ Note that in contrast to the `assembly="nodal"` flag for :func:`integrate`, using this quadrature does not imply
378
398
  any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
379
399
  """
380
400
 
@@ -412,10 +432,13 @@ class NodalQuadrature(Quadrature):
412
432
  @cache.cached_arg_value
413
433
  def arg_value(self, device):
414
434
  arg = self.Arg()
415
- arg.space_arg = self._space.space_arg_value(device)
416
- arg.topo_arg = self._space.topology.topo_arg_value(device)
435
+ self.fill_arg(arg, device)
417
436
  return arg
418
437
 
438
+ def fill_arg(self, arg: "NodalQuadrature.Arg", device):
439
+ self._space.fill_space_arg(arg.space_arg, device)
440
+ self._space.topology.fill_topo_arg(arg.topo_arg, device)
441
+
419
442
  def _make_point_count(self):
420
443
  @cache.dynamic_func(suffix=self.name)
421
444
  def point_count(
@@ -533,7 +556,7 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
533
556
 
534
557
  @property
535
558
  def name(self):
536
- return f"{self.__class__.__name__}_{self._whole_geo}"
559
+ return f"{self.__class__.__name__}_{self._whole_geo}_{self._points_per_cell}"
537
560
 
538
561
  def total_point_count(self):
539
562
  return self._weights.size
@@ -541,15 +564,16 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
541
564
  def max_points_per_element(self):
542
565
  return self._points_per_cell
543
566
 
544
- @cache.cached_arg_value
545
567
  def arg_value(self, device):
546
568
  arg = self.Arg()
569
+ self.fill_arg(arg, device)
570
+ return arg
571
+
572
+ def fill_arg(self, arg: "ExplicitQuadrature.Arg", device):
547
573
  arg.points_per_cell = self._points_per_cell
548
574
  arg.points = self._points.to(device)
549
575
  arg.weights = self._weights.to(device)
550
576
 
551
- return arg
552
-
553
577
  @wp.func
554
578
  def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
555
579
  return qp_arg.points.shape[1]
@@ -120,7 +120,7 @@ def make_polynomial_basis_space(
120
120
  # Degree-1 serendipity is always equivalent to Lagrange
121
121
  element_basis = ElementBasis.LAGRANGE
122
122
 
123
- shape = get_shape_function(geo.reference_cell(), geo.dimension, degree, element_basis, family)
123
+ shape = get_shape_function(geo.reference_cell().__class__, geo.dimension, degree, element_basis, family)
124
124
 
125
125
  if discontinuous or degree == 0 or element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
126
126
  return make_discontinuous_basis_space(geo, shape)