warp-lang 1.4.1__py3-none-macosx_10_13_universal2.whl → 1.5.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  import warp as wp
4
4
  from warp.fem.cache import (
@@ -22,9 +22,8 @@ from .geometry import Geometry
22
22
 
23
23
 
24
24
  @wp.struct
25
- class Trimesh2DCellArg:
25
+ class TrimeshCellArg:
26
26
  tri_vertex_indices: wp.array2d(dtype=int)
27
- positions: wp.array(dtype=wp.vec2)
28
27
 
29
28
  # for neighbor cell lookup
30
29
  vertex_tri_offsets: wp.array(dtype=int)
@@ -35,19 +34,27 @@ class Trimesh2DCellArg:
35
34
 
36
35
 
37
36
  @wp.struct
38
- class Trimesh2DSideArg:
39
- cell_arg: Trimesh2DCellArg
37
+ class TrimeshSideArg:
38
+ cell_arg: TrimeshCellArg
40
39
  edge_vertex_indices: wp.array(dtype=wp.vec2i)
41
40
  edge_tri_indices: wp.array(dtype=wp.vec2i)
42
41
 
43
42
 
44
- _NULL_BVH = wp.constant(wp.uint64(-1))
43
+ _NULL_BVH = wp.constant(wp.uint64(0))
45
44
 
46
45
 
47
- class Trimesh2D(Geometry):
48
- """Two-dimensional triangular mesh geometry"""
46
+ @wp.func
47
+ def _bvh_vec(v: wp.vec3):
48
+ return v
49
49
 
50
- dimension = 2
50
+
51
+ @wp.func
52
+ def _bvh_vec(v: wp.vec2):
53
+ return wp.vec3(v[0], v[1], 0.0)
54
+
55
+
56
+ class Trimesh(Geometry):
57
+ """Triangular mesh geometry"""
51
58
 
52
59
  def __init__(
53
60
  self,
@@ -57,11 +64,11 @@ class Trimesh2D(Geometry):
57
64
  temporary_store: Optional[TemporaryStore] = None,
58
65
  ):
59
66
  """
60
- Constructs a two-dimensional triangular mesh.
67
+ Constructs a D-dimensional triangular mesh.
61
68
 
62
69
  Args:
63
70
  tri_vertex_indices: warp array of shape (num_tris, 3) containing vertex indices for each tri
64
- positions: warp array of shape (num_vertices, 2) containing 2d position for each vertex
71
+ positions: warp array of shape (num_vertices, D) containing the position of each vertex
65
72
  temporary_store: shared pool from which to allocate temporary arrays
66
73
  build_bvh: Whether to also build the triangle BVH, which is necessary for the global `fem.lookup` operator to function without initial guess
67
74
  """
@@ -79,6 +86,16 @@ class Trimesh2D(Geometry):
79
86
  if build_bvh:
80
87
  self._build_bvh()
81
88
 
89
+ # Flip edges so that normals point away from inner cell
90
+ wp.launch(
91
+ kernel=self._orient_edges,
92
+ device=positions.device,
93
+ dim=self.side_count(),
94
+ inputs=[self._edge_vertex_indices, self._edge_tri_indices, self.tri_vertex_indices, self.positions],
95
+ )
96
+
97
+ self._make_default_dependent_implementations()
98
+
82
99
  def update_bvh(self, force_rebuild: bool = False):
83
100
  """
84
101
  Refits the BVH, or rebuilds it from scratch if `force_rebuild` is ``True``.
@@ -88,7 +105,7 @@ class Trimesh2D(Geometry):
88
105
  return self.build_bvh()
89
106
 
90
107
  wp.launch(
91
- Trimesh2D._compute_tri_bounds,
108
+ Trimesh._compute_tri_bounds,
92
109
  self.tri_vertex_indices,
93
110
  self.positions,
94
111
  self._tri_bvh.lowers,
@@ -100,7 +117,7 @@ class Trimesh2D(Geometry):
100
117
  lowers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
101
118
  uppers = wp.array(shape=self.cell_count(), dtype=wp.vec3, device=self.positions.device)
102
119
  wp.launch(
103
- Trimesh2D._compute_tri_bounds,
120
+ Trimesh._compute_tri_bounds,
104
121
  device=self.positions.device,
105
122
  dim=self.cell_count(),
106
123
  inputs=[self.tri_vertex_indices, self.positions, lowers, uppers],
@@ -133,70 +150,76 @@ class Trimesh2D(Geometry):
133
150
  def edge_vertex_indices(self) -> wp.array:
134
151
  return self._edge_vertex_indices
135
152
 
136
- CellArg = Trimesh2DCellArg
137
- SideArg = Trimesh2DSideArg
138
-
139
153
  @wp.struct
140
154
  class SideIndexArg:
141
155
  boundary_edge_indices: wp.array(dtype=int)
142
156
 
143
- # Geometry device interface
144
-
145
157
  @cached_arg_value
146
- def cell_arg_value(self, device) -> CellArg:
147
- args = self.CellArg()
158
+ def _cell_topo_arg_value(self, device):
159
+ args = TrimeshCellArg()
148
160
 
149
161
  args.tri_vertex_indices = self.tri_vertex_indices.to(device)
150
- args.positions = self.positions.to(device)
151
162
  args.vertex_tri_offsets = self._vertex_tri_offsets.to(device)
152
163
  args.vertex_tri_indices = self._vertex_tri_indices.to(device)
153
- args.tri_bvh = _NULL_BVH if self._tri_bvh is None else self._tri_bvh.id
154
164
 
155
165
  return args
156
166
 
157
- @wp.func
158
- def cell_position(args: CellArg, s: Sample):
159
- tri_idx = args.tri_vertex_indices[s.element_index]
160
- return (
161
- s.element_coords[0] * args.positions[tri_idx[0]]
162
- + s.element_coords[1] * args.positions[tri_idx[1]]
163
- + s.element_coords[2] * args.positions[tri_idx[2]]
164
- )
167
+ @cached_arg_value
168
+ def _side_topo_arg_value(self, device):
169
+ args = TrimeshSideArg()
165
170
 
166
- @wp.func
167
- def cell_deformation_gradient(args: CellArg, s: Sample):
168
- p0 = args.positions[args.tri_vertex_indices[s.element_index, 0]]
169
- p1 = args.positions[args.tri_vertex_indices[s.element_index, 1]]
170
- p2 = args.positions[args.tri_vertex_indices[s.element_index, 2]]
171
- return wp.mat22(p1 - p0, p2 - p0)
171
+ args.cell_arg = self._cell_topo_arg_value(device)
172
+ args.edge_vertex_indices = self._edge_vertex_indices.to(device)
173
+ args.edge_tri_indices = self._edge_tri_indices.to(device)
172
174
 
173
- @wp.func
174
- def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
175
- return wp.inverse(Trimesh2D.cell_deformation_gradient(args, s))
175
+ return args
176
+
177
+ def _bvh_id(self, device):
178
+ if self._tri_bvh is None or self._tri_bvh.device != device:
179
+ return _NULL_BVH
180
+ return self._tri_bvh.id
181
+
182
+ def cell_arg_value(self, device):
183
+ args = self.CellArg()
184
+
185
+ args.topology = self._cell_topo_arg_value(device)
186
+ args.positions = self.positions.to(device)
187
+ args.topology.tri_bvh = self._bvh_id(device)
188
+
189
+ return args
190
+
191
+ def side_arg_value(self, device):
192
+ args = self.SideArg()
193
+
194
+ args.topology = self._side_topo_arg_value(device)
195
+ args.positions = self.positions.to(device)
196
+ args.topology.cell_arg.tri_bvh = self._bvh_id(device)
197
+
198
+ return args
176
199
 
177
200
  @wp.func
178
- def _project_on_tri(args: CellArg, pos: wp.vec2, tri_index: int):
179
- p0 = args.positions[args.tri_vertex_indices[tri_index, 0]]
201
+ def _project_on_tri(args: TrimeshCellArg, positions: wp.array(dtype=Any), pos: Any, tri_index: int):
202
+ p0 = positions[args.tri_vertex_indices[tri_index, 0]]
180
203
 
181
204
  q = pos - p0
182
- e1 = args.positions[args.tri_vertex_indices[tri_index, 1]] - p0
183
- e2 = args.positions[args.tri_vertex_indices[tri_index, 2]] - p0
205
+ e1 = positions[args.tri_vertex_indices[tri_index, 1]] - p0
206
+ e2 = positions[args.tri_vertex_indices[tri_index, 2]] - p0
184
207
 
185
208
  dist, coords = project_on_tri_at_origin(q, e1, e2)
186
209
  return dist, coords
187
210
 
188
211
  @wp.func
189
- def _bvh_lookup(args: CellArg, pos: wp.vec2):
212
+ def _bvh_lookup(args: TrimeshCellArg, positions: wp.array(dtype=Any), pos: Any):
190
213
  closest_tri = int(NULL_ELEMENT_INDEX)
191
214
  closest_coords = Coords(OUTSIDE)
192
215
  closest_dist = float(1.0e8)
193
216
 
194
217
  if args.tri_bvh != _NULL_BVH:
195
- pos3 = wp.vec3(pos[0], pos[1], 0.0)
196
- query = wp.bvh_query_aabb(args.tri_bvh, pos3, pos3)
218
+ bvh_pos = _bvh_vec(pos)
219
+ query = wp.bvh_query_aabb(args.tri_bvh, bvh_pos, bvh_pos)
197
220
  tri = int(0)
198
221
  while wp.bvh_query_next(query, tri):
199
- dist, coords = Trimesh2D._project_on_tri(args, pos, tri)
222
+ dist, coords = Trimesh._project_on_tri(args, positions, pos, tri)
200
223
  if dist <= closest_dist:
201
224
  closest_dist = dist
202
225
  closest_tri = tri
@@ -205,39 +228,23 @@ class Trimesh2D(Geometry):
205
228
  return closest_dist, closest_tri, closest_coords
206
229
 
207
230
  @wp.func
208
- def cell_lookup(args: CellArg, pos: wp.vec2):
209
- closest_dist, closest_tri, closest_coords = Trimesh2D._bvh_lookup(args, pos)
210
-
211
- return make_free_sample(closest_tri, closest_coords)
212
-
213
- @wp.func
214
- def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
215
- closest_dist, closest_tri, closest_coords = Trimesh2D._bvh_lookup(args, pos)
216
-
217
- if closest_tri == NULL_ELEMENT_INDEX:
218
- # nothing found yet, bvh may not be available or outside mesh
219
- for v in range(3):
220
- vtx = args.tri_vertex_indices[guess.element_index, v]
221
- tri_beg = args.vertex_tri_offsets[vtx]
222
- tri_end = args.vertex_tri_offsets[vtx + 1]
223
-
224
- for t in range(tri_beg, tri_end):
225
- tri = args.vertex_tri_indices[t]
226
- dist, coords = Trimesh2D._project_on_tri(args, pos, tri)
227
- if dist <= closest_dist:
228
- closest_dist = dist
229
- closest_tri = tri
230
- closest_coords = coords
231
+ def _cell_neighbor_lookup(args: TrimeshCellArg, positions: wp.array(dtype=Any), pos: Any, cell_index: int):
232
+ closest_dist = float(1.0e8)
231
233
 
232
- return make_free_sample(closest_tri, closest_coords)
234
+ for v in range(3):
235
+ vtx = args.tri_vertex_indices[cell_index, v]
236
+ tri_beg = args.vertex_tri_offsets[vtx]
237
+ tri_end = args.vertex_tri_offsets[vtx + 1]
233
238
 
234
- @wp.func
235
- def cell_measure(args: CellArg, s: Sample):
236
- return 0.5 * wp.abs(wp.determinant(Trimesh2D.cell_deformation_gradient(args, s)))
239
+ for t in range(tri_beg, tri_end):
240
+ tri = args.vertex_tri_indices[t]
241
+ dist, coords = Trimesh._project_on_tri(args, positions, pos, tri)
242
+ if dist <= closest_dist:
243
+ closest_dist = dist
244
+ closest_tri = tri
245
+ closest_coords = coords
237
246
 
238
- @wp.func
239
- def cell_normal(args: CellArg, s: Sample):
240
- return wp.vec2(0.0)
247
+ return closest_dist, closest_tri, closest_coords
241
248
 
242
249
  @cached_arg_value
243
250
  def side_index_arg_value(self, device) -> SideIndexArg:
@@ -253,77 +260,10 @@ class Trimesh2D(Geometry):
253
260
 
254
261
  return args.boundary_edge_indices[boundary_side_index]
255
262
 
256
- @cached_arg_value
257
- def side_arg_value(self, device) -> CellArg:
258
- args = self.SideArg()
259
-
260
- args.cell_arg = self.cell_arg_value(device)
261
- args.edge_vertex_indices = self._edge_vertex_indices.to(device)
262
- args.edge_tri_indices = self._edge_tri_indices.to(device)
263
-
264
- return args
265
-
266
- @wp.func
267
- def side_position(args: SideArg, s: Sample):
268
- edge_idx = args.edge_vertex_indices[s.element_index]
269
- return (1.0 - s.element_coords[0]) * args.cell_arg.positions[edge_idx[0]] + s.element_coords[
270
- 0
271
- ] * args.cell_arg.positions[edge_idx[1]]
272
-
273
263
  @wp.func
274
- def side_deformation_gradient(args: SideArg, s: Sample):
275
- edge_idx = args.edge_vertex_indices[s.element_index]
276
- v0 = args.cell_arg.positions[edge_idx[0]]
277
- v1 = args.cell_arg.positions[edge_idx[1]]
278
- return v1 - v0
279
-
280
- @wp.func
281
- def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
282
- cell_index = Trimesh2D.side_inner_cell_index(args, s.element_index)
283
- s_cell = make_free_sample(cell_index, Coords())
284
- return Trimesh2D.cell_inverse_deformation_gradient(args.cell_arg, s_cell)
285
-
286
- @wp.func
287
- def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
288
- cell_index = Trimesh2D.side_outer_cell_index(args, s.element_index)
289
- s_cell = make_free_sample(cell_index, Coords())
290
- return Trimesh2D.cell_inverse_deformation_gradient(args.cell_arg, s_cell)
291
-
292
- @wp.func
293
- def side_measure(args: SideArg, s: Sample):
294
- edge_idx = args.edge_vertex_indices[s.element_index]
295
- v0 = args.cell_arg.positions[edge_idx[0]]
296
- v1 = args.cell_arg.positions[edge_idx[1]]
297
- return wp.length(v1 - v0)
298
-
299
- @wp.func
300
- def side_measure_ratio(args: SideArg, s: Sample):
301
- inner = Trimesh2D.side_inner_cell_index(args, s.element_index)
302
- outer = Trimesh2D.side_outer_cell_index(args, s.element_index)
303
- return Trimesh2D.side_measure(args, s) / wp.min(
304
- Trimesh2D.cell_measure(args.cell_arg, make_free_sample(inner, Coords())),
305
- Trimesh2D.cell_measure(args.cell_arg, make_free_sample(outer, Coords())),
306
- )
307
-
308
- @wp.func
309
- def side_normal(args: SideArg, s: Sample):
310
- edge_idx = args.edge_vertex_indices[s.element_index]
311
- v0 = args.cell_arg.positions[edge_idx[0]]
312
- v1 = args.cell_arg.positions[edge_idx[1]]
313
- e = v1 - v0
314
-
315
- return wp.normalize(wp.vec2(-e[1], e[0]))
316
-
317
- @wp.func
318
- def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
319
- return arg.edge_tri_indices[side_index][0]
320
-
321
- @wp.func
322
- def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
323
- return arg.edge_tri_indices[side_index][1]
324
-
325
- @wp.func
326
- def edge_to_tri_coords(args: SideArg, side_index: ElementIndex, tri_index: ElementIndex, side_coords: Coords):
264
+ def _edge_to_tri_coords(
265
+ args: TrimeshSideArg, side_index: ElementIndex, tri_index: ElementIndex, side_coords: Coords
266
+ ):
327
267
  edge_vidx = args.edge_vertex_indices[side_index]
328
268
  tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
329
269
 
@@ -351,18 +291,8 @@ class Trimesh2D(Geometry):
351
291
  return Coords(cx, cy, cz)
352
292
 
353
293
  @wp.func
354
- def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
355
- inner_cell_index = Trimesh2D.side_inner_cell_index(args, side_index)
356
- return Trimesh2D.edge_to_tri_coords(args, side_index, inner_cell_index, side_coords)
357
-
358
- @wp.func
359
- def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
360
- outer_cell_index = Trimesh2D.side_outer_cell_index(args, side_index)
361
- return Trimesh2D.edge_to_tri_coords(args, side_index, outer_cell_index, side_coords)
362
-
363
- @wp.func
364
- def side_from_cell_coords(
365
- args: SideArg,
294
+ def _tri_to_edge_coords(
295
+ args: TrimeshSideArg,
366
296
  side_index: ElementIndex,
367
297
  tri_index: ElementIndex,
368
298
  tri_coords: Coords,
@@ -384,10 +314,6 @@ class Trimesh2D(Geometry):
384
314
  tri_coords[start] + tri_coords[end] > 0.999, Coords(OUTSIDE), Coords(tri_coords[end], 0.0, 0.0)
385
315
  )
386
316
 
387
- @wp.func
388
- def side_to_cell_arg(side_arg: SideArg):
389
- return side_arg.cell_arg
390
-
391
317
  def _build_topology(self, temporary_store: TemporaryStore):
392
318
  from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
393
319
  from warp.utils import array_scan
@@ -409,7 +335,7 @@ class Trimesh2D(Geometry):
409
335
 
410
336
  # Count face edges starting at each vertex
411
337
  wp.launch(
412
- kernel=Trimesh2D._count_starting_edges_kernel,
338
+ kernel=Trimesh._count_starting_edges_kernel,
413
339
  device=device,
414
340
  dim=self.cell_count(),
415
341
  inputs=[self.tri_vertex_indices, vertex_start_edge_count.array],
@@ -420,7 +346,7 @@ class Trimesh2D(Geometry):
420
346
  # Count number of unique edges (deduplicate across faces)
421
347
  vertex_unique_edge_count = vertex_start_edge_count
422
348
  wp.launch(
423
- kernel=Trimesh2D._count_unique_starting_edges_kernel,
349
+ kernel=Trimesh._count_unique_starting_edges_kernel,
424
350
  device=device,
425
351
  dim=self.vertex_count(),
426
352
  inputs=[
@@ -451,7 +377,7 @@ class Trimesh2D(Geometry):
451
377
 
452
378
  # Compress edge data
453
379
  wp.launch(
454
- kernel=Trimesh2D._compress_edges_kernel,
380
+ kernel=Trimesh._compress_edges_kernel,
455
381
  device=device,
456
382
  dim=self.vertex_count(),
457
383
  inputs=[
@@ -472,14 +398,6 @@ class Trimesh2D(Geometry):
472
398
  vertex_edge_ends.release()
473
399
  vertex_edge_tris.release()
474
400
 
475
- # Flip normals if necessary
476
- wp.launch(
477
- kernel=Trimesh2D._flip_edge_normals,
478
- device=device,
479
- dim=self.side_count(),
480
- inputs=[self._edge_vertex_indices, self._edge_tri_indices, self.tri_vertex_indices, self.positions],
481
- )
482
-
483
401
  boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
484
402
  self._boundary_edge_indices = boundary_edge_indices.detach()
485
403
 
@@ -542,7 +460,7 @@ class Trimesh2D(Geometry):
542
460
  other_v = wp.max(v0, v1)
543
461
 
544
462
  # Check if other_v has been seen
545
- seen_idx = Trimesh2D._find(other_v, edge_ends, edge_beg, edge_cur)
463
+ seen_idx = Trimesh._find(other_v, edge_ends, edge_beg, edge_cur)
546
464
 
547
465
  if seen_idx == -1:
548
466
  edge_ends[edge_cur] = other_v
@@ -586,7 +504,138 @@ class Trimesh2D(Geometry):
586
504
  boundary_mask[edge_index] = 0
587
505
 
588
506
  @wp.kernel
589
- def _flip_edge_normals(
507
+ def _compute_tri_bounds(
508
+ tri_vertex_indices: wp.array2d(dtype=int),
509
+ positions: wp.array(dtype=wp.vec2),
510
+ lowers: wp.array(dtype=wp.vec3),
511
+ uppers: wp.array(dtype=wp.vec3),
512
+ ):
513
+ t = wp.tid()
514
+ p0 = _bvh_vec(positions[tri_vertex_indices[t, 0]])
515
+ p1 = _bvh_vec(positions[tri_vertex_indices[t, 1]])
516
+ p2 = _bvh_vec(positions[tri_vertex_indices[t, 2]])
517
+
518
+ lowers[t] = wp.vec3(
519
+ wp.min(wp.min(p0[0], p1[0]), p2[0]),
520
+ wp.min(wp.min(p0[1], p1[1]), p2[1]),
521
+ wp.min(wp.min(p0[2], p1[2]), p2[2]),
522
+ )
523
+ uppers[t] = wp.vec3(
524
+ wp.max(wp.max(p0[0], p1[0]), p2[0]),
525
+ wp.max(wp.max(p0[1], p1[1]), p2[1]),
526
+ wp.max(wp.max(p0[2], p1[2]), p2[2]),
527
+ )
528
+
529
+
530
+ @wp.struct
531
+ class Trimesh2DCellArg:
532
+ topology: TrimeshCellArg
533
+ positions: wp.array(dtype=wp.vec2)
534
+
535
+
536
+ @wp.struct
537
+ class Trimesh2DSideArg:
538
+ topology: TrimeshSideArg
539
+ positions: wp.array(dtype=wp.vec2)
540
+
541
+
542
+ class Trimesh2D(Trimesh):
543
+ """2D Triangular mesh geometry"""
544
+
545
+ dimension = 2
546
+ CellArg = Trimesh2DCellArg
547
+ SideArg = Trimesh2DSideArg
548
+
549
+ @wp.func
550
+ def cell_position(args: CellArg, s: Sample):
551
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
552
+ return (
553
+ s.element_coords[0] * args.positions[tri_idx[0]]
554
+ + s.element_coords[1] * args.positions[tri_idx[1]]
555
+ + s.element_coords[2] * args.positions[tri_idx[2]]
556
+ )
557
+
558
+ @wp.func
559
+ def cell_deformation_gradient(args: CellArg, s: Sample):
560
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
561
+ p0 = args.positions[tri_idx[0]]
562
+ p1 = args.positions[tri_idx[1]]
563
+ p2 = args.positions[tri_idx[2]]
564
+ return wp.mat22(p1 - p0, p2 - p0)
565
+
566
+ @wp.func
567
+ def cell_lookup(args: CellArg, pos: wp.vec2):
568
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
569
+
570
+ return make_free_sample(closest_tri, closest_coords)
571
+
572
+ @wp.func
573
+ def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
574
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
575
+
576
+ if closest_tri == NULL_ELEMENT_INDEX:
577
+ closest_dist, closest_tri, closest_coords = Trimesh._cell_neighbor_lookup(
578
+ args.topology, args.positions, pos, guess.element_index
579
+ )
580
+
581
+ return make_free_sample(closest_tri, closest_coords)
582
+
583
+ @wp.func
584
+ def side_position(args: SideArg, s: Sample):
585
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
586
+ return (1.0 - s.element_coords[0]) * args.positions[edge_idx[0]] + s.element_coords[0] * args.positions[
587
+ edge_idx[1]
588
+ ]
589
+
590
+ @wp.func
591
+ def side_deformation_gradient(args: SideArg, s: Sample):
592
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
593
+ v0 = args.positions[edge_idx[0]]
594
+ v1 = args.positions[edge_idx[1]]
595
+ return v1 - v0
596
+
597
+ @wp.func
598
+ def side_normal(args: SideArg, s: Sample):
599
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
600
+ v0 = args.positions[edge_idx[0]]
601
+ v1 = args.positions[edge_idx[1]]
602
+ e = v1 - v0
603
+
604
+ return wp.normalize(wp.vec2(-e[1], e[0]))
605
+
606
+ @wp.func
607
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
608
+ return arg.topology.edge_tri_indices[side_index][0]
609
+
610
+ @wp.func
611
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
612
+ return arg.topology.edge_tri_indices[side_index][1]
613
+
614
+ @wp.func
615
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
616
+ inner_cell_index = Trimesh2D.side_inner_cell_index(args, side_index)
617
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, inner_cell_index, side_coords)
618
+
619
+ @wp.func
620
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
621
+ outer_cell_index = Trimesh2D.side_outer_cell_index(args, side_index)
622
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, outer_cell_index, side_coords)
623
+
624
+ @wp.func
625
+ def side_from_cell_coords(
626
+ args: SideArg,
627
+ side_index: ElementIndex,
628
+ tri_index: ElementIndex,
629
+ tri_coords: Coords,
630
+ ):
631
+ return Trimesh._tri_to_edge_coords(args.topology, side_index, tri_index, tri_coords)
632
+
633
+ @wp.func
634
+ def side_to_cell_arg(side_arg: SideArg):
635
+ return Trimesh2DCellArg(side_arg.topology.cell_arg, side_arg.positions)
636
+
637
+ @wp.kernel
638
+ def _orient_edges(
590
639
  edge_vertex_indices: wp.array(dtype=wp.vec2i),
591
640
  edge_tri_indices: wp.array(dtype=wp.vec2i),
592
641
  tri_vertex_indices: wp.array2d(dtype=int),
@@ -612,17 +661,136 @@ class Trimesh2D(Geometry):
612
661
  if wp.dot(tri_centroid - edge_center, edge_normal) > 0.0:
613
662
  edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])
614
663
 
664
+
665
+ @wp.struct
666
+ class Trimesh3DCellArg:
667
+ topology: TrimeshCellArg
668
+ positions: wp.array(dtype=wp.vec3)
669
+
670
+
671
+ @wp.struct
672
+ class Trimesh3DSideArg:
673
+ topology: TrimeshSideArg
674
+ positions: wp.array(dtype=wp.vec3)
675
+
676
+
677
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
678
+
679
+
680
+ class Trimesh3D(Trimesh):
681
+ """3D Triangular mesh geometry"""
682
+
683
+ dimension = 3
684
+ CellArg = Trimesh3DCellArg
685
+ SideArg = Trimesh3DSideArg
686
+
687
+ @wp.func
688
+ def cell_position(args: CellArg, s: Sample):
689
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
690
+ return (
691
+ s.element_coords[0] * args.positions[tri_idx[0]]
692
+ + s.element_coords[1] * args.positions[tri_idx[1]]
693
+ + s.element_coords[2] * args.positions[tri_idx[2]]
694
+ )
695
+
696
+ @wp.func
697
+ def cell_deformation_gradient(args: CellArg, s: Sample):
698
+ tri_idx = args.topology.tri_vertex_indices[s.element_index]
699
+ p0 = args.positions[tri_idx[0]]
700
+ p1 = args.positions[tri_idx[1]]
701
+ p2 = args.positions[tri_idx[2]]
702
+ return _mat32(p1 - p0, p2 - p0)
703
+
704
+ @wp.func
705
+ def cell_lookup(args: CellArg, pos: wp.vec3):
706
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
707
+
708
+ return make_free_sample(closest_tri, closest_coords)
709
+
710
+ @wp.func
711
+ def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
712
+ closest_dist, closest_tri, closest_coords = Trimesh._bvh_lookup(args.topology, args.positions, pos)
713
+
714
+ if closest_tri == NULL_ELEMENT_INDEX:
715
+ closest_dist, closest_tri, closest_coords = Trimesh._cell_neighbor_lookup(
716
+ args.topology, args.positions, pos, guess.element_index
717
+ )
718
+
719
+ return make_free_sample(closest_tri, closest_coords)
720
+
721
+ @wp.func
722
+ def side_position(args: SideArg, s: Sample):
723
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
724
+ return (1.0 - s.element_coords[0]) * args.positions[edge_idx[0]] + s.element_coords[0] * args.positions[
725
+ edge_idx[1]
726
+ ]
727
+
728
+ @wp.func
729
+ def side_deformation_gradient(args: SideArg, s: Sample):
730
+ edge_idx = args.topology.edge_vertex_indices[s.element_index]
731
+ v0 = args.positions[edge_idx[0]]
732
+ v1 = args.positions[edge_idx[1]]
733
+ return v1 - v0
734
+
735
+ @wp.func
736
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
737
+ return arg.topology.edge_tri_indices[side_index][0]
738
+
739
+ @wp.func
740
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
741
+ return arg.topology.edge_tri_indices[side_index][1]
742
+
743
+ @wp.func
744
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
745
+ inner_cell_index = Trimesh3D.side_inner_cell_index(args, side_index)
746
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, inner_cell_index, side_coords)
747
+
748
+ @wp.func
749
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
750
+ outer_cell_index = Trimesh3D.side_outer_cell_index(args, side_index)
751
+ return Trimesh._edge_to_tri_coords(args.topology, side_index, outer_cell_index, side_coords)
752
+
753
+ @wp.func
754
+ def side_from_cell_coords(
755
+ args: SideArg,
756
+ side_index: ElementIndex,
757
+ tri_index: ElementIndex,
758
+ tri_coords: Coords,
759
+ ):
760
+ return Trimesh._tri_to_edge_coords(args.topology, side_index, tri_index, tri_coords)
761
+
762
+ @wp.func
763
+ def side_to_cell_arg(side_arg: SideArg):
764
+ return Trimesh3DCellArg(side_arg.topology.cell_arg, side_arg.positions)
765
+
615
766
  @wp.kernel
616
- def _compute_tri_bounds(
767
+ def _orient_edges(
768
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
769
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
617
770
  tri_vertex_indices: wp.array2d(dtype=int),
618
- positions: wp.array(dtype=wp.vec2),
619
- lowers: wp.array(dtype=wp.vec3),
620
- uppers: wp.array(dtype=wp.vec3),
771
+ positions: wp.array(dtype=wp.vec3),
621
772
  ):
622
- t = wp.tid()
623
- p0 = positions[tri_vertex_indices[t, 0]]
624
- p1 = positions[tri_vertex_indices[t, 1]]
625
- p2 = positions[tri_vertex_indices[t, 2]]
773
+ e = wp.tid()
774
+
775
+ tri = edge_tri_indices[e][0]
776
+
777
+ tri_vidx = tri_vertex_indices[tri]
778
+ edge_vidx = edge_vertex_indices[e]
779
+
780
+ t0 = positions[tri_vidx[0]]
781
+ t1 = positions[tri_vidx[1]]
782
+ t2 = positions[tri_vidx[2]]
783
+
784
+ tri_centroid = (t0 + t1 + t2) / 3.0
785
+ tri_normal = wp.cross(t1 - t0, t2 - t0)
626
786
 
627
- lowers[t] = wp.vec3(wp.min(wp.min(p0[0], p1[0]), p2[0]), wp.min(wp.min(p0[1], p1[1]), p2[1]), 0.0)
628
- uppers[t] = wp.vec3(wp.max(wp.max(p0[0], p1[0]), p2[0]), wp.max(wp.max(p0[1], p1[1]), p2[1]), 0.0)
787
+ v0 = positions[edge_vidx[0]]
788
+ v1 = positions[edge_vidx[1]]
789
+
790
+ edge_center = 0.5 * (v1 + v0)
791
+ edge_vec = v1 - v0
792
+ edge_normal = wp.cross(edge_vec, tri_normal)
793
+
794
+ # if edge normal points toward first triangle centroid, flip indices
795
+ if wp.dot(tri_centroid - edge_center, edge_normal) > 0.0:
796
+ edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])