warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp-clang.dylib +0 -0
  5. warp/bin/libwarp.dylib +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import cached_arg_value
20
- from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
19
+ from warp.fem.cache import cached_arg_value, dynamic_func
20
+ from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
21
21
 
22
+ from .closest_point import project_on_box_at_origin
22
23
  from .element import LinearEdge, Square
23
24
  from .geometry import Geometry
24
25
 
@@ -35,8 +36,8 @@ class Grid2D(Geometry):
35
36
 
36
37
  dimension = 2
37
38
 
38
- Permutation = wp.types.matrix(shape=(2, 2), dtype=int)
39
- ROTATION = wp.constant(Permutation(0, 1, 1, 0))
39
+ ALT_AXIS = 0
40
+ LONG_AXIS = 1
40
41
 
41
42
  def __init__(self, res: wp.vec2i, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
42
43
  """Constructs a dense 2D grid
@@ -140,28 +141,27 @@ class Grid2D(Geometry):
140
141
  SideIndexArg = SideArg
141
142
 
142
143
  @wp.func
143
- def _rotate(axis: int, vec: wp.vec2i):
144
- return wp.vec2i(
145
- vec[Grid2D.ROTATION[axis, 0]],
146
- vec[Grid2D.ROTATION[axis, 1]],
147
- )
144
+ def orient(axis: int, vec: Any):
145
+ return wp.where(axis == 0, vec, type(vec)(vec[1], vec[0]))
148
146
 
149
147
  @wp.func
150
- def _rotate(axis: int, vec: wp.vec2):
151
- return wp.vec2(
152
- vec[Grid2D.ROTATION[axis, 0]],
153
- vec[Grid2D.ROTATION[axis, 1]],
154
- )
148
+ def orient(axis: int, coord: int):
149
+ return wp.where(axis == 0, coord, 1 - coord)
150
+
151
+ @wp.func
152
+ def is_flipped(side: Side):
153
+ # Flip such that the boundary is CCW
154
+ return (side.axis == 0) == (side.origin[Grid2D.ALT_AXIS] == 0)
155
155
 
156
156
  @wp.func
157
157
  def side_index(arg: SideArg, side: Side):
158
- alt_axis = Grid2D.ROTATION[side.axis, 0]
158
+ alt_axis = Grid2D.orient(side.axis, 0)
159
159
  if side.origin[0] == arg.cell_arg.res[alt_axis]:
160
160
  # Upper-boundary side
161
161
  longitude = side.origin[1]
162
162
  return 2 * arg.cell_count + arg.axis_offsets[side.axis] + longitude
163
163
 
164
- cell_index = Grid2D.cell_index(arg.cell_arg.res, Grid2D._rotate(side.axis, side.origin))
164
+ cell_index = Grid2D.cell_index(arg.cell_arg.res, Grid2D.orient(side.axis, side.origin))
165
165
  return side.axis * arg.cell_count + cell_index
166
166
 
167
167
  @wp.func
@@ -169,13 +169,13 @@ class Grid2D(Geometry):
169
169
  if side_index < 2 * arg.cell_count:
170
170
  axis = side_index // arg.cell_count
171
171
  cell_index = side_index - axis * arg.cell_count
172
- origin = Grid2D._rotate(axis, Grid2D.get_cell(arg.cell_arg.res, cell_index))
172
+ origin = Grid2D.orient(axis, Grid2D.get_cell(arg.cell_arg.res, cell_index))
173
173
  return Grid2D.Side(axis, origin)
174
174
 
175
175
  axis_side_index = side_index - 2 * arg.cell_count
176
176
  axis = wp.where(axis_side_index < arg.axis_offsets[1], 0, 1)
177
177
 
178
- altitude = arg.cell_arg.res[Grid2D.ROTATION[axis, 0]]
178
+ altitude = arg.cell_arg.res[Grid2D.orient(axis, 0)]
179
179
  longitude = axis_side_index - arg.axis_offsets[axis]
180
180
 
181
181
  origin_loc = wp.vec2i(altitude, longitude)
@@ -186,10 +186,13 @@ class Grid2D(Geometry):
186
186
  @cached_arg_value
187
187
  def cell_arg_value(self, device) -> CellArg:
188
188
  args = self.CellArg()
189
+ self.fill_cell_arg(args, device)
190
+ return args
191
+
192
+ def fill_cell_arg(self, args: CellArg, device):
189
193
  args.res = self.res
190
194
  args.cell_size = self.cell_size
191
195
  args.origin = self.bounds_lo
192
- return args
193
196
 
194
197
  @wp.func
195
198
  def cell_position(args: CellArg, s: Sample):
@@ -211,22 +214,83 @@ class Grid2D(Geometry):
211
214
  return wp.diag(wp.cw_div(wp.vec2(1.0), args.cell_size))
212
215
 
213
216
  @wp.func
214
- def cell_lookup(args: CellArg, pos: wp.vec2):
215
- loc_pos = wp.cw_div(pos - args.origin, args.cell_size)
216
- x = wp.clamp(loc_pos[0], 0.0, float(args.res[0]))
217
- y = wp.clamp(loc_pos[1], 0.0, float(args.res[1]))
217
+ def cell_coordinates(args: Grid2DCellArg, cell_index: int, pos: wp.vec2):
218
+ uvw = wp.cw_div(pos - args.origin, args.cell_size)
219
+ ij = Grid2D.get_cell(args.res, cell_index)
220
+ return Coords(uvw[0] - float(ij[0]), uvw[1] - float(ij[1]), 0.0)
218
221
 
219
- x_cell = wp.min(wp.floor(x), float(args.res[0]) - 1.0)
220
- y_cell = wp.min(wp.floor(y), float(args.res[1]) - 1.0)
222
+ @wp.func
223
+ def cell_closest_point(args: Grid2DCellArg, cell_index: int, pos: wp.vec2):
224
+ ij_world = wp.cw_mul(wp.vec2(Grid2D.get_cell(args.res, cell_index)), args.cell_size) + args.origin
225
+ dist_sq, coords = project_on_box_at_origin(pos - ij_world, args.cell_size)
226
+ return coords, dist_sq
221
227
 
222
- coords = Coords(x - x_cell, y - y_cell, 0.0)
223
- cell_index = Grid2D.cell_index(args.res, Grid2D.Cell(int(x_cell), int(y_cell)))
228
+ def supports_cell_lookup(self, device):
229
+ return True
224
230
 
225
- return make_free_sample(cell_index, coords)
231
+ def make_filtered_cell_lookup(self, filter_func: wp.Function = None):
232
+ suffix = f"{self.name}{filter_func.key if filter_func is not None else ''}"
226
233
 
227
- @wp.func
228
- def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
229
- return Grid2D.cell_lookup(args, pos)
234
+ @dynamic_func(suffix=suffix)
235
+ def cell_lookup(args: self.CellArg, pos: wp.vec2, max_dist: float, filter_data: Any, filter_target: Any):
236
+ cell_size = args.cell_size
237
+ res = args.res
238
+
239
+ # Start at closest point on grid
240
+ loc_pos = wp.cw_div(pos - args.origin, cell_size)
241
+ x = wp.clamp(loc_pos[0], 0.0, float(res[0]))
242
+ y = wp.clamp(loc_pos[1], 0.0, float(res[1]))
243
+
244
+ x_cell = wp.min(wp.floor(x), float(res[0] - 1))
245
+ y_cell = wp.min(wp.floor(y), float(res[1] - 1))
246
+
247
+ coords = Coords(x - x_cell, y - y_cell, 0.0)
248
+ cell_index = Grid2D.cell_index(res, Grid2D.Cell(int(x_cell), int(y_cell)))
249
+
250
+ if wp.static(filter_func is None):
251
+ return make_free_sample(cell_index, coords)
252
+ else:
253
+ if filter_func(filter_data, cell_index) == filter_target:
254
+ return make_free_sample(cell_index, coords)
255
+
256
+ offset = float(0.5)
257
+ min_cell_size = wp.min(cell_size)
258
+ max_offset = wp.ceil(max_dist / min_cell_size)
259
+
260
+ scales = wp.cw_div(wp.vec2(min_cell_size), cell_size)
261
+
262
+ closest_cell = NULL_ELEMENT_INDEX
263
+ closest_coords = Coords()
264
+
265
+ # Iterate over increasingly larger neighborhoods
266
+ while closest_cell == NULL_ELEMENT_INDEX:
267
+ i_min = wp.max(0, int(wp.floor(x - offset * scales[0])))
268
+ i_max = wp.min(res[0], int(wp.floor(x + offset * scales[0])) + 1)
269
+ j_min = wp.max(0, int(wp.floor(y - offset * scales[1])))
270
+ j_max = wp.min(res[1], int(wp.floor(y + offset * scales[1])) + 1)
271
+
272
+ closest_dist = min_cell_size * min_cell_size * float(offset * offset)
273
+
274
+ for i in range(i_min, i_max):
275
+ for j in range(j_min, j_max):
276
+ ij = Grid2D.Cell(i, j)
277
+ cell_index = Grid2D.cell_index(res, ij)
278
+ if filter_func(filter_data, cell_index) == filter_target:
279
+ rel_pos = wp.cw_mul(loc_pos - wp.vec2(ij), cell_size)
280
+ dist, coords = project_on_box_at_origin(rel_pos, cell_size)
281
+
282
+ if dist <= closest_dist:
283
+ closest_dist = dist
284
+ closest_coords = coords
285
+ closest_cell = cell_index
286
+
287
+ if offset >= max_offset:
288
+ break
289
+ offset = wp.min(3.0 * offset, max_offset)
290
+
291
+ return make_free_sample(closest_cell, closest_coords)
292
+
293
+ return cell_lookup
230
294
 
231
295
  @wp.func
232
296
  def cell_measure(args: CellArg, s: Sample):
@@ -239,18 +303,23 @@ class Grid2D(Geometry):
239
303
  @cached_arg_value
240
304
  def side_arg_value(self, device) -> SideArg:
241
305
  args = self.SideArg()
306
+ self.fill_side_arg(args, device)
307
+ return args
242
308
 
309
+ def fill_side_arg(self, args: SideArg, device):
243
310
  args.axis_offsets = wp.vec2i(
244
311
  0,
245
312
  self.res[1],
246
313
  )
247
314
  args.cell_count = self.cell_count()
248
315
  args.cell_arg = self.cell_arg_value(device)
249
- return args
250
316
 
251
317
  def side_index_arg_value(self, device) -> SideIndexArg:
252
318
  return self.side_arg_value(device)
253
319
 
320
+ def fill_side_index_arg(self, args: SideIndexArg, device):
321
+ self.fill_side_arg(args, device)
322
+
254
323
  @wp.func
255
324
  def boundary_side_index(args: SideArg, boundary_side_index: int):
256
325
  """Boundary side to side index"""
@@ -273,14 +342,11 @@ class Grid2D(Geometry):
273
342
  def side_position(args: SideArg, s: Sample):
274
343
  side = Grid2D.get_side(args, s.element_index)
275
344
 
276
- coord = wp.where((side.origin[0] == 0) == (side.axis == 0), s.element_coords[0], 1.0 - s.element_coords[0])
277
-
278
- local_pos = wp.vec2(
279
- float(side.origin[0]),
280
- float(side.origin[1]) + coord,
281
- )
345
+ flip = Grid2D.is_flipped(side)
346
+ coord = wp.where(flip, 1.0 - s.element_coords[0], s.element_coords[0])
282
347
 
283
- pos = args.cell_arg.origin + wp.cw_mul(Grid2D._rotate(side.axis, local_pos), args.cell_arg.cell_size)
348
+ local_pos = wp.vec2(side.origin) + wp.vec2(0.0, coord)
349
+ pos = args.cell_arg.origin + wp.cw_mul(Grid2D.orient(side.axis, local_pos), args.cell_arg.cell_size)
284
350
 
285
351
  return pos
286
352
 
@@ -288,9 +354,10 @@ class Grid2D(Geometry):
288
354
  def side_deformation_gradient(args: SideArg, s: Sample):
289
355
  side = Grid2D.get_side(args, s.element_index)
290
356
 
291
- sign = wp.where((side.origin[0] == 0) == (side.axis == 0), 1.0, -1.0)
357
+ flip = Grid2D.is_flipped(side)
358
+ sign = wp.where(flip, -1.0, 1.0)
292
359
 
293
- return wp.cw_mul(Grid2D._rotate(side.axis, wp.vec2(0.0, sign)), args.cell_arg.cell_size)
360
+ return wp.cw_mul(Grid2D.orient(side.axis, wp.vec2(0.0, sign)), args.cell_arg.cell_size)
294
361
 
295
362
  @wp.func
296
363
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
@@ -303,70 +370,74 @@ class Grid2D(Geometry):
303
370
  @wp.func
304
371
  def side_measure(args: SideArg, s: Sample):
305
372
  side = Grid2D.get_side(args, s.element_index)
306
- long_axis = Grid2D.ROTATION[side.axis, 1]
373
+ long_axis = Grid2D.orient(side.axis, Grid2D.LONG_AXIS)
307
374
  return args.cell_arg.cell_size[long_axis]
308
375
 
309
376
  @wp.func
310
377
  def side_measure_ratio(args: SideArg, s: Sample):
311
378
  side = Grid2D.get_side(args, s.element_index)
312
- alt_axis = Grid2D.ROTATION[side.axis, 0]
379
+ alt_axis = Grid2D.orient(side.axis, Grid2D.ALT_AXIS)
313
380
  return 1.0 / args.cell_arg.cell_size[alt_axis]
314
381
 
315
382
  @wp.func
316
383
  def side_normal(args: SideArg, s: Sample):
317
384
  side = Grid2D.get_side(args, s.element_index)
318
385
 
319
- sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
386
+ # intentionally not using is_flipped to account for normql sign switch with orient(axis=1)
387
+ flip = side.origin[Grid2D.ALT_AXIS] == 0
388
+ sign = wp.where(flip, -1.0, 1.0)
320
389
 
321
390
  local_n = wp.vec2(sign, 0.0)
322
- return Grid2D._rotate(side.axis, local_n)
391
+ return Grid2D.orient(side.axis, local_n)
323
392
 
324
393
  @wp.func
325
394
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
326
395
  side = Grid2D.get_side(arg, side_index)
327
396
 
328
- inner_alt = wp.where(side.origin[0] == 0, 0, side.origin[0] - 1)
397
+ inner_alt = wp.where(side.origin[Grid2D.ALT_AXIS] == 0, 0, side.origin[Grid2D.ALT_AXIS] - 1)
329
398
 
330
399
  inner_origin = wp.vec2i(inner_alt, side.origin[1])
331
400
 
332
- cell = Grid2D._rotate(side.axis, inner_origin)
401
+ cell = Grid2D.orient(side.axis, inner_origin)
333
402
  return Grid2D.cell_index(arg.cell_arg.res, cell)
334
403
 
335
404
  @wp.func
336
405
  def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
337
406
  side = Grid2D.get_side(arg, side_index)
338
407
 
339
- alt_axis = Grid2D.ROTATION[side.axis, 0]
408
+ alt_axis = Grid2D.orient(side.axis, 0)
340
409
  outer_alt = wp.where(
341
- side.origin[0] == arg.cell_arg.res[alt_axis], arg.cell_arg.res[alt_axis] - 1, side.origin[0]
410
+ side.origin[Grid2D.ALT_AXIS] == arg.cell_arg.res[alt_axis], arg.cell_arg.res[alt_axis] - 1, side.origin[0]
342
411
  )
343
412
 
344
- outer_origin = wp.vec2i(outer_alt, side.origin[1])
413
+ outer_origin = wp.vec2i(outer_alt, side.origin[Grid2D.LONG_AXIS])
345
414
 
346
- cell = Grid2D._rotate(side.axis, outer_origin)
415
+ cell = Grid2D.orient(side.axis, outer_origin)
347
416
  return Grid2D.cell_index(arg.cell_arg.res, cell)
348
417
 
349
418
  @wp.func
350
419
  def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
351
420
  side = Grid2D.get_side(args, side_index)
352
421
 
353
- inner_alt = wp.where(side.origin[0] == 0, 0.0, 1.0)
422
+ inner_alt = wp.where(side.origin[Grid2D.ALT_AXIS] == 0, 0.0, 1.0)
354
423
 
355
- side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), side_coords[0], 1.0 - side_coords[0])
424
+ flip = Grid2D.is_flipped(side)
425
+ side_coord = wp.where(flip, 1.0 - side_coords[0], side_coords[0])
356
426
 
357
- coords = Grid2D._rotate(side.axis, wp.vec2(inner_alt, side_coord))
427
+ coords = Grid2D.orient(side.axis, wp.vec2(inner_alt, side_coord))
358
428
  return Coords(coords[0], coords[1], 0.0)
359
429
 
360
430
  @wp.func
361
431
  def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
362
432
  side = Grid2D.get_side(args, side_index)
363
433
 
364
- alt_axis = Grid2D.ROTATION[side.axis, 0]
365
- outer_alt = wp.where(side.origin[0] == args.cell_arg.res[alt_axis], 1.0, 0.0)
434
+ alt_axis = Grid2D.orient(side.axis, Grid2D.ALT_AXIS)
435
+ outer_alt = wp.where(side.origin[Grid2D.ALT_AXIS] == args.cell_arg.res[alt_axis], 1.0, 0.0)
366
436
 
367
- side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), side_coords[0], 1.0 - side_coords[0])
437
+ flip = Grid2D.is_flipped(side)
438
+ side_coord = wp.where(flip, 1.0 - side_coords[0], side_coords[0])
368
439
 
369
- coords = Grid2D._rotate(side.axis, wp.vec2(outer_alt, side_coord))
440
+ coords = Grid2D.orient(side.axis, wp.vec2(outer_alt, side_coord))
370
441
  return Coords(coords[0], coords[1], 0.0)
371
442
 
372
443
  @wp.func
@@ -379,10 +450,11 @@ class Grid2D(Geometry):
379
450
  side = Grid2D.get_side(args, side_index)
380
451
  cell = Grid2D.get_cell(args.cell_arg.res, element_index)
381
452
 
382
- if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
383
- long_axis = Grid2D.ROTATION[side.axis, 1]
453
+ if float(side.origin[Grid2D.ALT_AXIS] - cell[side.axis]) == element_coords[side.axis]:
454
+ long_axis = Grid2D.orient(side.axis, Grid2D.LONG_AXIS)
384
455
  axis_coord = element_coords[long_axis]
385
- side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), axis_coord, 1.0 - axis_coord)
456
+ flip = Grid2D.is_flipped(side)
457
+ side_coord = wp.where(flip, 1.0 - axis_coord, axis_coord)
386
458
  return Coords(side_coord, 0.0, 0.0)
387
459
 
388
460
  return Coords(OUTSIDE)
@@ -390,3 +462,26 @@ class Grid2D(Geometry):
390
462
  @wp.func
391
463
  def side_to_cell_arg(side_arg: SideArg):
392
464
  return side_arg.cell_arg
465
+
466
+ @wp.func
467
+ def side_coordinates(args: SideArg, side_index: int, pos: wp.vec2):
468
+ cell_arg = args.cell_arg
469
+ side = Grid2D.get_side(args, side_index)
470
+ long_axis = Grid2D.orient(side.axis, Grid2D.LONG_AXIS)
471
+ flip = Grid2D.is_flipped(side)
472
+
473
+ long_loc = (pos[long_axis] - cell_arg.origin[long_axis]) / cell_arg.cell_size[long_axis] - float(side.origin[1])
474
+ coord = wp.where(flip, 1.0 - long_loc, long_loc)
475
+
476
+ return Coords(coord, 0.0, 0.0)
477
+
478
+ @wp.func
479
+ def side_closest_point(args: SideArg, side_index: int, pos: wp.vec2):
480
+ coord = Grid2D.side_coordinates(args, side_index, pos)
481
+
482
+ cell_arg = args.cell_arg
483
+ side = Grid2D.get_side(args, side_index)
484
+ long_axis = Grid2D.orient(side.axis, Grid2D.LONG_AXIS)
485
+ proj_coord = wp.clamp(coord, 0.0, 1.0)
486
+ dist = (coord - proj_coord) * cell_arg.cell_size[long_axis]
487
+ return Coords(proj_coord, 0.0, 0.0), dist * dist
@@ -16,9 +16,10 @@
16
16
  from typing import Any, Optional
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import cached_arg_value
20
- from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
19
+ from warp.fem.cache import cached_arg_value, dynamic_func
20
+ from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
21
21
 
22
+ from .closest_point import project_on_box_at_origin
22
23
  from .element import Cube, Square
23
24
  from .geometry import Geometry
24
25
 
@@ -225,10 +226,13 @@ class Grid3D(Geometry):
225
226
  @cached_arg_value
226
227
  def cell_arg_value(self, device) -> CellArg:
227
228
  args = self.CellArg()
229
+ self.fill_cell_arg(args, device)
230
+ return args
231
+
232
+ def fill_cell_arg(self, args: CellArg, device):
228
233
  args.res = self.res
229
234
  args.origin = self.bounds_lo
230
235
  args.cell_size = self.cell_size
231
- return args
232
236
 
233
237
  @wp.func
234
238
  def cell_position(args: CellArg, s: Sample):
@@ -251,24 +255,87 @@ class Grid3D(Geometry):
251
255
  return wp.diag(wp.cw_div(wp.vec3(1.0), args.cell_size))
252
256
 
253
257
  @wp.func
254
- def cell_lookup(args: CellArg, pos: wp.vec3):
255
- loc_pos = wp.cw_div(pos - args.origin, args.cell_size)
256
- x = wp.clamp(loc_pos[0], 0.0, float(args.res[0]))
257
- y = wp.clamp(loc_pos[1], 0.0, float(args.res[1]))
258
- z = wp.clamp(loc_pos[2], 0.0, float(args.res[2]))
259
-
260
- x_cell = wp.min(wp.floor(x), float(args.res[0]) - 1.0)
261
- y_cell = wp.min(wp.floor(y), float(args.res[1]) - 1.0)
262
- z_cell = wp.min(wp.floor(z), float(args.res[2]) - 1.0)
263
-
264
- coords = Coords(x - x_cell, y - y_cell, z - z_cell)
265
- cell_index = Grid3D.cell_index(args.res, Grid3D.Cell(int(x_cell), int(y_cell), int(z_cell)))
266
-
267
- return make_free_sample(cell_index, coords)
258
+ def cell_coordinates(args: Grid3DCellArg, cell_index: int, pos: wp.vec3):
259
+ uvw = wp.cw_div(pos - args.origin, args.cell_size)
260
+ ijk = Grid3D.get_cell(args.res, cell_index)
261
+ return uvw - wp.vec3(ijk)
268
262
 
269
263
  @wp.func
270
- def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
271
- return Grid3D.cell_lookup(args, pos)
264
+ def cell_closest_point(args: Grid3DCellArg, cell_index: int, pos: wp.vec3):
265
+ ijk_world = wp.cw_mul(wp.vec3(Grid3D.get_cell(args.res, cell_index)), args.cell_size) + args.origin
266
+ dist_sq, coords = project_on_box_at_origin(pos - ijk_world, args.cell_size)
267
+ return coords, dist_sq
268
+
269
+ def supports_cell_lookup(self, device):
270
+ return True
271
+
272
+ def make_filtered_cell_lookup(self, filter_func: wp.Function = None):
273
+ suffix = f"{self.name}{filter_func.key if filter_func is not None else ''}"
274
+
275
+ @dynamic_func(suffix=suffix)
276
+ def cell_lookup(args: self.CellArg, pos: wp.vec3, max_dist: float, filter_data: Any, filter_target: Any):
277
+ cell_size = args.cell_size
278
+ res = args.res
279
+
280
+ # Start at closest point on grid
281
+ loc_pos = wp.cw_div(pos - args.origin, cell_size)
282
+ x = wp.clamp(loc_pos[0], 0.0, float(res[0]))
283
+ y = wp.clamp(loc_pos[1], 0.0, float(res[1]))
284
+ z = wp.clamp(loc_pos[2], 0.0, float(res[2]))
285
+
286
+ x_cell = wp.min(wp.floor(x), float(res[0]) - 1.0)
287
+ y_cell = wp.min(wp.floor(y), float(res[1]) - 1.0)
288
+ z_cell = wp.min(wp.floor(z), float(res[2]) - 1.0)
289
+
290
+ coords = Coords(x - x_cell, y - y_cell, z - z_cell)
291
+ cell_index = Grid3D.cell_index(res, Grid3D.Cell(int(x_cell), int(y_cell), int(z_cell)))
292
+
293
+ if wp.static(filter_func is None):
294
+ return make_free_sample(cell_index, coords)
295
+ else:
296
+ if filter_func(filter_data, cell_index) == filter_target:
297
+ return make_free_sample(cell_index, coords)
298
+
299
+ offset = float(0.5)
300
+ min_cell_size = wp.min(cell_size)
301
+ max_offset = wp.ceil(max_dist / min_cell_size)
302
+ scales = wp.cw_div(wp.vec3(min_cell_size), cell_size)
303
+
304
+ closest_cell = NULL_ELEMENT_INDEX
305
+ closest_coords = Coords()
306
+
307
+ # Iterate over increasingly larger neighborhoods
308
+ while closest_cell == NULL_ELEMENT_INDEX:
309
+ i_min = wp.max(0, int(wp.floor(x - offset * scales[0])))
310
+ i_max = wp.min(res[0], int(wp.floor(x + offset * scales[0])) + 1)
311
+ j_min = wp.max(0, int(wp.floor(y - offset * scales[1])))
312
+ j_max = wp.min(res[1], int(wp.floor(y + offset * scales[1])) + 1)
313
+ k_min = wp.max(0, int(wp.floor(z - offset * scales[2])))
314
+ k_max = wp.min(res[2], int(wp.floor(z + offset * scales[2])) + 1)
315
+
316
+ closest_dist = min_cell_size * min_cell_size * float(offset * offset)
317
+
318
+ for i in range(i_min, i_max):
319
+ for j in range(j_min, j_max):
320
+ for k in range(k_min, k_max):
321
+ ijk = Grid3D.Cell(i, j, k)
322
+ cell_index = Grid3D.cell_index(res, ijk)
323
+ if filter_func(filter_data, cell_index) == filter_target:
324
+ rel_pos = wp.cw_mul(loc_pos - wp.vec3(ijk), cell_size)
325
+ dist, coords = project_on_box_at_origin(rel_pos, cell_size)
326
+
327
+ if dist <= closest_dist:
328
+ closest_dist = dist
329
+ closest_coords = coords
330
+ closest_cell = cell_index
331
+
332
+ if offset >= max_offset:
333
+ break
334
+ offset = wp.min(3.0 * offset, max_offset)
335
+
336
+ return make_free_sample(closest_cell, closest_coords)
337
+
338
+ return cell_lookup
272
339
 
273
340
  @wp.func
274
341
  def cell_measure(args: CellArg, s: Sample):
@@ -281,7 +348,10 @@ class Grid3D(Geometry):
281
348
  @cached_arg_value
282
349
  def side_arg_value(self, device) -> SideArg:
283
350
  args = self.SideArg()
351
+ self.fill_side_arg(args, device)
352
+ return args
284
353
 
354
+ def fill_side_arg(self, args: SideArg, device):
285
355
  axis_dims = wp.vec3i(
286
356
  self.res[1] * self.res[2],
287
357
  self.res[2] * self.res[0],
@@ -294,11 +364,13 @@ class Grid3D(Geometry):
294
364
  )
295
365
  args.cell_count = self.cell_count()
296
366
  args.cell_arg = self.cell_arg_value(device)
297
- return args
298
367
 
299
368
  def side_index_arg_value(self, device) -> SideIndexArg:
300
369
  return self.side_arg_value(device)
301
370
 
371
+ def fill_side_index_arg(self, args: SideIndexArg, device):
372
+ self.fill_side_arg(args, device)
373
+
302
374
  @wp.func
303
375
  def boundary_side_index(args: SideArg, boundary_side_index: int):
304
376
  """Boundary side to side index"""
@@ -450,3 +522,27 @@ class Grid3D(Geometry):
450
522
  @wp.func
451
523
  def side_to_cell_arg(side_arg: SideArg):
452
524
  return side_arg.cell_arg
525
+
526
+ @wp.func
527
+ def side_coordinates(args: SideArg, side_index: int, pos: wp.vec3):
528
+ cell_arg = args.cell_arg
529
+ side = Grid3D.get_side(args, side_index)
530
+
531
+ pos_loc = Grid3D._world_to_local(side.axis, wp.cw_div(pos - cell_arg.origin, cell_arg.cell_size)) - wp.vec3(
532
+ side.origin
533
+ )
534
+
535
+ coord0 = wp.where(side.origin[0] == 0, 1.0 - pos_loc[1], pos_loc[1])
536
+ return Coords(coord0, pos_loc[2], 0.0)
537
+
538
+ @wp.func
539
+ def side_closest_point(args: SideArg, side_index: int, pos: wp.vec3):
540
+ coord = Grid3D.side_coordinates(args, side_index, pos)
541
+
542
+ cell_arg = args.cell_arg
543
+ side = Grid3D.get_side(args, side_index)
544
+
545
+ loc_cell_size = Grid3D._world_to_local(side.axis, cell_arg.cell_size)
546
+ long_lat_sizes = wp.vec2(loc_cell_size[1], loc_cell_size[2])
547
+ dist, proj_coord = project_on_box_at_origin(wp.vec2(coord[0], coord[1]), long_lat_sizes)
548
+ return proj_coord, dist