warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -33,9 +33,8 @@ class HexmeshCellArg:
33
33
  hex_vertex_indices: wp.array2d(dtype=int)
34
34
  positions: wp.array(dtype=wp.vec3)
35
35
 
36
- # for neighbor cell lookup
37
- vertex_hex_offsets: wp.array(dtype=int)
38
- vertex_hex_indices: wp.array(dtype=int)
36
+ # for global cell lookup
37
+ hex_bvh: wp.uint64
39
38
 
40
39
 
41
40
  @wp.struct
@@ -131,20 +130,28 @@ class Hexmesh(Geometry):
131
130
  dimension = 3
132
131
 
133
132
  def __init__(
134
- self, hex_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
133
+ self,
134
+ hex_vertex_indices: wp.array,
135
+ positions: wp.array,
136
+ assume_parallelepiped_cells=False,
137
+ build_bvh: bool = False,
138
+ temporary_store: Optional[TemporaryStore] = None,
135
139
  ):
136
140
  """
137
- Constructs a tetrahedral mesh.
141
+ Constructs a hexahedral mesh.
138
142
 
139
143
  Args:
140
144
  hex_vertex_indices: warp array of shape (num_hexes, 8) containing vertex indices for each hex
141
145
  following standard ordering (bottom face vertices in counter-clockwise order, then similarly for upper face)
142
146
  positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
147
+ assume_parallelepiped: If true, assume that all cells are parallelepipeds (cheaper position/gradient evaluations)
148
+ build_bvh: Whether to also build the hex BVH, which is necessary for the global `fem.lookup` operator
143
149
  temporary_store: shared pool from which to allocate temporary arrays
144
150
  """
145
151
 
146
152
  self.hex_vertex_indices = hex_vertex_indices
147
153
  self.positions = positions
154
+ self.parallelepiped_cells = assume_parallelepiped_cells
148
155
 
149
156
  self._face_vertex_indices: wp.array = None
150
157
  self._face_hex_indices: wp.array = None
@@ -155,7 +162,25 @@ class Hexmesh(Geometry):
155
162
  self._edge_count = 0
156
163
  self._build_topology(temporary_store)
157
164
 
165
+ # Use cheaper variants if we know that cells are parallelepipeds (i.e. linearly transformed)
166
+ # (Cells only, not as much difference for sides)
167
+ self.cell_position = (
168
+ self._cell_position_parallelepiped if assume_parallelepiped_cells else self._cell_position_generic
169
+ )
170
+ self.cell_deformation_gradient = (
171
+ self._cell_deformation_gradient_parallelepiped
172
+ if assume_parallelepiped_cells
173
+ else self._cell_deformation_gradient_generic
174
+ )
175
+
158
176
  self._make_default_dependent_implementations()
177
+ self.cell_coordinates = self._make_cell_coordinates(assume_linear=assume_parallelepiped_cells)
178
+ self.side_coordinates = self._make_side_coordinates(assume_linear=assume_parallelepiped_cells)
179
+ self.cell_closest_point = self._make_cell_closest_point(assume_linear=assume_parallelepiped_cells)
180
+ self.side_closest_point = self._make_side_closest_point(assume_linear=assume_parallelepiped_cells)
181
+
182
+ if build_bvh:
183
+ self.build_bvh(self.positions.device)
159
184
 
160
185
  def cell_count(self):
161
186
  return self.hex_vertex_indices.shape[0]
@@ -203,19 +228,18 @@ class Hexmesh(Geometry):
203
228
 
204
229
  # Geometry device interface
205
230
 
206
- @cached_arg_value
207
231
  def cell_arg_value(self, device) -> CellArg:
208
232
  args = self.CellArg()
233
+ self.fill_cell_arg(args, device)
234
+ return args
209
235
 
236
+ def fill_cell_arg(self, args: CellArg, device):
210
237
  args.hex_vertex_indices = self.hex_vertex_indices.to(device)
211
238
  args.positions = self.positions.to(device)
212
- args.vertex_hex_offsets = self._vertex_hex_offsets.to(device)
213
- args.vertex_hex_indices = self._vertex_hex_indices.to(device)
214
-
215
- return args
239
+ args.hex_bvh = self.bvh_id(device)
216
240
 
217
241
  @wp.func
218
- def cell_position(args: CellArg, s: Sample):
242
+ def _cell_position_generic(args: CellArg, s: Sample):
219
243
  hex_idx = args.hex_vertex_indices[s.element_index]
220
244
 
221
245
  w_p = s.element_coords
@@ -242,9 +266,18 @@ class Hexmesh(Geometry):
242
266
  )
243
267
 
244
268
  @wp.func
245
- def cell_deformation_gradient(cell_arg: CellArg, s: Sample):
269
+ def _cell_position_parallelepiped(args: CellArg, s: Sample):
270
+ hex_idx = args.hex_vertex_indices[s.element_index]
271
+ w = s.element_coords
272
+ p0 = args.positions[hex_idx[0]]
273
+ p1 = args.positions[hex_idx[1]]
274
+ p2 = args.positions[hex_idx[3]]
275
+ p3 = args.positions[hex_idx[4]]
276
+ return w[0] * p1 + w[1] * p2 + w[2] * p3 + (1.0 - w[0] - w[1] - w[2]) * p0
277
+
278
+ @wp.func
279
+ def _cell_deformation_gradient_generic(cell_arg: CellArg, s: Sample):
246
280
  """Deformation gradient at `coords`"""
247
- """Transposed deformation gradient at `coords`"""
248
281
  hex_idx = cell_arg.hex_vertex_indices[s.element_index]
249
282
 
250
283
  w_p = s.element_coords
@@ -261,31 +294,43 @@ class Hexmesh(Geometry):
261
294
  + wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
262
295
  )
263
296
 
297
+ @wp.func
298
+ def _cell_deformation_gradient_parallelepiped(cell_arg: CellArg, s: Sample):
299
+ """Deformation gradient at `coords`"""
300
+ hex_idx = cell_arg.hex_vertex_indices[s.element_index]
301
+
302
+ p0 = cell_arg.positions[hex_idx[0]]
303
+ p1 = cell_arg.positions[hex_idx[1]]
304
+ p2 = cell_arg.positions[hex_idx[3]]
305
+ p3 = cell_arg.positions[hex_idx[4]]
306
+ return wp.matrix_from_cols(p1 - p0, p2 - p0, p3 - p0)
307
+
264
308
  @cached_arg_value
265
309
  def side_index_arg_value(self, device) -> SideIndexArg:
266
310
  args = self.SideIndexArg()
311
+ self.fill_side_index_arg(args, device)
312
+ return args
267
313
 
314
+ def fill_side_index_arg(self, args: SideIndexArg, device):
268
315
  args.boundary_face_indices = self._boundary_face_indices.to(device)
269
316
 
270
- return args
271
-
272
317
  @wp.func
273
318
  def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
274
319
  """Boundary side to side index"""
275
320
 
276
321
  return args.boundary_face_indices[boundary_side_index]
277
322
 
278
- @cached_arg_value
279
323
  def side_arg_value(self, device) -> CellArg:
280
324
  args = self.SideArg()
325
+ self.fill_side_arg(args, device)
326
+ return args
281
327
 
282
- args.cell_arg = self.cell_arg_value(device)
328
+ def fill_side_arg(self, args: SideArg, device):
329
+ self.fill_cell_arg(args.cell_arg, device)
283
330
  args.face_vertex_indices = self._face_vertex_indices.to(device)
284
331
  args.face_hex_indices = self._face_hex_indices.to(device)
285
332
  args.face_hex_face_orientation = self._face_hex_face_orientation.to(device)
286
333
 
287
- return args
288
-
289
334
  @wp.func
290
335
  def side_position(args: SideArg, s: Sample):
291
336
  face_idx = args.face_vertex_indices[s.element_index]
@@ -332,7 +377,7 @@ class Hexmesh(Geometry):
332
377
 
333
378
  @wp.func
334
379
  def _hex_local_face_coords(hex_coords: Coords, face_index: int):
335
- # Coordinatex in local face coordinates system
380
+ # Coordinates in local face coordinates system
336
381
  # Sign of last coordinate (out of face)
337
382
 
338
383
  face_coords = wp.vec2(
@@ -909,3 +954,24 @@ class Hexmesh(Geometry):
909
954
  + unique_beg
910
955
  )
911
956
  hex_edge_indices[t][k] = edge_id
957
+
958
+ @wp.func
959
+ def cell_bvh_id(cell_arg: HexmeshCellArg):
960
+ return cell_arg.hex_bvh
961
+
962
+ @wp.func
963
+ def cell_bounds(cell_arg: HexmeshCellArg, cell_index: ElementIndex):
964
+ vidx = cell_arg.hex_vertex_indices[cell_index]
965
+ p0 = cell_arg.positions[vidx[0]]
966
+ p1 = cell_arg.positions[vidx[1]]
967
+ p2 = cell_arg.positions[vidx[2]]
968
+ p3 = cell_arg.positions[vidx[3]]
969
+ lo0, up0 = wp.min(wp.min(p0, p1), wp.min(p2, p3)), wp.max(wp.max(p0, p1), wp.max(p2, p3))
970
+
971
+ p4 = cell_arg.positions[vidx[4]]
972
+ p5 = cell_arg.positions[vidx[5]]
973
+ p6 = cell_arg.positions[vidx[6]]
974
+ p7 = cell_arg.positions[vidx[7]]
975
+ lo1, up1 = wp.min(wp.min(p4, p5), wp.min(p6, p7)), wp.max(wp.max(p4, p5), wp.max(p6, p7))
976
+
977
+ return wp.min(lo0, lo1), wp.max(up0, up1)
@@ -13,13 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import Any, Optional
17
17
 
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
21
  from warp.fem import cache, utils
22
- from warp.fem.linalg import basis_element
23
22
  from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
24
23
 
25
24
  from .element import Cube, Square
@@ -111,6 +110,13 @@ class Nanogrid(Geometry):
111
110
  self._edge_grid = None
112
111
  self._edge_count = 0
113
112
 
113
+ transform = np.array(self._cell_grid_info.transform_matrix).reshape(3, 3)
114
+ self._inverse_transform = wp.mat33f(np.linalg.inv(transform))
115
+ self._cell_volume = abs(np.linalg.det(transform))
116
+ self._face_areas = wp.vec3(
117
+ tuple(np.linalg.norm(np.cross(transform[:, k - 2], transform[:, k - 1])) for k in range(3))
118
+ )
119
+
114
120
  @property
115
121
  def cell_grid(self) -> wp.Volume:
116
122
  return self._cell_grid
@@ -158,14 +164,15 @@ class Nanogrid(Geometry):
158
164
  @cache.cached_arg_value
159
165
  def cell_arg_value(self, device) -> CellArg:
160
166
  args = self.CellArg()
161
- args.cell_grid = self._cell_grid.id
162
- args.cell_ijk = self._cell_ijk
167
+ self.fill_cell_arg(args, device)
168
+ return args
163
169
 
164
- transform = np.array(self._cell_grid_info.transform_matrix).reshape(3, 3)
165
- args.inverse_transform = wp.mat33f(np.linalg.inv(transform))
166
- args.cell_volume = abs(np.linalg.det(transform))
170
+ def fill_cell_arg(self, arg, device):
171
+ arg.cell_grid = self._cell_grid.id
172
+ arg.cell_ijk = self._cell_ijk
167
173
 
168
- return args
174
+ arg.inverse_transform = self._inverse_transform
175
+ arg.cell_volume = self._cell_volume
169
176
 
170
177
  @wp.func
171
178
  def cell_position(args: CellArg, s: Sample):
@@ -180,66 +187,106 @@ class Nanogrid(Geometry):
180
187
  def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
181
188
  return args.inverse_transform
182
189
 
183
- @wp.func
184
- def cell_lookup(args: CellArg, pos: wp.vec3):
185
- uvw = wp.volume_world_to_index(args.cell_grid, pos) + wp.vec3(0.5)
186
- ijk = wp.vec3i(int(wp.floor(uvw[0])), int(wp.floor(uvw[1])), int(wp.floor(uvw[2])))
187
- cell_index = wp.volume_lookup_index(args.cell_grid, ijk[0], ijk[1], ijk[2])
188
-
189
- coords = uvw - wp.vec3(ijk)
190
- if cell_index == -1:
191
- if wp.min(coords) == 0.0 or wp.max(coords) == 1.0:
192
- il = wp.where(coords[0] > 0.5, 0, -1)
193
- jl = wp.where(coords[1] > 0.5, 0, -1)
194
- kl = wp.where(coords[2] > 0.5, 0, -1)
195
-
196
- for n in range(8):
197
- ni = n >> 2
198
- nj = (n & 2) >> 1
199
- nk = n & 1
200
- nijk = ijk + wp.vec3i(ni + il, nj + jl, nk + kl)
201
-
202
- coords = uvw - wp.vec3(nijk)
203
- if wp.min(coords) >= 0.0 and wp.max(coords) <= 1.0:
204
- cell_index = wp.volume_lookup_index(args.cell_grid, nijk[0], nijk[1], nijk[2])
205
- if cell_index != -1:
206
- return make_free_sample(cell_index, coords)
207
-
208
- return make_free_sample(NULL_ELEMENT_INDEX, Coords(OUTSIDE))
209
-
210
- return make_free_sample(cell_index, coords)
190
+ def supports_cell_lookup(self, device):
191
+ return True
211
192
 
212
193
  @wp.func
213
- def _project_on_voxel_at_origin(coords: wp.vec3):
214
- proj_coords = wp.min(wp.max(coords, wp.vec3(0.0)), wp.vec3(1.0))
215
- return wp.length_sq(coords - proj_coords), proj_coords
194
+ def _lookup_cell_index(args: NanogridCellArg, i: int, j: int, k: int):
195
+ return wp.volume_lookup_index(args.cell_grid, i, j, k)
216
196
 
217
197
  @wp.func
218
- def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
219
- s_global = Nanogrid.cell_lookup(args, pos)
220
-
221
- if s_global.element_index != NULL_ELEMENT_INDEX:
222
- return s_global
198
+ def _cell_coordinates_local(args: NanogridCellArg, cell_index: int, uvw: wp.vec3):
199
+ ijk = wp.vec3(args.cell_ijk[cell_index])
200
+ rel_pos = uvw - ijk
201
+ return rel_pos
223
202
 
224
- closest_voxel = int(NULL_ELEMENT_INDEX)
225
- closest_coords = Coords(OUTSIDE)
226
- closest_dist = float(1.0e8)
203
+ @wp.func
204
+ def _cell_closest_point_local(args: NanogridCellArg, cell_index: int, uvw: wp.vec3):
205
+ ijk = wp.vec3(args.cell_ijk[cell_index])
206
+ rel_pos = uvw - ijk
207
+ coords = wp.min(wp.max(rel_pos, wp.vec3(0.0)), wp.vec3(1.0))
208
+ return wp.length_sq(wp.volume_index_to_world_dir(args.cell_grid, coords - rel_pos)), coords
227
209
 
228
- # project to closest in stencil
210
+ @wp.func
211
+ def cell_coordinates(args: NanogridCellArg, cell_index: int, pos: wp.vec3):
229
212
  uvw = wp.volume_world_to_index(args.cell_grid, pos) + wp.vec3(0.5)
230
- cell_ijk = args.cell_ijk[guess.element_index]
231
- for ni in range(-1, 2):
232
- for nj in range(-1, 2):
233
- for nk in range(-1, 2):
234
- nijk = cell_ijk + wp.vec3i(ni, nj, nk)
235
- cell_idx = wp.volume_lookup_index(args.cell_grid, nijk[0], nijk[1], nijk[2])
236
- dist, coords = Nanogrid._project_on_voxel_at_origin(uvw - wp.vec3(nijk))
237
- if cell_idx != -1 and dist <= closest_dist:
238
- closest_dist = dist
239
- closest_voxel = cell_idx
240
- closest_coords = coords
213
+ return Nanogrid._cell_coordinates_local(args, cell_index, uvw)
241
214
 
242
- return make_free_sample(closest_voxel, closest_coords)
215
+ @wp.func
216
+ def cell_closest_point(args: NanogridCellArg, cell_index: int, pos: wp.vec3):
217
+ uvw = wp.volume_world_to_index(args.cell_grid, pos) + wp.vec3(0.5)
218
+ dist, coords = Nanogrid._cell_closest_point_local(args, cell_index, uvw)
219
+ return coords, dist
220
+
221
+ @staticmethod
222
+ def _make_filtered_cell_lookup(grid_geo, filter_func: wp.Function = None):
223
+ suffix = f"{grid_geo.name}{filter_func.key if filter_func is not None else ''}"
224
+
225
+ @cache.dynamic_func(suffix=suffix)
226
+ def cell_lookup(args: grid_geo.CellArg, pos: wp.vec3, max_dist: float, filter_data: Any, filter_target: Any):
227
+ grid = args.cell_grid
228
+
229
+ # Start at corresponding voxel
230
+ uvw = wp.volume_world_to_index(grid, pos) + wp.vec3(0.5)
231
+ i, j, k = int(wp.floor(uvw[0])), int(wp.floor(uvw[1])), int(wp.floor(uvw[2]))
232
+ cell_index = grid_geo._lookup_cell_index(args, i, j, k)
233
+
234
+ if cell_index != -1:
235
+ coords = grid_geo._cell_coordinates_local(args, cell_index, uvw)
236
+ if wp.static(filter_func is None):
237
+ return make_free_sample(cell_index, coords)
238
+ else:
239
+ if filter_func(filter_data, cell_index) == filter_target:
240
+ return make_free_sample(cell_index, coords)
241
+
242
+ # Iterate over increasingly larger neighborhoods
243
+ cell_size = wp.vec3(
244
+ wp.length(wp.volume_index_to_world_dir(grid, wp.vec3(1.0, 0.0, 0.0))),
245
+ wp.length(wp.volume_index_to_world_dir(grid, wp.vec3(0.0, 1.0, 0.0))),
246
+ wp.length(wp.volume_index_to_world_dir(grid, wp.vec3(0.0, 0.0, 1.0))),
247
+ )
248
+
249
+ offset = float(0.5)
250
+ min_cell_size = wp.min(cell_size)
251
+ max_offset = wp.ceil(max_dist / min_cell_size)
252
+ scales = wp.cw_div(wp.vec3(min_cell_size), wp.vec3(cell_size))
253
+
254
+ closest_cell = NULL_ELEMENT_INDEX
255
+ closest_coords = Coords()
256
+
257
+ while closest_cell == NULL_ELEMENT_INDEX:
258
+ uvw_min = wp.vec3i(uvw - offset * scales)
259
+ uvw_max = wp.vec3i(uvw + offset * scales) + wp.vec3i(1)
260
+
261
+ closest_dist = min_cell_size * min_cell_size * float(offset * offset)
262
+
263
+ for i in range(uvw_min[0], uvw_max[0]):
264
+ for j in range(uvw_min[1], uvw_max[1]):
265
+ for k in range(uvw_min[2], uvw_max[2]):
266
+ cell_index = grid_geo._lookup_cell_index(args, i, j, k)
267
+ if cell_index == -1:
268
+ continue
269
+
270
+ if wp.static(filter_func is not None):
271
+ if filter_func(filter_data, cell_index) != filter_target:
272
+ continue
273
+ dist, coords = grid_geo._cell_closest_point_local(args, cell_index, uvw)
274
+
275
+ if dist <= closest_dist:
276
+ closest_dist = dist
277
+ closest_coords = coords
278
+ closest_cell = cell_index
279
+
280
+ if offset >= max_offset:
281
+ break
282
+ offset = wp.min(3.0 * offset, max_offset)
283
+
284
+ return make_free_sample(closest_cell, closest_coords)
285
+
286
+ return cell_lookup
287
+
288
+ def make_filtered_cell_lookup(self, filter_func):
289
+ return Nanogrid._make_filtered_cell_lookup(self, filter_func)
243
290
 
244
291
  @wp.func
245
292
  def cell_measure(args: CellArg, s: Sample):
@@ -253,43 +300,47 @@ class Nanogrid(Geometry):
253
300
 
254
301
  @cache.cached_arg_value
255
302
  def side_arg_value(self, device) -> SideArg:
256
- self._ensure_face_grid()
257
-
258
303
  args = self.SideArg()
259
- args.cell_arg = self.cell_arg_value(device)
260
- args.face_ijk = self._face_ijk.to(device)
261
- args.face_flags = self._face_flags.to(device)
262
- transform = np.array(self._cell_grid_info.transform_matrix).reshape(3, 3)
263
- args.face_areas = wp.vec3(
264
- tuple(np.linalg.norm(np.cross(transform[:, k - 2], transform[:, k - 1])) for k in range(3))
265
- )
266
-
304
+ self.fill_side_arg(args, device)
267
305
  return args
268
306
 
307
+ def fill_side_arg(self, arg: SideArg, device):
308
+ self._ensure_face_grid()
309
+ self.fill_cell_arg(arg.cell_arg, device)
310
+ arg.face_ijk = self._face_ijk.to(device)
311
+ arg.face_flags = self._face_flags.to(device)
312
+ arg.face_areas = self._face_areas
313
+
269
314
  @wp.struct
270
315
  class SideIndexArg:
271
316
  boundary_face_indices: wp.array(dtype=int)
272
317
 
273
318
  @cache.cached_arg_value
274
319
  def side_index_arg_value(self, device) -> SideIndexArg:
275
- self._ensure_face_grid()
276
-
277
320
  args = self.SideIndexArg()
278
- args.boundary_face_indices = self._boundary_face_indices.to(device)
321
+ self.fill_side_index_arg(args, device)
279
322
  return args
280
323
 
324
+ def fill_side_index_arg(self, arg: SideIndexArg, device):
325
+ self._ensure_face_grid()
326
+ arg.boundary_face_indices = self._boundary_face_indices.to(device)
327
+
281
328
  @wp.func
282
329
  def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
283
330
  return args.boundary_face_indices[boundary_side_index]
284
331
 
285
332
  @wp.func
286
- def _side_to_cell_coords(axis: int, inner: float, side_coords: Coords):
333
+ def _side_to_cell_coords(axis: int, flip: int, inner: float, side_coords: Coords):
287
334
  uvw = wp.vec3()
288
335
  uvw[axis] = inner
289
- uvw[(axis + 1) % 3] = side_coords[0]
290
- uvw[(axis + 2) % 3] = side_coords[1]
336
+ uvw[(axis + 1 + flip) % 3] = side_coords[0]
337
+ uvw[(axis + 2 - flip) % 3] = side_coords[1]
291
338
  return uvw
292
339
 
340
+ @wp.func
341
+ def _cell_to_side_coords(axis: int, flip: int, cell_coords: Coords):
342
+ return Coords(cell_coords[(axis + 1 + flip) % 3], cell_coords[(axis + 2 - flip) % 3], 0.0)
343
+
293
344
  @wp.func
294
345
  def _get_face_axis(flags: wp.uint8):
295
346
  return wp.int32(flags & FACE_AXIS_MASK)
@@ -305,18 +356,21 @@ class Nanogrid(Geometry):
305
356
  @wp.func
306
357
  def side_position(args: SideArg, s: Sample):
307
358
  ijk = args.face_ijk[s.element_index]
308
- axis = Nanogrid._get_face_axis(args.face_flags[s.element_index])
359
+ flags = args.face_flags[s.element_index]
360
+ axis = Nanogrid._get_face_axis(flags)
361
+ flip = Nanogrid._get_face_inner_offset(flags)
309
362
 
310
- uvw = wp.vec3(ijk) + Nanogrid._side_to_cell_coords(axis, 0.0, s.element_coords)
363
+ uvw = wp.vec3(ijk) + Nanogrid._side_to_cell_coords(axis, flip, 0.0, s.element_coords)
311
364
 
312
365
  cell_grid = args.cell_arg.cell_grid
313
366
  return wp.volume_index_to_world(cell_grid, uvw - wp.vec3(0.5))
314
367
 
315
368
  @wp.func
316
369
  def _face_tangent_vecs(cell_grid: wp.uint64, axis: int, flip: int):
317
- u_axis = basis_element(wp.vec3(), (axis + 1 + flip) % 3)
318
- v_axis = basis_element(wp.vec3(), (axis + 2 - flip) % 3)
319
-
370
+ u_axis = wp.vec3()
371
+ v_axis = wp.vec3()
372
+ u_axis[(axis + 1 + flip) % 3] = 1.0
373
+ v_axis[(axis + 2 - flip) % 3] = 1.0
320
374
  return wp.volume_index_to_world_dir(cell_grid, u_axis), wp.volume_index_to_world_dir(cell_grid, v_axis)
321
375
 
322
376
  @wp.func
@@ -382,15 +436,17 @@ class Nanogrid(Geometry):
382
436
  def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
383
437
  flags = args.face_flags[side_index]
384
438
  axis = Nanogrid._get_face_axis(flags)
439
+ flip = Nanogrid._get_face_inner_offset(flags)
385
440
  offset = float(Nanogrid._get_face_inner_offset(flags))
386
- return Nanogrid._side_to_cell_coords(axis, 1.0 - offset, side_coords)
441
+ return Nanogrid._side_to_cell_coords(axis, flip, 1.0 - offset, side_coords)
387
442
 
388
443
  @wp.func
389
444
  def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
390
445
  flags = args.face_flags[side_index]
391
446
  axis = Nanogrid._get_face_axis(flags)
447
+ flip = Nanogrid._get_face_inner_offset(flags)
392
448
  offset = float(Nanogrid._get_face_outer_offset(flags))
393
- return Nanogrid._side_to_cell_coords(axis, offset, side_coords)
449
+ return Nanogrid._side_to_cell_coords(axis, flip, offset, side_coords)
394
450
 
395
451
  @wp.func
396
452
  def side_from_cell_coords(
@@ -401,20 +457,44 @@ class Nanogrid(Geometry):
401
457
  ):
402
458
  flags = args.face_flags[side_index]
403
459
  axis = Nanogrid._get_face_axis(flags)
460
+ flip = Nanogrid._get_face_inner_offset(flags)
404
461
 
405
462
  cell_ijk = args.cell_arg.cell_ijk[element_index]
406
463
  side_ijk = args.face_ijk[side_index]
407
464
 
408
465
  on_side = float(side_ijk[axis] - cell_ijk[axis]) == element_coords[axis]
409
466
 
410
- return wp.where(
411
- on_side, Coords(element_coords[(axis + 1) % 3], element_coords[(axis + 2) % 3], 0.0), Coords(OUTSIDE)
412
- )
467
+ return wp.where(on_side, Nanogrid._cell_to_side_coords(axis, flip, element_coords), Coords(OUTSIDE))
413
468
 
414
469
  @wp.func
415
470
  def side_to_cell_arg(side_arg: SideArg):
416
471
  return side_arg.cell_arg
417
472
 
473
+ @wp.func
474
+ def side_coordinates(args: SideArg, side_index: int, pos: wp.vec3):
475
+ cell_arg = args.cell_arg
476
+
477
+ ijk = args.face_ijk[side_index]
478
+ cell_coords = wp.volume_world_to_index(cell_arg.cell_grid, pos) + wp.vec3(0.5) - wp.vec3(ijk)
479
+
480
+ flags = args.face_flags[side_index]
481
+ axis = Nanogrid._get_face_axis(flags)
482
+ flip = Nanogrid._get_face_inner_offset(flags)
483
+ return Nanogrid._cell_to_side_coords(axis, flip, cell_coords)
484
+
485
+ @wp.func
486
+ def side_closest_point(args: SideArg, side_index: int, pos: wp.vec3):
487
+ coords = Nanogrid.side_coordinates(args, side_index, pos)
488
+
489
+ proj_coords = Coords(wp.clamp(coords[0], 0.0, 1.0), wp.clamp(coords[1], 0.0, 1.0), 0.0)
490
+
491
+ flags = args.face_flags[side_index]
492
+ axis = Nanogrid._get_face_axis(flags)
493
+ flip = Nanogrid._get_face_inner_offset(flags)
494
+ cell_coord_offset = Nanogrid._side_to_cell_coords(axis, flip, 0, coords - proj_coords)
495
+
496
+ return proj_coords, wp.length_sq(wp.volume_index_to_world_dir(args.cell_grid, cell_coord_offset))
497
+
418
498
  def _build_face_grid(self, temporary_store: Optional[cache.TemporaryStore] = None):
419
499
  device = self._cell_grid.device
420
500
  self._face_grid = _build_face_grid(self._cell_ijk, self._cell_grid, temporary_store)