warp-lang 1.2.1__py3-none-manylinux2014_aarch64.whl → 1.3.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +401 -199
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +122 -39
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +344 -227
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.1.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.1.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
@@ -30,8 +30,8 @@ class TetmeshCellArg:
30
30
  vertex_tet_offsets: wp.array(dtype=int)
31
31
  vertex_tet_indices: wp.array(dtype=int)
32
32
 
33
- # for transforming reference gradient
34
- deformation_gradients: wp.array(dtype=wp.mat33f)
33
+ # for global cell lookup
34
+ tet_bvh: wp.uint64
35
35
 
36
36
 
37
37
  @wp.struct
@@ -42,6 +42,7 @@ class TetmeshSideArg:
42
42
 
43
43
 
44
44
  _mat32 = wp.mat(shape=(3, 2), dtype=float)
45
+ _NULL_BVH = wp.constant(wp.uint64(-1))
45
46
 
46
47
 
47
48
  class Tetmesh(Geometry):
@@ -50,7 +51,11 @@ class Tetmesh(Geometry):
50
51
  dimension = 3
51
52
 
52
53
  def __init__(
53
- self, tet_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
54
+ self,
55
+ tet_vertex_indices: wp.array,
56
+ positions: wp.array,
57
+ build_bvh: bool = False,
58
+ temporary_store: Optional[TemporaryStore] = None,
54
59
  ):
55
60
  """
56
61
  Constructs a tetrahedral mesh.
@@ -59,6 +64,7 @@ class Tetmesh(Geometry):
59
64
  tet_vertex_indices: warp array of shape (num_tets, 4) containing vertex indices for each tet
60
65
  positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
61
66
  temporary_store: shared pool from which to allocate temporary arrays
67
+ build_bvh: Whether to also build the tet BVH, which is necessary for the global `fem.lookup` operator to function without initial guess
62
68
  """
63
69
 
64
70
  self.tet_vertex_indices = tet_vertex_indices
@@ -72,8 +78,37 @@ class Tetmesh(Geometry):
72
78
  self._edge_count = 0
73
79
  self._build_topology(temporary_store)
74
80
 
75
- self._deformation_gradients: wp.array = None
76
- self._compute_deformation_gradients()
81
+ self._tet_bvh: wp.Bvh = None
82
+ if build_bvh:
83
+ self._build_bvh()
84
+
85
+ def update_bvh(self, force_rebuild: bool = False):
86
+ """
87
+ Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
88
+ """
89
+
90
+ if self._tet_bvh is None or force_rebuild:
91
+ return self.build_bvh()
92
+
93
+ wp.launch(
94
+ Tetmesh._compute_tet_bounds,
95
+ self.tet_vertex_indices,
96
+ self.positions,
97
+ self._tet_bvh.lowers,
98
+ self._tet_bvh.uppers,
99
+ )
100
+ self._tet_bvh.refit()
101
+
102
+ def _build_bvh(self, temporary_store: Optional[TemporaryStore] = None):
103
+ lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
104
+ uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
105
+ wp.launch(
106
+ Tetmesh._compute_tet_bounds,
107
+ device=self.positions.device,
108
+ dim=self.cell_count(),
109
+ inputs=[self.tet_vertex_indices, self.positions, lowers, uppers],
110
+ )
111
+ self._tet_bvh = wp.Bvh(lowers, uppers)
77
112
 
78
113
  def cell_count(self):
79
114
  return self.tet_vertex_indices.shape[0]
@@ -129,7 +164,8 @@ class Tetmesh(Geometry):
129
164
  args.positions = self.positions.to(device)
130
165
  args.vertex_tet_offsets = self._vertex_tet_offsets.to(device)
131
166
  args.vertex_tet_indices = self._vertex_tet_indices.to(device)
132
- args.deformation_gradients = self._deformation_gradients.to(device)
167
+
168
+ args.tet_bvh = _NULL_BVH if self._tet_bvh is None else self._tet_bvh.id
133
169
 
134
170
  return args
135
171
 
@@ -146,11 +182,15 @@ class Tetmesh(Geometry):
146
182
 
147
183
  @wp.func
148
184
  def cell_deformation_gradient(args: CellArg, s: Sample):
149
- return args.deformation_gradients[s.element_index]
185
+ p0 = args.positions[args.tet_vertex_indices[s.element_index, 0]]
186
+ p1 = args.positions[args.tet_vertex_indices[s.element_index, 1]]
187
+ p2 = args.positions[args.tet_vertex_indices[s.element_index, 2]]
188
+ p3 = args.positions[args.tet_vertex_indices[s.element_index, 3]]
189
+ return wp.mat33(p1 - p0, p2 - p0, p3 - p0)
150
190
 
151
191
  @wp.func
152
192
  def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
153
- return wp.inverse(args.deformation_gradients[s.element_index])
193
+ return wp.inverse(Tetmesh.cell_deformation_gradient(args, s))
154
194
 
155
195
  @wp.func
156
196
  def _project_on_tet(args: CellArg, pos: wp.vec3, tet_index: int):
@@ -165,29 +205,54 @@ class Tetmesh(Geometry):
165
205
  return dist, coords
166
206
 
167
207
  @wp.func
168
- def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
208
+ def _bvh_lookup(args: CellArg, pos: wp.vec3):
169
209
  closest_tet = int(NULL_ELEMENT_INDEX)
170
210
  closest_coords = Coords(OUTSIDE)
171
211
  closest_dist = float(1.0e8)
172
212
 
173
- for v in range(4):
174
- vtx = args.tet_vertex_indices[guess.element_index, v]
175
- tet_beg = args.vertex_tet_offsets[vtx]
176
- tet_end = args.vertex_tet_offsets[vtx + 1]
177
-
178
- for t in range(tet_beg, tet_end):
179
- tet = args.vertex_tet_indices[t]
213
+ if args.tet_bvh != _NULL_BVH:
214
+ query = wp.bvh_query_aabb(args.tet_bvh, pos, pos)
215
+ tet = int(0)
216
+ while wp.bvh_query_next(query, tet):
180
217
  dist, coords = Tetmesh._project_on_tet(args, pos, tet)
181
218
  if dist <= closest_dist:
182
219
  closest_dist = dist
183
220
  closest_tet = tet
184
221
  closest_coords = coords
185
222
 
223
+ return closest_dist, closest_tet, closest_coords
224
+
225
+ @wp.func
226
+ def cell_lookup(args: CellArg, pos: wp.vec3):
227
+ closest_dist, closest_tet, closest_coords = Tetmesh._bvh_lookup(args, pos)
228
+
229
+ return make_free_sample(closest_tet, closest_coords)
230
+
231
+ @wp.func
232
+ def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
233
+ closest_dist, closest_tet, closest_coords = Tetmesh._bvh_lookup(args, pos)
234
+ return make_free_sample(closest_tet, closest_coords)
235
+
236
+ if closest_tet == NULL_ELEMENT_INDEX:
237
+ # nothing found yet, bvh may not be available or outside mesh
238
+ for v in range(4):
239
+ vtx = args.tet_vertex_indices[guess.element_index, v]
240
+ tet_beg = args.vertex_tet_offsets[vtx]
241
+ tet_end = args.vertex_tet_offsets[vtx + 1]
242
+
243
+ for t in range(tet_beg, tet_end):
244
+ tet = args.vertex_tet_indices[t]
245
+ dist, coords = Tetmesh._project_on_tet(args, pos, tet)
246
+ if dist <= closest_dist:
247
+ closest_dist = dist
248
+ closest_tet = tet
249
+ closest_coords = coords
250
+
186
251
  return make_free_sample(closest_tet, closest_coords)
187
252
 
188
253
  @wp.func
189
254
  def cell_measure(args: CellArg, s: Sample):
190
- return wp.abs(wp.determinant(args.deformation_gradients[s.element_index])) / 6.0
255
+ return wp.abs(wp.determinant(Tetmesh.cell_deformation_gradient(args, s))) / 6.0
191
256
 
192
257
  @wp.func
193
258
  def cell_measure_ratio(args: CellArg, s: Sample):
@@ -247,12 +312,14 @@ class Tetmesh(Geometry):
247
312
  @wp.func
248
313
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
249
314
  cell_index = Tetmesh.side_inner_cell_index(args, s.element_index)
250
- return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
315
+ s_cell = make_free_sample(cell_index, Coords())
316
+ return Tetmesh.cell_inverse_deformation_gradient(args.cell_arg, s_cell)
251
317
 
252
318
  @wp.func
253
319
  def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
254
320
  cell_index = Tetmesh.side_outer_cell_index(args, s.element_index)
255
- return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
321
+ s_cell = make_free_sample(cell_index, Coords())
322
+ return Tetmesh.cell_inverse_deformation_gradient(args.cell_arg, s_cell)
256
323
 
257
324
  @wp.func
258
325
  def side_measure(args: SideArg, s: Sample):
@@ -355,12 +422,12 @@ class Tetmesh(Geometry):
355
422
  return side_arg.cell_arg
356
423
 
357
424
  def _build_topology(self, temporary_store: TemporaryStore):
358
- from warp.fem.utils import compress_node_indices, masked_indices
425
+ from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
359
426
  from warp.utils import array_scan
360
427
 
361
428
  device = self.tet_vertex_indices.device
362
429
 
363
- vertex_tet_offsets, vertex_tet_indices, _, __ = compress_node_indices(
430
+ vertex_tet_offsets, vertex_tet_indices = compress_node_indices(
364
431
  self.vertex_count(), self.tet_vertex_indices, temporary_store=temporary_store
365
432
  )
366
433
  self._vertex_tet_offsets = vertex_tet_offsets.detach()
@@ -406,16 +473,11 @@ class Tetmesh(Geometry):
406
473
  array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
407
474
 
408
475
  # Get back edge count to host
409
- if device.is_cuda:
410
- face_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
411
- # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
412
- wp.copy(
413
- dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
476
+ face_count = int(
477
+ host_read_at_index(
478
+ vertex_unique_face_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
414
479
  )
415
- wp.synchronize_stream(wp.get_stream(device))
416
- face_count = int(face_count.array.numpy()[0])
417
- else:
418
- face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
480
+ )
419
481
 
420
482
  self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec3i, device=device)
421
483
  self._face_tet_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
@@ -457,6 +519,7 @@ class Tetmesh(Geometry):
457
519
  self._boundary_face_indices = boundary_face_indices.detach()
458
520
 
459
521
  def _compute_tet_edges(self, temporary_store: Optional[TemporaryStore] = None):
522
+ from warp.fem.utils import host_read_at_index
460
523
  from warp.utils import array_scan
461
524
 
462
525
  device = self.tet_vertex_indices.device
@@ -499,19 +562,11 @@ class Tetmesh(Geometry):
499
562
  array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
500
563
 
501
564
  # Get back edge count to host
502
- if device.is_cuda:
503
- edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
504
- # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
505
- wp.copy(
506
- dest=edge_count.array,
507
- src=vertex_unique_edge_offsets.array,
508
- src_offset=self.vertex_count() - 1,
509
- count=1,
565
+ self._edge_count = int(
566
+ host_read_at_index(
567
+ vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
510
568
  )
511
- wp.synchronize_stream(wp.get_stream(device))
512
- self._edge_count = int(edge_count.array.numpy()[0])
513
- else:
514
- self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
569
+ )
515
570
 
516
571
  self._tet_edge_indices = wp.empty(
517
572
  dtype=int, device=self.tet_vertex_indices.device, shape=(self.cell_count(), 6)
@@ -539,16 +594,6 @@ class Tetmesh(Geometry):
539
594
  vertex_unique_edge_count.release()
540
595
  vertex_edge_ends.release()
541
596
 
542
- def _compute_deformation_gradients(self):
543
- self._deformation_gradients = wp.empty(dtype=wp.mat33f, device=self.positions.device, shape=(self.cell_count()))
544
-
545
- wp.launch(
546
- kernel=Tetmesh._compute_deformation_gradients_kernel,
547
- dim=self._deformation_gradients.shape,
548
- device=self._deformation_gradients.device,
549
- inputs=[self.tet_vertex_indices, self.positions, self._deformation_gradients],
550
- )
551
-
552
597
  @wp.kernel
553
598
  def _count_starting_faces_kernel(
554
599
  tet_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
@@ -821,20 +866,17 @@ class Tetmesh(Geometry):
821
866
  tet_edge_indices[t][k + 3] = edge_id
822
867
 
823
868
  @wp.kernel
824
- def _compute_deformation_gradients_kernel(
869
+ def _compute_tet_bounds(
825
870
  tet_vertex_indices: wp.array2d(dtype=int),
826
- positions: wp.array(dtype=wp.vec3f),
827
- transforms: wp.array(dtype=wp.mat33f),
871
+ positions: wp.array(dtype=wp.vec3),
872
+ lowers: wp.array(dtype=wp.vec3),
873
+ uppers: wp.array(dtype=wp.vec3),
828
874
  ):
829
875
  t = wp.tid()
830
-
831
876
  p0 = positions[tet_vertex_indices[t, 0]]
832
877
  p1 = positions[tet_vertex_indices[t, 1]]
833
878
  p2 = positions[tet_vertex_indices[t, 2]]
834
879
  p3 = positions[tet_vertex_indices[t, 3]]
835
880
 
836
- e1 = p1 - p0
837
- e2 = p2 - p0
838
- e3 = p3 - p0
839
-
840
- transforms[t] = wp.mat33(e1, e2, e3)
881
+ lowers[t] = wp.min(wp.min(p0, p1), wp.min(p2, p3))
882
+ uppers[t] = wp.max(wp.max(p0, p1), wp.max(p2, p3))
@@ -30,7 +30,8 @@ class Trimesh2DCellArg:
30
30
  vertex_tri_offsets: wp.array(dtype=int)
31
31
  vertex_tri_indices: wp.array(dtype=int)
32
32
 
33
- deformation_gradients: wp.array(dtype=wp.mat22f)
33
+ # for global cell lookup
34
+ tri_bvh: wp.uint64
34
35
 
35
36
 
36
37
  @wp.struct
@@ -40,13 +41,20 @@ class Trimesh2DSideArg:
40
41
  edge_tri_indices: wp.array(dtype=wp.vec2i)
41
42
 
42
43
 
44
+ _NULL_BVH = wp.constant(wp.uint64(-1))
45
+
46
+
43
47
  class Trimesh2D(Geometry):
44
48
  """Two-dimensional triangular mesh geometry"""
45
49
 
46
50
  dimension = 2
47
51
 
48
52
  def __init__(
49
- self, tri_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
53
+ self,
54
+ tri_vertex_indices: wp.array,
55
+ positions: wp.array,
56
+ build_bvh: bool = False,
57
+ temporary_store: Optional[TemporaryStore] = None,
50
58
  ):
51
59
  """
52
60
  Constructs a two-dimensional triangular mesh.
@@ -55,6 +63,7 @@ class Trimesh2D(Geometry):
55
63
  tri_vertex_indices: warp array of shape (num_tris, 3) containing vertex indices for each tri
56
64
  positions: warp array of shape (num_vertices, 2) containing 2d position for each vertex
57
65
  temporary_store: shared pool from which to allocate temporary arrays
66
+ build_bvh: Whether to also build the triangle BVH, which is necessary for the global `fem.lookup` operator to function without initial guess
58
67
  """
59
68
 
60
69
  self.tri_vertex_indices = tri_vertex_indices
@@ -66,8 +75,37 @@ class Trimesh2D(Geometry):
66
75
  self._vertex_tri_indices: wp.array = None
67
76
  self._build_topology(temporary_store)
68
77
 
69
- self._deformation_gradients: wp.array = None
70
- self._compute_deformation_gradients()
78
+ self._tri_bvh: wp.Bvh = None
79
+ if build_bvh:
80
+ self._build_bvh()
81
+
82
+ def update_bvh(self, force_rebuild: bool = False):
83
+ """
84
+ Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
85
+ """
86
+
87
+ if self._tri_bvh is None or force_rebuild:
88
+ return self.build_bvh()
89
+
90
+ wp.launch(
91
+ Trimesh2D._compute_tri_bounds,
92
+ self.tri_vertex_indices,
93
+ self.positions,
94
+ self._tri_bvh.lowers,
95
+ self._tri_bvh.uppers,
96
+ )
97
+ self._tri_bvh.refit()
98
+
99
+ def _build_bvh(self, temporary_store: Optional[TemporaryStore] = None):
100
+ lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
101
+ uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
102
+ wp.launch(
103
+ Trimesh2D._compute_tri_bounds,
104
+ device=self.positions.device,
105
+ dim=self.cell_count(),
106
+ inputs=[self.tri_vertex_indices, self.positions, lowers, uppers],
107
+ )
108
+ self._tri_bvh = wp.Bvh(lowers, uppers)
71
109
 
72
110
  def cell_count(self):
73
111
  return self.tri_vertex_indices.shape[0]
@@ -112,7 +150,7 @@ class Trimesh2D(Geometry):
112
150
  args.positions = self.positions.to(device)
113
151
  args.vertex_tri_offsets = self._vertex_tri_offsets.to(device)
114
152
  args.vertex_tri_indices = self._vertex_tri_indices.to(device)
115
- args.deformation_gradients = self._deformation_gradients.to(device)
153
+ args.tri_bvh = _NULL_BVH if self._tri_bvh is None else self._tri_bvh.id
116
154
 
117
155
  return args
118
156
 
@@ -127,11 +165,14 @@ class Trimesh2D(Geometry):
127
165
 
128
166
  @wp.func
129
167
  def cell_deformation_gradient(args: CellArg, s: Sample):
130
- return args.deformation_gradients[s.element_index]
168
+ p0 = args.positions[args.tri_vertex_indices[s.element_index, 0]]
169
+ p1 = args.positions[args.tri_vertex_indices[s.element_index, 1]]
170
+ p2 = args.positions[args.tri_vertex_indices[s.element_index, 2]]
171
+ return wp.mat22(p1 - p0, p2 - p0)
131
172
 
132
173
  @wp.func
133
174
  def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
134
- return wp.inverse(args.deformation_gradients[s.element_index])
175
+ return wp.inverse(Trimesh2D.cell_deformation_gradient(args, s))
135
176
 
136
177
  @wp.func
137
178
  def _project_on_tri(args: CellArg, pos: wp.vec2, tri_index: int):
@@ -145,29 +186,54 @@ class Trimesh2D(Geometry):
145
186
  return dist, coords
146
187
 
147
188
  @wp.func
148
- def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
189
+ def _bvh_lookup(args: CellArg, pos: wp.vec2):
149
190
  closest_tri = int(NULL_ELEMENT_INDEX)
150
191
  closest_coords = Coords(OUTSIDE)
151
192
  closest_dist = float(1.0e8)
152
193
 
153
- for v in range(3):
154
- vtx = args.tri_vertex_indices[guess.element_index, v]
155
- tri_beg = args.vertex_tri_offsets[vtx]
156
- tri_end = args.vertex_tri_offsets[vtx + 1]
157
-
158
- for t in range(tri_beg, tri_end):
159
- tri = args.vertex_tri_indices[t]
194
+ if args.tri_bvh != _NULL_BVH:
195
+ pos3 = wp.vec3(pos[0], pos[1], 0.0)
196
+ query = wp.bvh_query_aabb(args.tri_bvh, pos3, pos3)
197
+ tri = int(0)
198
+ while wp.bvh_query_next(query, tri):
160
199
  dist, coords = Trimesh2D._project_on_tri(args, pos, tri)
161
200
  if dist <= closest_dist:
162
201
  closest_dist = dist
163
202
  closest_tri = tri
164
203
  closest_coords = coords
165
204
 
205
+ return closest_dist, closest_tri, closest_coords
206
+
207
+ @wp.func
208
+ def cell_lookup(args: CellArg, pos: wp.vec2):
209
+ closest_dist, closest_tri, closest_coords = Trimesh2D._bvh_lookup(args, pos)
210
+
211
+ return make_free_sample(closest_tri, closest_coords)
212
+
213
+ @wp.func
214
+ def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
215
+ closest_dist, closest_tri, closest_coords = Trimesh2D._bvh_lookup(args, pos)
216
+
217
+ if closest_tri == NULL_ELEMENT_INDEX:
218
+ # nothing found yet, bvh may not be available or outside mesh
219
+ for v in range(3):
220
+ vtx = args.tri_vertex_indices[guess.element_index, v]
221
+ tri_beg = args.vertex_tri_offsets[vtx]
222
+ tri_end = args.vertex_tri_offsets[vtx + 1]
223
+
224
+ for t in range(tri_beg, tri_end):
225
+ tri = args.vertex_tri_indices[t]
226
+ dist, coords = Trimesh2D._project_on_tri(args, pos, tri)
227
+ if dist <= closest_dist:
228
+ closest_dist = dist
229
+ closest_tri = tri
230
+ closest_coords = coords
231
+
166
232
  return make_free_sample(closest_tri, closest_coords)
167
233
 
168
234
  @wp.func
169
235
  def cell_measure(args: CellArg, s: Sample):
170
- return 0.5 * wp.abs(wp.determinant(args.deformation_gradients[s.element_index]))
236
+ return 0.5 * wp.abs(wp.determinant(Trimesh2D.cell_deformation_gradient(args, s)))
171
237
 
172
238
  @wp.func
173
239
  def cell_normal(args: CellArg, s: Sample):
@@ -214,12 +280,14 @@ class Trimesh2D(Geometry):
214
280
  @wp.func
215
281
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
216
282
  cell_index = Trimesh2D.side_inner_cell_index(args, s.element_index)
217
- return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
283
+ s_cell = make_free_sample(cell_index, Coords())
284
+ return Trimesh2D.cell_inverse_deformation_gradient(args.cell_arg, s_cell)
218
285
 
219
286
  @wp.func
220
287
  def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
221
288
  cell_index = Trimesh2D.side_outer_cell_index(args, s.element_index)
222
- return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
289
+ s_cell = make_free_sample(cell_index, Coords())
290
+ return Trimesh2D.cell_inverse_deformation_gradient(args.cell_arg, s_cell)
223
291
 
224
292
  @wp.func
225
293
  def side_measure(args: SideArg, s: Sample):
@@ -321,12 +389,12 @@ class Trimesh2D(Geometry):
321
389
  return side_arg.cell_arg
322
390
 
323
391
  def _build_topology(self, temporary_store: TemporaryStore):
324
- from warp.fem.utils import compress_node_indices, masked_indices
392
+ from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
325
393
  from warp.utils import array_scan
326
394
 
327
395
  device = self.tri_vertex_indices.device
328
396
 
329
- vertex_tri_offsets, vertex_tri_indices, _, __ = compress_node_indices(
397
+ vertex_tri_offsets, vertex_tri_indices = compress_node_indices(
330
398
  self.vertex_count(), self.tri_vertex_indices, temporary_store=temporary_store
331
399
  )
332
400
  self._vertex_tri_offsets = vertex_tri_offsets.detach()
@@ -370,16 +438,11 @@ class Trimesh2D(Geometry):
370
438
  array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
371
439
 
372
440
  # Get back edge count to host
373
- if device.is_cuda:
374
- edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
375
- # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
376
- wp.copy(
377
- dest=edge_count.array, src=vertex_unique_edge_offsets.array, src_offset=self.vertex_count() - 1, count=1
441
+ edge_count = int(
442
+ host_read_at_index(
443
+ vertex_unique_edge_offsets.array, self.vertex_count() - 1, temporary_store=temporary_store
378
444
  )
379
- wp.synchronize_stream(wp.get_stream(device))
380
- edge_count = int(edge_count.array.numpy()[0])
381
- else:
382
- edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
445
+ )
383
446
 
384
447
  self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
385
448
  self._edge_tri_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
@@ -422,16 +485,6 @@ class Trimesh2D(Geometry):
422
485
 
423
486
  boundary_mask.release()
424
487
 
425
- def _compute_deformation_gradients(self):
426
- self._deformation_gradients = wp.empty(dtype=wp.mat22f, device=self.positions.device, shape=(self.cell_count()))
427
-
428
- wp.launch(
429
- kernel=Trimesh2D._compute_deformation_gradients_kernel,
430
- dim=self._deformation_gradients.shape,
431
- device=self._deformation_gradients.device,
432
- inputs=[self.tri_vertex_indices, self.positions, self._deformation_gradients],
433
- )
434
-
435
488
  @wp.kernel
436
489
  def _count_starting_edges_kernel(
437
490
  tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
@@ -560,18 +613,16 @@ class Trimesh2D(Geometry):
560
613
  edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])
561
614
 
562
615
  @wp.kernel
563
- def _compute_deformation_gradients_kernel(
616
+ def _compute_tri_bounds(
564
617
  tri_vertex_indices: wp.array2d(dtype=int),
565
- positions: wp.array(dtype=wp.vec2f),
566
- transforms: wp.array(dtype=wp.mat22f),
618
+ positions: wp.array(dtype=wp.vec2),
619
+ lowers: wp.array(dtype=wp.vec3),
620
+ uppers: wp.array(dtype=wp.vec3),
567
621
  ):
568
622
  t = wp.tid()
569
-
570
623
  p0 = positions[tri_vertex_indices[t, 0]]
571
624
  p1 = positions[tri_vertex_indices[t, 1]]
572
625
  p2 = positions[tri_vertex_indices[t, 2]]
573
626
 
574
- e1 = p1 - p0
575
- e2 = p2 - p0
576
-
577
- transforms[t] = wp.mat22(e1, e2)
627
+ lowers[t] = wp.vec3(wp.min(wp.min(p0[0], p1[0]), p2[0]), wp.min(wp.min(p0[1], p1[1]), p2[1]), 0.0)
628
+ uppers[t] = wp.vec3(wp.max(wp.max(p0[0], p1[0]), p2[0]), wp.max(wp.max(p0[1], p1[1]), p2[1]), 0.0)