warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -13,16 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any
16
+ from functools import cached_property
17
+ from typing import Any, ClassVar
17
18
 
18
19
  import warp as wp
19
20
  from warp.fem import cache
20
- from warp.fem.types import Coords, ElementIndex, Sample, make_free_sample
21
+ from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, ElementKind, Sample, make_free_sample
21
22
 
22
23
  from .element import Element
23
24
 
24
25
  _mat32 = wp.mat(shape=(3, 2), dtype=float)
25
26
 
27
+ _NULL_BVH_ID = wp.uint64(0)
28
+ _COORD_LOOKUP_ITERATIONS = 24
29
+ _COORD_LOOKUP_STEP = 1.0
30
+ _COORD_LOOKUP_EPS = float(2**-20)
31
+ _BVH_MIN_PADDING = float(2**-16)
32
+ _BVH_MAX_PADDING = float(2**16)
33
+
26
34
 
27
35
  class Geometry:
28
36
  """
@@ -33,6 +41,8 @@ class Geometry:
33
41
 
34
42
  dimension: int = 0
35
43
 
44
+ _bvhs = None
45
+
36
46
  def cell_count(self):
37
47
  """Number of cells in the geometry"""
38
48
  raise NotImplementedError
@@ -83,6 +93,10 @@ class Geometry:
83
93
  """Value of the arguments to be passed to cell-related device functions"""
84
94
  raise NotImplementedError
85
95
 
96
+ def fill_cell_arg(self, args: "Geometry.CellArg", device):
97
+ """Fill the arguments to be passed to cell-related device functions"""
98
+ raise NotImplementedError
99
+
86
100
  @staticmethod
87
101
  def cell_position(args: "Geometry.CellArg", s: "Sample"):
88
102
  """Device function returning the world position of a cell sample point"""
@@ -100,16 +114,6 @@ class Geometry:
100
114
  """
101
115
  raise NotImplementedError
102
116
 
103
- @staticmethod
104
- def cell_lookup(args: "Geometry.CellArg", pos: Any):
105
- """Device function returning the cell sample point corresponding to a world position"""
106
- raise NotImplementedError
107
-
108
- @staticmethod
109
- def cell_lookup(args: "Geometry.CellArg", pos: Any, guess: "Sample"):
110
- """Device function returning the cell sample point corresponding to a world position. Can use guess for faster lookup"""
111
- raise NotImplementedError
112
-
113
117
  @staticmethod
114
118
  def cell_measure(args: "Geometry.CellArg", s: "Sample"):
115
119
  """Device function returning the measure determinant (e.g. volume, area) at a given point"""
@@ -130,6 +134,10 @@ class Geometry:
130
134
  """Value of the arguments to be passed to side-related device functions"""
131
135
  raise NotImplementedError
132
136
 
137
+ def fill_side_arg(self, args: "Geometry.SideArg", device):
138
+ """Fill the arguments to be passed to side-related device functions"""
139
+ raise NotImplementedError
140
+
133
141
  @staticmethod
134
142
  def boundary_side_index(args: "Geometry.SideIndexArg", boundary_side_index: int):
135
143
  """Device function returning the side index corresponding to a boundary side"""
@@ -212,16 +220,21 @@ class Geometry:
212
220
  # Default implementations for dependent quantities
213
221
  # Can be overridden in derived classes if more efficient implementations exist
214
222
 
215
- def _make_default_dependent_implementations(self):
216
- self.cell_inverse_deformation_gradient = self._make_cell_inverse_deformation_gradient()
217
- self.cell_measure = self._make_cell_measure()
218
- self.cell_normal = self._make_cell_normal()
223
+ _dynamic_attribute_constructors: ClassVar = {
224
+ "cell_inverse_deformation_gradient": lambda obj: obj._make_cell_inverse_deformation_gradient(),
225
+ "cell_measure": lambda obj: obj._make_cell_measure(),
226
+ "cell_normal": lambda obj: obj._make_cell_normal(),
227
+ "side_inverse_deformation_gradient": lambda obj: obj._make_side_inverse_deformation_gradient(),
228
+ "side_inner_inverse_deformation_gradient": lambda obj: obj._make_side_inner_inverse_deformation_gradient(),
229
+ "side_outer_inverse_deformation_gradient": lambda obj: obj._make_side_outer_inverse_deformation_gradient(),
230
+ "side_measure": lambda obj: obj._make_side_measure(),
231
+ "side_measure_ratio": lambda obj: obj._make_side_measure_ratio(),
232
+ "side_normal": lambda obj: obj._make_side_normal(),
233
+ "compute_cell_bounds": lambda obj: obj._make_compute_cell_bounds(),
234
+ }
219
235
 
220
- self.side_inner_inverse_deformation_gradient = self._make_side_inner_inverse_deformation_gradient()
221
- self.side_outer_inverse_deformation_gradient = self._make_side_outer_inverse_deformation_gradient()
222
- self.side_measure = self._make_side_measure()
223
- self.side_measure_ratio = self._make_side_measure_ratio()
224
- self.side_normal = self._make_side_normal()
236
+ def _make_default_dependent_implementations(self):
237
+ cache.setup_dynamic_attributes(self, cls=__class__)
225
238
 
226
239
  @wp.func
227
240
  def _element_measure(F: wp.vec2):
@@ -247,7 +260,7 @@ class Geometry:
247
260
 
248
261
  @wp.func
249
262
  def _element_normal(F: wp.vec2):
250
- return wp.normalize(wp.vec2(-F[1], F[0]))
263
+ return wp.normalize(wp.vec2(F[1], -F[0]))
251
264
 
252
265
  @wp.func
253
266
  def _element_normal(F: _mat32):
@@ -302,6 +315,35 @@ class Geometry:
302
315
 
303
316
  return cell_inverse_deformation_gradient if cell_dim == geo_dim else cell_pseudoinverse_deformation_gradient
304
317
 
318
+ def _make_side_inverse_deformation_gradient(self):
319
+ side_dim = self.reference_side().dimension
320
+ geo_dim = self.dimension
321
+
322
+ if side_dim == geo_dim:
323
+
324
+ @cache.dynamic_func(suffix=self.name)
325
+ def side_inverse_deformation_gradient(side_arg: self.SideArg, s: Sample):
326
+ return wp.inverse(self.side_deformation_gradient(side_arg, s))
327
+
328
+ return side_inverse_deformation_gradient
329
+
330
+ if side_dim == 1:
331
+
332
+ @cache.dynamic_func(suffix=self.name)
333
+ def edge_pseudoinverse_deformation_gradient(side_arg: self.SideArg, s: Sample):
334
+ F = self.side_deformation_gradient(side_arg, s)
335
+ return wp.matrix_from_rows(F / wp.dot(F, F))
336
+
337
+ return edge_pseudoinverse_deformation_gradient
338
+
339
+ @cache.dynamic_func(suffix=self.name)
340
+ def side_pseudoinverse_deformation_gradient(side_arg: self.SideArg, s: Sample):
341
+ F = self.side_deformation_gradient(side_arg, s)
342
+ Ft = wp.transpose(F)
343
+ return wp.inverse(Ft * F) * Ft
344
+
345
+ return side_pseudoinverse_deformation_gradient
346
+
305
347
  def _make_side_measure(self):
306
348
  REF_MEASURE = wp.constant(self.reference_side().measure())
307
349
 
@@ -360,3 +402,262 @@ class Geometry:
360
402
  return self.cell_inverse_deformation_gradient(cell_arg, make_free_sample(cell_index, cell_coords))
361
403
 
362
404
  return side_outer_inverse_deformation_gradient
405
+
406
+ def _make_element_coordinates(self, element_kind: ElementKind, assume_linear: bool = False):
407
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
408
+
409
+ if element_kind == ElementKind.CELL:
410
+ ref_elt = self.reference_cell()
411
+ arg_type = self.CellArg
412
+ elt_pos = self.cell_position
413
+ elt_inv_grad = self.cell_inverse_deformation_gradient
414
+ else:
415
+ ref_elt = self.reference_side()
416
+ arg_type = self.SideArg
417
+ elt_pos = self.side_position
418
+ elt_inv_grad = self.side_inverse_deformation_gradient
419
+
420
+ elt_center = Coords(ref_elt.center())
421
+
422
+ ITERATIONS = 1 if assume_linear else _COORD_LOOKUP_ITERATIONS
423
+ STEP = 1.0 if assume_linear else _COORD_LOOKUP_STEP
424
+
425
+ @cache.dynamic_func(suffix=f"{self.name}{element_kind}{assume_linear}")
426
+ def element_coordinates(args: arg_type, element_index: ElementIndex, pos: pos_type):
427
+ coords = elt_center
428
+
429
+ # Newton loop (single iteration in linear case)
430
+ for _k in range(ITERATIONS):
431
+ s = make_free_sample(element_index, coords)
432
+ x = elt_pos(args, s)
433
+ dc = elt_inv_grad(args, s) * (pos - x)
434
+ if wp.static(not assume_linear):
435
+ if wp.length_sq(dc) < _COORD_LOOKUP_EPS:
436
+ break
437
+ coords = coords + ref_elt.coord_delta(STEP * dc)
438
+
439
+ return coords
440
+
441
+ return element_coordinates
442
+
443
+ def _make_cell_coordinates(self, assume_linear: bool = False):
444
+ return self._make_element_coordinates(element_kind=ElementKind.CELL, assume_linear=assume_linear)
445
+
446
+ def _make_side_coordinates(self, assume_linear: bool = False):
447
+ return self._make_element_coordinates(element_kind=ElementKind.SIDE, assume_linear=assume_linear)
448
+
449
+ def _make_element_closest_point(self, element_kind: ElementKind, assume_linear: bool = False):
450
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
451
+
452
+ element_coordinates = self._make_element_coordinates(element_kind=element_kind, assume_linear=assume_linear)
453
+
454
+ if element_kind == ElementKind.CELL:
455
+ ref_elt = self.reference_cell()
456
+ arg_type = self.CellArg
457
+ elt_pos = self.cell_position
458
+ elt_def_grad = self.cell_deformation_gradient
459
+ else:
460
+ ref_elt = self.reference_side()
461
+ arg_type = self.SideArg
462
+ elt_pos = self.side_position
463
+ elt_def_grad = self.side_deformation_gradient
464
+
465
+ @cache.dynamic_func(suffix=f"{self.name}{element_kind}{assume_linear}")
466
+ def cell_closest_point(args: arg_type, cell_index: ElementIndex, pos: pos_type):
467
+ # First get unconstrained coordinates, may use newton for this
468
+ coords = element_coordinates(args, cell_index, pos)
469
+
470
+ # Now do projected gradient
471
+ # For interior points should exit at first iteration
472
+ for _k in range(_COORD_LOOKUP_ITERATIONS):
473
+ cur_coords = coords
474
+ s = make_free_sample(cell_index, cur_coords)
475
+ x = elt_pos(args, s)
476
+
477
+ F = elt_def_grad(args, s)
478
+ F_scale = wp.ddot(F, F)
479
+
480
+ dc = (pos - x) @ F # gradient step
481
+ coords = ref_elt.project(cur_coords + ref_elt.coord_delta(dc / F_scale))
482
+
483
+ if wp.length_sq(coords - cur_coords) < _COORD_LOOKUP_EPS:
484
+ break
485
+
486
+ return cur_coords, wp.length_sq(pos - x)
487
+
488
+ return cell_closest_point
489
+
490
+ def _make_cell_closest_point(self, assume_linear: bool = False):
491
+ return self._make_element_closest_point(element_kind=ElementKind.CELL, assume_linear=assume_linear)
492
+
493
+ def _make_side_closest_point(self, assume_linear: bool = False):
494
+ return self._make_element_closest_point(element_kind=ElementKind.SIDE, assume_linear=assume_linear)
495
+
496
+ def make_filtered_cell_lookup(self, filter_func: wp.Function = None):
497
+ suffix = f"{self.name}{filter_func.key if filter_func is not None else ''}"
498
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
499
+
500
+ @cache.dynamic_func(suffix=suffix)
501
+ def cell_lookup(args: self.CellArg, pos: pos_type, max_dist: float, filter_data: Any, filter_target: Any):
502
+ closest_cell = int(NULL_ELEMENT_INDEX)
503
+ closest_coords = Coords(OUTSIDE)
504
+
505
+ bvh_id = self.cell_bvh_id(args)
506
+ if bvh_id != _NULL_BVH_ID:
507
+ pad = wp.max(max_dist, 1.0) * _BVH_MIN_PADDING
508
+
509
+ # query with increasing bbox size until we find an element
510
+ # or reach the max distance bound
511
+ while closest_cell == NULL_ELEMENT_INDEX:
512
+ query = wp.bvh_query_aabb(bvh_id, _bvh_vec(pos) - wp.vec3(pad), _bvh_vec(pos) + wp.vec3(pad))
513
+ cell_index = int(0)
514
+ closest_dist = float(pad * pad)
515
+
516
+ while wp.bvh_query_next(query, cell_index):
517
+ if wp.static(filter_func is not None):
518
+ if filter_func(filter_data, cell_index) != filter_target:
519
+ continue
520
+
521
+ coords, dist = self.cell_closest_point(args, cell_index, pos)
522
+ if dist <= closest_dist:
523
+ closest_dist = dist
524
+ closest_cell = cell_index
525
+ closest_coords = coords
526
+
527
+ if pad >= _BVH_MAX_PADDING:
528
+ break
529
+ pad = wp.min(4.0 * pad, _BVH_MAX_PADDING)
530
+
531
+ return make_free_sample(closest_cell, closest_coords)
532
+
533
+ return cell_lookup
534
+
535
+ @cached_property
536
+ def cell_lookup(self) -> wp.Function:
537
+ unfiltered_cell_lookup = self.make_filtered_cell_lookup(filter_func=None)
538
+
539
+ # overloads
540
+ null_filter_data = 0
541
+ null_filter_target = 0
542
+
543
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
544
+
545
+ @cache.dynamic_func(suffix=self.name)
546
+ def cell_lookup(args: self.CellArg, pos: pos_type, max_dist: float):
547
+ return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
548
+
549
+ @cache.dynamic_func(suffix=self.name)
550
+ def cell_lookup(args: self.CellArg, pos: pos_type, guess: Sample):
551
+ guess_pos = self.cell_position(args, guess)
552
+ max_dist = wp.length(guess_pos - pos)
553
+ return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
554
+
555
+ @cache.dynamic_func(suffix=self.name)
556
+ def cell_lookup(args: self.CellArg, pos: pos_type):
557
+ max_dist = 0.0
558
+ return unfiltered_cell_lookup(args, pos, max_dist, null_filter_data, null_filter_target)
559
+
560
+ # array filtering variants
561
+ filtered_cell_lookup = self.make_filtered_cell_lookup(filter_func=_array_load)
562
+ pos_type = cache.cached_vec_type(self.dimension, dtype=float)
563
+
564
+ @cache.dynamic_func(suffix=self.name)
565
+ def cell_lookup(
566
+ args: self.CellArg, pos: pos_type, max_dist: float, filter_array: wp.array(dtype=Any), filter_target: Any
567
+ ):
568
+ return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)
569
+
570
+ @cache.dynamic_func(suffix=self.name)
571
+ def cell_lookup(args: self.CellArg, pos: pos_type, filter_array: wp.array(dtype=Any), filter_target: Any):
572
+ max_dist = 0.0
573
+ return filtered_cell_lookup(args, pos, max_dist, filter_array, filter_target)
574
+
575
+ return cell_lookup
576
+
577
+ def _make_compute_cell_bounds(self):
578
+ @cache.dynamic_kernel(suffix=self.name)
579
+ def compute_cell_bounds(
580
+ args: self.CellArg,
581
+ lowers: wp.array(dtype=wp.vec3),
582
+ uppers: wp.array(dtype=wp.vec3),
583
+ ):
584
+ i = wp.tid()
585
+ lo, up = self.cell_bounds(args, i)
586
+ lowers[i] = _bvh_vec(lo)
587
+ uppers[i] = _bvh_vec(up)
588
+
589
+ return compute_cell_bounds
590
+
591
+ def supports_cell_lookup(self, device) -> bool:
592
+ return self.bvh_id(device) != _NULL_BVH_ID
593
+
594
+ def update_bvh(self, device=None):
595
+ """
596
+ Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
597
+ """
598
+
599
+ if self._bvhs is None:
600
+ return self.build_bvh(device)
601
+
602
+ device = wp.get_device(device)
603
+ bvh = self._bvhs.get(device.ordinal)
604
+ if bvh is None:
605
+ return self.build_bvh(device)
606
+
607
+ wp.launch(
608
+ self.compute_cell_bounds,
609
+ dim=self.cell_count(),
610
+ device=device,
611
+ inputs=[self.cell_arg_value(device=device)],
612
+ outputs=[
613
+ bvh.lowers,
614
+ bvh.uppers,
615
+ ],
616
+ )
617
+
618
+ bvh.refit()
619
+
620
+ def build_bvh(self, device=None):
621
+ device = wp.get_device(device)
622
+
623
+ lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=device)
624
+ uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=device)
625
+
626
+ wp.launch(
627
+ self.compute_cell_bounds,
628
+ dim=self.cell_count(),
629
+ device=device,
630
+ inputs=[self.cell_arg_value(device=device)],
631
+ outputs=[
632
+ lowers,
633
+ uppers,
634
+ ],
635
+ )
636
+
637
+ if self._bvhs is None:
638
+ self._bvhs = {}
639
+ self._bvhs[device.ordinal] = wp.Bvh(lowers, uppers)
640
+
641
+ def bvh_id(self, device):
642
+ if self._bvhs is None:
643
+ return _NULL_BVH_ID
644
+
645
+ bvh = self._bvhs.get(wp.get_device(device).ordinal)
646
+ if bvh is None:
647
+ return _NULL_BVH_ID
648
+ return bvh.id
649
+
650
+
651
+ @wp.func
652
+ def _bvh_vec(v: wp.vec3):
653
+ return v
654
+
655
+
656
+ @wp.func
657
+ def _bvh_vec(v: wp.vec2):
658
+ return wp.vec3(v[0], v[1], 0.0)
659
+
660
+
661
+ @wp.func
662
+ def _array_load(arr: wp.array(dtype=Any), idx: int):
663
+ return arr[idx]