warp-lang 1.4.1__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,8 @@ import warp as wp
4
4
  from warp.fem import cache
5
5
  from warp.fem.types import Coords
6
6
 
7
+ from .shape_function import ShapeFunction
8
+
7
9
 
8
10
  def _triangle_node_index(tx: int, ty: int, degree: int):
9
11
  VERTEX_NODE_COUNT = 3
@@ -34,17 +36,43 @@ def _triangle_node_index(tx: int, ty: int, degree: int):
34
36
  return vertex_edge_node_count + _triangle_node_index(tx - 1, ty - 1, degree - 3)
35
37
 
36
38
 
37
- class Triangle2DPolynomialShapeFunctions:
39
+ class TriangleShapeFunction(ShapeFunction):
38
40
  VERTEX = wp.constant(0)
39
41
  EDGE = wp.constant(1)
40
42
  INTERIOR = wp.constant(2)
41
43
 
44
+ VERTEX_NODE_COUNT: int
45
+ """Number of shape function nodes per vertex"""
46
+
47
+ EDGE_NODE_COUNT: int
48
+ """Number of shape function nodes per triangle edge (excluding vertex nodes)"""
49
+
50
+ INTERIOR_NODE_COUNT: int
51
+ """Number of shape function nodes per triangle (excluding edge and vertex nodes)"""
52
+
53
+ @staticmethod
54
+ def node_type_and_index(node_index_in_elt: int):
55
+ pass
56
+
57
+ @wp.func
58
+ def _vertex_coords(vidx: int):
59
+ return wp.vec2(
60
+ float(vidx == 1),
61
+ float(vidx == 2),
62
+ )
63
+
64
+
65
+ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
42
66
  def __init__(self, degree: int):
43
67
  self.ORDER = wp.constant(degree)
44
68
 
45
69
  self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) // 2)
46
70
  self.NODES_PER_SIDE = wp.constant(degree + 1)
47
71
 
72
+ self.VERTEX_NODE_COUNT = wp.constant(1)
73
+ self.EDGE_NODE_COUNT = wp.constant(degree - 1)
74
+ self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 2) * max(0, degree - 1) // 2)
75
+
48
76
  triangle_coords = np.empty((self.NODES_PER_ELEMENT, 2), dtype=int)
49
77
 
50
78
  for tx in range(degree + 1):
@@ -68,7 +96,7 @@ class Triangle2DPolynomialShapeFunctions:
68
96
  def node_triangle_coordinates(
69
97
  node_index_in_elt: int,
70
98
  ):
71
- return wp.vec2i(NODE_TRIANGLE_COORDS[node_index_in_elt, 0], NODE_TRIANGLE_COORDS[node_index_in_elt, 1])
99
+ return NODE_TRIANGLE_COORDS[node_index_in_elt]
72
100
 
73
101
  return cache.get_func(node_triangle_coordinates, self.name)
74
102
 
@@ -79,12 +107,12 @@ class Triangle2DPolynomialShapeFunctions:
79
107
  node_index_in_elt: int,
80
108
  ):
81
109
  if node_index_in_elt < 3:
82
- return Triangle2DPolynomialShapeFunctions.VERTEX, node_index_in_elt
110
+ return TrianglePolynomialShapeFunctions.VERTEX, node_index_in_elt
83
111
 
84
112
  if node_index_in_elt < 3 * ORDER:
85
- return Triangle2DPolynomialShapeFunctions.EDGE, (node_index_in_elt - 3)
113
+ return TrianglePolynomialShapeFunctions.EDGE, (node_index_in_elt - 3)
86
114
 
87
- return Triangle2DPolynomialShapeFunctions.INTERIOR, (node_index_in_elt - 3 * ORDER)
115
+ return TrianglePolynomialShapeFunctions.INTERIOR, (node_index_in_elt - 3 * ORDER)
88
116
 
89
117
  return cache.get_func(node_type_and_index, self.name)
90
118
 
@@ -125,9 +153,9 @@ class Triangle2DPolynomialShapeFunctions:
125
153
  def node_quadrature_weight(node_index_in_element: int):
126
154
  node_type, type_index = self.node_type_and_type_index(node_index_in_element)
127
155
 
128
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
156
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
129
157
  return VERTEX_WEIGHT
130
- elif node_type == Triangle2DPolynomialShapeFunctions.EDGE:
158
+ elif node_type == TrianglePolynomialShapeFunctions.EDGE:
131
159
  return EDGE_WEIGHT
132
160
 
133
161
  return INTERIOR_WEIGHT
@@ -153,7 +181,7 @@ class Triangle2DPolynomialShapeFunctions:
153
181
  def trace_node_quadrature_weight(node_index_in_element: int):
154
182
  node_type, type_index = self.node_type_and_type_index(node_index_in_element)
155
183
 
156
- return wp.select(node_type == Triangle2DPolynomialShapeFunctions.VERTEX, EDGE_WEIGHT, VERTEX_WEIGHT)
184
+ return wp.select(node_type == TrianglePolynomialShapeFunctions.VERTEX, EDGE_WEIGHT, VERTEX_WEIGHT)
157
185
 
158
186
  return trace_node_quadrature_weight
159
187
 
@@ -172,7 +200,7 @@ class Triangle2DPolynomialShapeFunctions:
172
200
  ):
173
201
  node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
174
202
 
175
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
203
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
176
204
  # Vertex
177
205
  return coords[type_index] * (2.0 * coords[type_index] - 1.0)
178
206
 
@@ -187,11 +215,11 @@ class Triangle2DPolynomialShapeFunctions:
187
215
  ):
188
216
  node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
189
217
 
190
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
218
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
191
219
  # Vertex
192
220
  return 0.5 * coords[type_index] * (3.0 * coords[type_index] - 1.0) * (3.0 * coords[type_index] - 2.0)
193
221
 
194
- elif node_type == Triangle2DPolynomialShapeFunctions.EDGE:
222
+ elif node_type == TrianglePolynomialShapeFunctions.EDGE:
195
223
  # Edge
196
224
  edge = type_index // 2
197
225
  k = type_index - 2 * edge
@@ -233,7 +261,7 @@ class Triangle2DPolynomialShapeFunctions:
233
261
 
234
262
  dw_dc = wp.vec3(0.0)
235
263
 
236
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
264
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
237
265
  # Vertex
238
266
  dw_dc[type_index] = 4.0 * coords[type_index] - 1.0
239
267
 
@@ -255,13 +283,13 @@ class Triangle2DPolynomialShapeFunctions:
255
283
 
256
284
  dw_dc = wp.vec3(0.0)
257
285
 
258
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
286
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
259
287
  # Vertex
260
288
  dw_dc[type_index] = (
261
289
  0.5 * 27.0 * coords[type_index] * coords[type_index] - 9.0 * coords[type_index] + 1.0
262
290
  )
263
291
 
264
- elif node_type == Triangle2DPolynomialShapeFunctions.EDGE:
292
+ elif node_type == TrianglePolynomialShapeFunctions.EDGE:
265
293
  # Edge
266
294
  edge = type_index // 2
267
295
  k = type_index - 2 * edge
@@ -318,9 +346,9 @@ class Triangle2DPolynomialShapeFunctions:
318
346
  return cells[np.newaxis, :], np.array([cell_type], dtype=np.int8)
319
347
 
320
348
 
321
- class Triangle2DNonConformingPolynomialShapeFunctions:
349
+ class TriangleNonConformingPolynomialShapeFunctions(ShapeFunction):
322
350
  def __init__(self, degree: int):
323
- self._tri_shape = Triangle2DPolynomialShapeFunctions(degree=degree)
351
+ self._tri_shape = TrianglePolynomialShapeFunctions(degree=degree)
324
352
  self.ORDER = self._tri_shape.ORDER
325
353
  self.NODES_PER_ELEMENT = self._tri_shape.NODES_PER_ELEMENT
326
354
 
@@ -346,7 +374,7 @@ class Triangle2DNonConformingPolynomialShapeFunctions:
346
374
 
347
375
  @property
348
376
  def name(self) -> str:
349
- return f"Tri_P{self.ORDER}d"
377
+ return f"Tri_dP{self.ORDER}"
350
378
 
351
379
  def make_node_quadrature_weight(self):
352
380
  # Intrinsic quadrature -- precomputed integral of node shape functions
@@ -373,9 +401,9 @@ class Triangle2DNonConformingPolynomialShapeFunctions:
373
401
  def node_quadrature_weight(node_index_in_element: int):
374
402
  node_type, type_index = self._tri_shape.node_type_and_type_index(node_index_in_element)
375
403
 
376
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
404
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
377
405
  return VERTEX_WEIGHT
378
- elif node_type == Triangle2DPolynomialShapeFunctions.EDGE:
406
+ elif node_type == TrianglePolynomialShapeFunctions.EDGE:
379
407
  return EDGE_WEIGHT
380
408
 
381
409
  return INTERIOR_WEIGHT
@@ -436,3 +464,194 @@ class Triangle2DNonConformingPolynomialShapeFunctions:
436
464
  return INV_TRI_SCALE * grad
437
465
 
438
466
  return element_inner_weight_gradient
467
+
468
+
469
+ class TriangleNedelecFirstKindShapeFunctions(TriangleShapeFunction):
470
+ value = ShapeFunction.Value.CovariantVector
471
+
472
+ def __init__(self, degree: int):
473
+ if degree != 1:
474
+ raise NotImplementedError("Only linear Nédélec implemented right now")
475
+
476
+ self.ORDER = wp.constant(degree)
477
+
478
+ self.NODES_PER_ELEMENT = wp.constant(3)
479
+ self.NODES_PER_SIDE = wp.constant(1)
480
+
481
+ self.VERTEX_NODE_COUNT = wp.constant(0)
482
+ self.EDGE_NODE_COUNT = wp.constant(1)
483
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
484
+
485
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
486
+
487
+ @property
488
+ def name(self) -> str:
489
+ return f"TriN1_{self.ORDER}"
490
+
491
+ def _get_node_type_and_type_index(self):
492
+ @cache.dynamic_func(suffix=self.name)
493
+ def node_type_and_index(
494
+ node_index_in_elt: int,
495
+ ):
496
+ return TriangleShapeFunction.EDGE, node_index_in_elt
497
+
498
+ return node_type_and_index
499
+
500
+ def make_node_coords_in_element(self):
501
+ @cache.dynamic_func(suffix=self.name)
502
+ def node_coords_in_element(
503
+ node_index_in_elt: int,
504
+ ):
505
+ coords = Coords(0.5)
506
+ coords[(node_index_in_elt + 2) % 3] = 0.0
507
+ return coords
508
+
509
+ return node_coords_in_element
510
+
511
+ def make_node_quadrature_weight(self):
512
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
513
+
514
+ @cache.dynamic_func(suffix=self.name)
515
+ def node_quadrature_weight(node_index_in_element: int):
516
+ return 1.0 / float(NODES_PER_ELEMENT)
517
+
518
+ return node_quadrature_weight
519
+
520
+ def make_trace_node_quadrature_weight(self):
521
+ NODES_PER_SIDE = self.NODES_PER_SIDE
522
+
523
+ @cache.dynamic_func(suffix=self.name)
524
+ def trace_node_quadrature_weight(node_index_in_element: int):
525
+ return 1.0 / float(NODES_PER_SIDE)
526
+
527
+ return trace_node_quadrature_weight
528
+
529
+ @wp.func
530
+ def _vertex_coords(vidx: int):
531
+ return wp.vec2(
532
+ float(vidx == 1),
533
+ float(vidx == 2),
534
+ )
535
+
536
+ def make_element_inner_weight(self):
537
+ ORDER = self.ORDER
538
+
539
+ def element_inner_weight_linear(
540
+ coords: Coords,
541
+ node_index_in_elt: int,
542
+ ):
543
+ x = wp.vec2(coords[1], coords[2])
544
+ p = self._vertex_coords((node_index_in_elt + 2) % 3)
545
+
546
+ d = x - p
547
+ return wp.vec2(-d[1], d[0])
548
+
549
+ if ORDER == 1:
550
+ return cache.get_func(element_inner_weight_linear, self.name)
551
+
552
+ return None
553
+
554
+ def make_element_inner_weight_gradient(self):
555
+ ROT = wp.constant(wp.mat22f(0.0, -1.0, 1.0, 0.0))
556
+
557
+ def element_inner_weight_gradient_linear(
558
+ coords: Coords,
559
+ node_index_in_elt: int,
560
+ ):
561
+ return ROT
562
+
563
+ if self.ORDER == 1:
564
+ return cache.get_func(element_inner_weight_gradient_linear, self.name)
565
+
566
+ return None
567
+
568
+
569
+ class TriangleRaviartThomasShapeFunctions(TriangleShapeFunction):
570
+ value = ShapeFunction.Value.ContravariantVector
571
+
572
+ def __init__(self, degree: int):
573
+ if degree != 1:
574
+ raise NotImplementedError("Only linear Raviart-Thomas implemented right now")
575
+
576
+ self.ORDER = wp.constant(degree)
577
+
578
+ self.NODES_PER_ELEMENT = wp.constant(3)
579
+ self.NODES_PER_SIDE = wp.constant(1)
580
+
581
+ self.VERTEX_NODE_COUNT = wp.constant(0)
582
+ self.EDGE_NODE_COUNT = wp.constant(1)
583
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
584
+
585
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
586
+
587
+ @property
588
+ def name(self) -> str:
589
+ return f"TriRT_{self.ORDER}"
590
+
591
+ def _get_node_type_and_type_index(self):
592
+ @cache.dynamic_func(suffix=self.name)
593
+ def node_type_and_index(
594
+ node_index_in_elt: int,
595
+ ):
596
+ return TriangleShapeFunction.EDGE, node_index_in_elt
597
+
598
+ return node_type_and_index
599
+
600
+ def make_node_coords_in_element(self):
601
+ @cache.dynamic_func(suffix=self.name)
602
+ def node_coords_in_element(
603
+ node_index_in_elt: int,
604
+ ):
605
+ coords = Coords(0.5)
606
+ coords[(node_index_in_elt + 2) % 3] = 0.0
607
+ return coords
608
+
609
+ return node_coords_in_element
610
+
611
+ def make_node_quadrature_weight(self):
612
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
613
+
614
+ @cache.dynamic_func(suffix=self.name)
615
+ def node_quadrature_weight(node_index_in_element: int):
616
+ return 1.0 / float(NODES_PER_ELEMENT)
617
+
618
+ return node_quadrature_weight
619
+
620
+ def make_trace_node_quadrature_weight(self):
621
+ NODES_PER_SIDE = self.NODES_PER_SIDE
622
+
623
+ @cache.dynamic_func(suffix=self.name)
624
+ def trace_node_quadrature_weight(node_index_in_element: int):
625
+ return 1.0 / float(NODES_PER_SIDE)
626
+
627
+ return trace_node_quadrature_weight
628
+
629
+ def make_element_inner_weight(self):
630
+ ORDER = self.ORDER
631
+
632
+ def element_inner_weight_linear(
633
+ coords: Coords,
634
+ node_index_in_elt: int,
635
+ ):
636
+ x = wp.vec2(coords[1], coords[2])
637
+ p = self._vertex_coords((node_index_in_elt + 2) % 3)
638
+
639
+ d = x - p
640
+ return d
641
+
642
+ if ORDER == 1:
643
+ return cache.get_func(element_inner_weight_linear, self.name)
644
+
645
+ return None
646
+
647
+ def make_element_inner_weight_gradient(self):
648
+ def element_inner_weight_gradient_linear(
649
+ coords: Coords,
650
+ node_index_in_elt: int,
651
+ ):
652
+ return wp.identity(n=2, dtype=float)
653
+
654
+ if self.ORDER == 1:
655
+ return cache.get_func(element_inner_weight_gradient_linear, self.name)
656
+
657
+ return None
@@ -6,6 +6,7 @@ from warp.fem.types import ElementIndex
6
6
  from .shape import (
7
7
  ShapeFunction,
8
8
  TetrahedronPolynomialShapeFunctions,
9
+ TetrahedronShapeFunction,
9
10
  )
10
11
  from .topology import SpaceTopology, forward_base_topology
11
12
 
@@ -15,6 +16,7 @@ class TetmeshTopologyArg:
15
16
  tet_edge_indices: wp.array2d(dtype=int)
16
17
  tet_face_indices: wp.array2d(dtype=int)
17
18
  face_vertex_indices: wp.array(dtype=wp.vec3i)
19
+ face_tet_indices: wp.array(dtype=wp.vec2i)
18
20
 
19
21
  vertex_count: int
20
22
  edge_count: int
@@ -27,13 +29,14 @@ class TetmeshSpaceTopology(SpaceTopology):
27
29
  def __init__(
28
30
  self,
29
31
  mesh: Tetmesh,
30
- shape: ShapeFunction,
31
- need_tet_edge_indices: bool = True,
32
- need_tet_face_indices: bool = True,
32
+ shape: TetrahedronShapeFunction,
33
33
  ):
34
+ self._shape = shape
34
35
  super().__init__(mesh, shape.NODES_PER_ELEMENT)
35
36
  self._mesh = mesh
36
- self._shape = shape
37
+
38
+ need_tet_edge_indices = self._shape.EDGE_NODE_COUNT > 0
39
+ need_tet_face_indices = self._shape.FACE_NODE_COUNT > 0
37
40
 
38
41
  if need_tet_edge_indices:
39
42
  self._tet_edge_indices = self._mesh.tet_edge_indices
@@ -47,12 +50,20 @@ class TetmeshSpaceTopology(SpaceTopology):
47
50
  else:
48
51
  self._tet_face_indices = wp.empty(shape=(0, 0), dtype=int)
49
52
 
53
+ self.element_node_index = self._make_element_node_index()
54
+ self.element_node_sign = self._make_element_node_sign()
55
+
56
+ @property
57
+ def name(self):
58
+ return f"{self.geometry.name}_{self._shape.name}"
59
+
50
60
  @cache.cached_arg_value
51
61
  def topo_arg_value(self, device):
52
62
  arg = TetmeshTopologyArg()
53
63
  arg.tet_face_indices = self._tet_face_indices.to(device)
54
64
  arg.tet_edge_indices = self._tet_edge_indices.to(device)
55
65
  arg.face_vertex_indices = self._mesh.face_vertex_indices.to(device)
66
+ arg.face_tet_indices = self._mesh.face_tet_indices.to(device)
56
67
 
57
68
  arg.vertex_count = self._mesh.vertex_count()
58
69
  arg.face_count = self._mesh.side_count()
@@ -126,31 +137,19 @@ class TetmeshSpaceTopology(SpaceTopology):
126
137
  t1_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t1_vtx)
127
138
  tet_face_indices[t1, t1_face] = e
128
139
 
129
-
130
- class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
131
- def __init__(self, mesh: Tetmesh, shape: TetrahedronPolynomialShapeFunctions):
132
- super().__init__(mesh, shape, need_tet_edge_indices=shape.ORDER >= 2, need_tet_face_indices=shape.ORDER >= 3)
133
-
134
- self.element_node_index = self._make_element_node_index()
135
-
136
140
  def node_count(self) -> int:
137
- ORDER = self._shape.ORDER
138
- INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1)
139
- INTERIOR_NODES_PER_FACE = max(0, ORDER - 2) * max(0, ORDER - 1) // 2
140
- INTERIOR_NODES_PER_CELL = max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6
141
-
142
141
  return (
143
- self._mesh.vertex_count()
144
- + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE
145
- + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
146
- + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
142
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
143
+ + self._mesh.edge_count() * self._shape.EDGE_NODE_COUNT
144
+ + self._mesh.side_count() * self._shape.FACE_NODE_COUNT
145
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
147
146
  )
148
147
 
149
148
  def _make_element_node_index(self):
150
- ORDER = self._shape.ORDER
151
- INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1))
152
- INTERIOR_NODES_PER_FACE = wp.constant(max(0, ORDER - 2) * max(0, ORDER - 1) // 2)
153
- INTERIOR_NODES_PER_CELL = wp.constant(max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6)
149
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
150
+ INTERIOR_NODES_PER_EDGE = self._shape.EDGE_NODE_COUNT
151
+ INTERIOR_NODES_PER_FACE = self._shape.FACE_NODE_COUNT
152
+ INTERIOR_NODES_PER_CELL = self._shape.INTERIOR_NODE_COUNT
154
153
 
155
154
  @cache.dynamic_func(suffix=self.name)
156
155
  def element_node_index(
@@ -162,9 +161,11 @@ class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
162
161
  node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
163
162
 
164
163
  if node_type == TetrahedronPolynomialShapeFunctions.VERTEX:
165
- return geo_arg.tet_vertex_indices[element_index][type_index]
164
+ vertex = type_index // VERTEX_NODE_COUNT
165
+ vertex_node = type_index - VERTEX_NODE_COUNT * vertex
166
+ return geo_arg.tet_vertex_indices[element_index][vertex] * VERTEX_NODE_COUNT + vertex_node
166
167
 
167
- global_offset = topo_arg.vertex_count
168
+ global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
168
169
 
169
170
  if node_type == TetrahedronPolynomialShapeFunctions.EDGE:
170
171
  edge = type_index // INTERIOR_NODES_PER_EDGE
@@ -173,14 +174,8 @@ class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
173
174
  global_edge_index = topo_arg.tet_edge_indices[element_index][edge]
174
175
 
175
176
  # Test if we need to swap edge direction
176
- if INTERIOR_NODES_PER_EDGE > 1:
177
- if edge < 3:
178
- c1 = edge
179
- c2 = (edge + 1) % 3
180
- else:
181
- c1 = edge - 3
182
- c2 = 3
183
-
177
+ if wp.static(INTERIOR_NODES_PER_EDGE > 1):
178
+ c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
184
179
  if geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2]:
185
180
  edge_node = INTERIOR_NODES_PER_EDGE - 1 - edge_node
186
181
 
@@ -194,7 +189,7 @@ class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
194
189
 
195
190
  global_face_index = topo_arg.tet_face_indices[element_index][face]
196
191
 
197
- if INTERIOR_NODES_PER_FACE == 3:
192
+ if wp.static(INTERIOR_NODES_PER_FACE == 3):
198
193
  # Hard code for P4 case, 3 nodes per face
199
194
  # Higher orders would require rotating triangle coordinates, this is not supported yet
200
195
 
@@ -216,9 +211,46 @@ class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
216
211
 
217
212
  return element_node_index
218
213
 
214
+ def _make_element_node_sign(self):
215
+ INTERIOR_NODES_PER_EDGE = self._shape.EDGE_NODE_COUNT
216
+ INTERIOR_NODES_PER_FACE = self._shape.FACE_NODE_COUNT
217
+
218
+ @cache.dynamic_func(suffix=self.name)
219
+ def element_node_sign(
220
+ geo_arg: self.geometry.CellArg,
221
+ topo_arg: TetmeshTopologyArg,
222
+ element_index: ElementIndex,
223
+ node_index_in_elt: int,
224
+ ):
225
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
226
+
227
+ if wp.static(INTERIOR_NODES_PER_EDGE > 0):
228
+ if node_type == TetrahedronShapeFunction.EDGE:
229
+ edge = type_index // INTERIOR_NODES_PER_EDGE
230
+ c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
231
+
232
+ return wp.select(
233
+ geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2],
234
+ 1.0,
235
+ -1.0,
236
+ )
237
+
238
+ if wp.static(INTERIOR_NODES_PER_FACE > 0):
239
+ if node_type == TetrahedronShapeFunction.FACE:
240
+ face = type_index // INTERIOR_NODES_PER_FACE
241
+
242
+ global_face_index = topo_arg.tet_face_indices[element_index][face]
243
+ inner = topo_arg.face_tet_indices[global_face_index][0]
244
+
245
+ return wp.select(inner == element_index, -1.0, 1.0)
246
+
247
+ return 1.0
248
+
249
+ return element_node_sign
250
+
219
251
 
220
252
  def make_tetmesh_space_topology(mesh: Tetmesh, shape: ShapeFunction):
221
- if isinstance(shape, TetrahedronPolynomialShapeFunctions):
222
- return forward_base_topology(TetmeshPolynomialSpaceTopology, mesh, shape)
253
+ if isinstance(shape, TetrahedronShapeFunction):
254
+ return forward_base_topology(TetmeshSpaceTopology, mesh, shape)
223
255
 
224
256
  raise ValueError(f"Unsupported shape function {shape.name}")
@@ -37,6 +37,7 @@ class SpaceTopology:
37
37
  self.ElementArg = geometry.CellArg
38
38
 
39
39
  self._make_constant_element_node_count()
40
+ self._make_constant_element_node_sign()
40
41
 
41
42
  @property
42
43
  def geometry(self) -> Geometry:
@@ -176,6 +177,18 @@ class SpaceTopology:
176
177
  self.element_node_count = constant_element_node_count
177
178
  self.side_neighbor_node_counts = constant_side_neighbor_node_counts
178
179
 
180
+ def _make_constant_element_node_sign(self):
181
+ @cache.dynamic_func(suffix=self.name)
182
+ def constant_element_node_sign(
183
+ geo_arg: self.geometry.CellArg,
184
+ topo_arg: self.TopologyArg,
185
+ element_index: ElementIndex,
186
+ node_index_in_element: int,
187
+ ):
188
+ return 1.0
189
+
190
+ self.element_node_sign = constant_element_node_sign
191
+
179
192
 
180
193
  class TraceSpaceTopology(SpaceTopology):
181
194
  """Auto-generated trace topology defining the node indices associated to the geometry sides"""
@@ -267,6 +280,21 @@ class TraceSpaceTopology(SpaceTopology):
267
280
 
268
281
  return trace_element_node_index
269
282
 
283
+ def _make_element_node_sign(self):
284
+ @cache.dynamic_func(suffix=self.name)
285
+ def trace_element_node_sign(
286
+ geo_side_arg: self.geometry.SideArg,
287
+ topo_arg: self._topo.TopologyArg,
288
+ element_index: ElementIndex,
289
+ node_index_in_elt: int,
290
+ ):
291
+ cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
292
+
293
+ geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
294
+ return self._topo.element_node_sign(geo_cell_arg, topo_arg, cell_index, index_in_cell)
295
+
296
+ return trace_element_node_sign
297
+
270
298
  def full_space_topology(self) -> SpaceTopology:
271
299
  """Returns the full space topology from which this topology is derived"""
272
300
  return self._topo
@@ -351,8 +379,18 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
351
379
  inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
352
380
  return inner_count, outer_count
353
381
 
382
+ @cache.dynamic_func(suffix=self.name)
383
+ def element_node_sign(
384
+ elt_arg: self.geometry.CellArg,
385
+ topo_arg: self.TopologyArg,
386
+ element_index: ElementIndex,
387
+ node_index_in_elt: int,
388
+ ):
389
+ return self.base.element_node_sign(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
390
+
354
391
  self.element_node_index = element_node_index
355
392
  self.element_node_count = element_node_count
393
+ self.element_node_sign = element_node_sign
356
394
  self.side_neighbor_node_counts = side_neighbor_node_counts
357
395
 
358
396