warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
@@ -7,10 +7,33 @@ from warp.fem import cache
7
7
  from warp.fem.polynomial import Polynomial, is_closed, lagrange_scales, quadrature_1d
8
8
  from warp.fem.types import Coords
9
9
 
10
- from .triangle_shape_function import Triangle2DPolynomialShapeFunctions
10
+ from .shape_function import ShapeFunction
11
+ from .triangle_shape_function import TrianglePolynomialShapeFunctions
11
12
 
12
13
 
13
- class SquareBipolynomialShapeFunctions:
14
+ class SquareShapeFunction(ShapeFunction):
15
+ VERTEX = 0
16
+ EDGE_X = 1
17
+ EDGE_Y = 2
18
+ INTERIOR = 3
19
+
20
+ VERTEX_NODE_COUNT: int
21
+ """Number of shape function nodes per vertex"""
22
+
23
+ EDGE_NODE_COUNT: int
24
+ """Number of shape function nodes per triangle edge (excluding vertex nodes)"""
25
+
26
+ INTERIOR_NODE_COUNT: int
27
+ """Number of shape function nodes per triangle (excluding edge and vertex nodes)"""
28
+
29
+ @wp.func
30
+ def _vertex_coords_f(vidx_in_cell: int):
31
+ x = vidx_in_cell // 2
32
+ y = vidx_in_cell - 2 * x
33
+ return wp.vec2(float(x), float(y))
34
+
35
+
36
+ class SquareBipolynomialShapeFunctions(SquareShapeFunction):
14
37
  def __init__(self, degree: int, family: Polynomial):
15
38
  self.family = family
16
39
 
@@ -18,6 +41,15 @@ class SquareBipolynomialShapeFunctions:
18
41
  self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 1))
19
42
  self.NODES_PER_SIDE = wp.constant(degree + 1)
20
43
 
44
+ if is_closed(self.family):
45
+ self.VERTEX_NODE_COUNT = wp.constant(1)
46
+ self.EDGE_NODE_COUNT = wp.constant(max(0, degree - 1))
47
+ self.INTERIOR_NODE_COUNT = wp.constant(max(0, degree - 1) ** 2)
48
+ else:
49
+ self.VERTEX_NODE_COUNT = wp.constant(0)
50
+ self.EDGE_NODE_COUNT = wp.constant(0)
51
+ self.INTERIOR_NODE_COUNT = self.NODES_PER_ELEMENT
52
+
21
53
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
22
54
  lagrange_scale = lagrange_scales(lobatto_coords)
23
55
 
@@ -27,20 +59,72 @@ class SquareBipolynomialShapeFunctions:
27
59
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
28
60
  self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
29
61
 
62
+ self._node_ij = self._make_node_ij()
63
+ self.node_type_and_type_index = self._make_node_type_and_type_index()
64
+
30
65
  @property
31
66
  def name(self) -> str:
32
67
  return f"Square_Q{self.ORDER}_{self.family}"
33
68
 
34
- def make_node_coords_in_element(self):
69
+ def _make_node_ij(self):
70
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
71
+
72
+ def node_ij(node_index_in_elt: int):
73
+ node_i = node_index_in_elt // ORDER_PLUS_ONE
74
+ node_j = node_index_in_elt - ORDER_PLUS_ONE * node_i
75
+ return node_i, node_j
76
+
77
+ return cache.get_func(node_ij, self.name)
78
+
79
+ def _make_node_type_and_type_index(self):
35
80
  ORDER = self.ORDER
81
+
82
+ @cache.dynamic_func(suffix=self.name)
83
+ def node_type_and_type_index_open(
84
+ node_index_in_elt: int,
85
+ ):
86
+ return SquareShapeFunction.INTERIOR, 0, node_index_in_elt
87
+
88
+ @cache.dynamic_func(suffix=self.name)
89
+ def node_type_and_type_index(
90
+ node_index_in_elt: int,
91
+ ):
92
+ i, j = self._node_ij(node_index_in_elt)
93
+
94
+ zi = int(i == 0)
95
+ zj = int(j == 0)
96
+
97
+ mi = int(i == ORDER)
98
+ mj = int(j == ORDER)
99
+
100
+ if zi + mi == 1:
101
+ if zj + mj == 1:
102
+ # vertex
103
+ type_instance = mi * 2 + mj
104
+ return SquareShapeFunction.VERTEX, type_instance, 0
105
+ # y edge
106
+ type_index = j - 1
107
+ type_instance = mi
108
+ return SquareShapeFunction.EDGE_Y, type_instance, type_index
109
+ elif zj + mj == 1:
110
+ # x edge
111
+ type_index = i - 1
112
+ type_instance = mj
113
+ return SquareShapeFunction.EDGE_X, type_instance, type_index
114
+
115
+ type_index = (i - 1) * (ORDER - 1) + (j - 1)
116
+ return SquareShapeFunction.INTERIOR, 0, type_index
117
+
118
+ return node_type_and_type_index if is_closed(self.family) else node_type_and_type_index_open
119
+
120
+ def make_node_coords_in_element(self):
36
121
  LOBATTO_COORDS = self.LOBATTO_COORDS
37
122
 
38
123
  @cache.dynamic_func(suffix=self.name)
39
124
  def node_coords_in_element(
40
125
  node_index_in_elt: int,
41
126
  ):
42
- node_i = node_index_in_elt // (ORDER + 1)
43
- node_j = node_index_in_elt - (ORDER + 1) * node_i
127
+ node_i, node_j = self._node_ij(node_index_in_elt)
44
128
  return Coords(LOBATTO_COORDS[node_i], LOBATTO_COORDS[node_j], 0.0)
45
129
 
46
130
  return node_coords_in_element
@@ -52,8 +136,7 @@ class SquareBipolynomialShapeFunctions:
52
136
  def node_quadrature_weight(
53
137
  node_index_in_elt: int,
54
138
  ):
55
- node_i = node_index_in_elt // (ORDER + 1)
56
- node_j = node_index_in_elt - (ORDER + 1) * node_i
139
+ node_i, node_j = self._node_ij(node_index_in_elt)
57
140
  return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j]
58
141
 
59
142
  def node_quadrature_weight_linear(
@@ -66,12 +149,6 @@ class SquareBipolynomialShapeFunctions:
66
149
 
67
150
  return cache.get_func(node_quadrature_weight, self.name)
68
151
 
69
- @wp.func
70
- def _vertex_coords_f(vidx_in_cell: int):
71
- x = vidx_in_cell // 2
72
- y = vidx_in_cell - 2 * x
73
- return wp.vec2(float(x), float(y))
74
-
75
152
  def make_trace_node_quadrature_weight(self):
76
153
  ORDER = self.ORDER
77
154
  LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
@@ -82,11 +159,10 @@ class SquareBipolynomialShapeFunctions:
82
159
  # We're either on a side interior or at a vertex
83
160
  # I.e., either both indices are at extrema, or only one is
84
161
  # Pick the interior one if possible, if both are at extrema pick any one
85
- node_i = node_index_in_elt // (ORDER + 1)
162
+ node_i, node_j = self._node_ij(node_index_in_elt)
86
163
  if node_i > 0 and node_i < ORDER:
87
164
  return LOBATTO_WEIGHT[node_i]
88
165
 
89
- node_j = node_index_in_elt - (ORDER + 1) * node_i
90
166
  return LOBATTO_WEIGHT[node_j]
91
167
 
92
168
  def trace_node_quadrature_weight_linear(
@@ -116,8 +192,7 @@ class SquareBipolynomialShapeFunctions:
116
192
  coords: Coords,
117
193
  node_index_in_elt: int,
118
194
  ):
119
- node_i = node_index_in_elt // ORDER_PLUS_ONE
120
- node_j = node_index_in_elt - ORDER_PLUS_ONE * node_i
195
+ node_i, node_j = self._node_ij(node_index_in_elt)
121
196
 
122
197
  w = float(1.0)
123
198
  for k in range(ORDER_PLUS_ONE):
@@ -154,8 +229,7 @@ class SquareBipolynomialShapeFunctions:
154
229
  coords: Coords,
155
230
  node_index_in_elt: int,
156
231
  ):
157
- node_i = node_index_in_elt // ORDER_PLUS_ONE
158
- node_j = node_index_in_elt - ORDER_PLUS_ONE * node_i
232
+ node_i, node_j = self._node_ij(node_index_in_elt)
159
233
 
160
234
  prefix_x = float(1.0)
161
235
  prefix_y = float(1.0)
@@ -232,18 +306,13 @@ class SquareBipolynomialShapeFunctions:
232
306
  return np.concatenate(cells)[np.newaxis, :], np.array([cell_type], dtype=np.int8)
233
307
 
234
308
 
235
- class SquareSerendipityShapeFunctions:
309
+ class SquareSerendipityShapeFunctions(SquareShapeFunction):
236
310
  """
237
311
  Serendipity element ~ tensor product space without interior nodes
238
312
  Side shape functions are usual Lagrange shape functions times a linear function in the normal direction
239
313
  Corner shape functions are bilinear shape functions times a function of (x^{d-1} + y^{d-1})
240
314
  """
241
315
 
242
- # Node categories
243
- VERTEX = wp.constant(0)
244
- EDGE_X = wp.constant(1)
245
- EDGE_Y = wp.constant(2)
246
-
247
316
  def __init__(self, degree: int, family: Polynomial):
248
317
  if not is_closed(family):
249
318
  raise ValueError("A closed polynomial family is required to define serendipity elements")
@@ -257,6 +326,10 @@ class SquareSerendipityShapeFunctions:
257
326
  self.NODES_PER_ELEMENT = wp.constant(4 * degree)
258
327
  self.NODES_PER_SIDE = wp.constant(degree + 1)
259
328
 
329
+ self.VERTEX_NODE_COUNT = wp.constant(1)
330
+ self.EDGE_NODE_COUNT = wp.constant(degree - 1)
331
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
332
+
260
333
  lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
261
334
  lagrange_scale = lagrange_scales(lobatto_coords)
262
335
 
@@ -279,39 +352,33 @@ class SquareSerendipityShapeFunctions:
279
352
  node_index_in_elt: int,
280
353
  ):
281
354
  if node_index_in_elt < 4:
282
- return SquareSerendipityShapeFunctions.VERTEX, node_index_in_elt
283
-
284
- type_index = (node_index_in_elt - 4) // 2
285
- side = node_index_in_elt - 4 - 2 * type_index
286
- return SquareSerendipityShapeFunctions.EDGE_X + side, type_index
355
+ return SquareSerendipityShapeFunctions.VERTEX, node_index_in_elt, 0
287
356
 
288
- return node_type_and_index
357
+ edge_index = (node_index_in_elt - 4) // 2
358
+ edge_axis = node_index_in_elt - 4 - 2 * edge_index
289
359
 
290
- @wp.func
291
- def side_offset_and_index(type_index: int):
292
- index_in_side = type_index // 2
293
- side_offset = type_index - 2 * index_in_side
360
+ index_in_side = edge_index // 2
361
+ side_offset = edge_index - 2 * index_in_side
362
+ return SquareSerendipityShapeFunctions.EDGE_X + edge_axis, side_offset, index_in_side
294
363
 
295
- return side_offset, index_in_side
364
+ return node_type_and_index
296
365
 
297
366
  def _get_node_lobatto_indices(self):
298
367
  ORDER = self.ORDER
299
368
 
300
369
  @cache.dynamic_func(suffix=self.name)
301
- def node_lobatto_indices(node_type: int, type_index: int):
370
+ def node_lobatto_indices(node_type: int, type_instance: int, type_index: int):
302
371
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
303
- node_i = type_index // 2
304
- node_j = type_index - 2 * node_i
372
+ node_i = type_instance // 2
373
+ node_j = type_instance - 2 * node_i
305
374
  return node_i * ORDER, node_j * ORDER
306
375
 
307
- side_offset, index_in_side = SquareSerendipityShapeFunctions.side_offset_and_index(type_index)
308
-
309
376
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
310
- node_i = 1 + index_in_side
311
- node_j = side_offset * ORDER
377
+ node_i = 1 + type_index
378
+ node_j = type_instance * ORDER
312
379
  else:
313
- node_j = 1 + index_in_side
314
- node_i = side_offset * ORDER
380
+ node_j = 1 + type_index
381
+ node_i = type_instance * ORDER
315
382
 
316
383
  return node_i, node_j
317
384
 
@@ -324,8 +391,8 @@ class SquareSerendipityShapeFunctions:
324
391
  def node_coords_in_element(
325
392
  node_index_in_elt: int,
326
393
  ):
327
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
328
- node_i, node_j = self._node_lobatto_indices(node_type, type_index)
394
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
395
+ node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
329
396
  return Coords(LOBATTO_COORDS[node_i], LOBATTO_COORDS[node_j], 0.0)
330
397
 
331
398
  return node_coords_in_element
@@ -337,7 +404,7 @@ class SquareSerendipityShapeFunctions:
337
404
  def node_quadrature_weight(
338
405
  node_index_in_elt: int,
339
406
  ):
340
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
407
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
341
408
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
342
409
  return 0.25 / float(ORDER * ORDER)
343
410
 
@@ -352,12 +419,11 @@ class SquareSerendipityShapeFunctions:
352
419
  def trace_node_quadrature_weight(
353
420
  node_index_in_elt: int,
354
421
  ):
355
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
422
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
356
423
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
357
424
  return LOBATTO_WEIGHT[0]
358
425
 
359
- side_offset, index_in_side = SquareSerendipityShapeFunctions.side_offset_and_index(type_index)
360
- return LOBATTO_WEIGHT[1 + index_in_side]
426
+ return LOBATTO_WEIGHT[1 + type_index]
361
427
 
362
428
  return trace_node_quadrature_weight
363
429
 
@@ -376,9 +442,8 @@ class SquareSerendipityShapeFunctions:
376
442
  coords: Coords,
377
443
  node_index_in_elt: int,
378
444
  ):
379
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
380
-
381
- node_i, node_j = self._node_lobatto_indices(node_type, type_index)
445
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
446
+ node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
382
447
 
383
448
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
384
449
  cx = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
@@ -429,9 +494,8 @@ class SquareSerendipityShapeFunctions:
429
494
  coords: Coords,
430
495
  node_index_in_elt: int,
431
496
  ):
432
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
433
-
434
- node_i, node_j = self._node_lobatto_indices(node_type, type_index)
497
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
498
+ node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
435
499
 
436
500
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
437
501
  cx = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
@@ -533,7 +597,7 @@ class SquareSerendipityShapeFunctions:
533
597
  return tris, np.full(tris.shape[0], cell_type, dtype=np.int8)
534
598
 
535
599
 
536
- class SquareNonConformingPolynomialShapeFunctions:
600
+ class SquareNonConformingPolynomialShapeFunctions(ShapeFunction):
537
601
  # embeds the largest equilateral triangle centered at (0.5, 0.5) into the reference square
538
602
  _tri_height = 0.75
539
603
  _tri_side = 2.0 / math.sqrt(3.0) * _tri_height
@@ -542,7 +606,7 @@ class SquareNonConformingPolynomialShapeFunctions:
542
606
  _TRI_OFFSET = wp.constant(wp.vec2(0.5 - 0.5 * _tri_side, 0.5 - _tri_height / 3.0))
543
607
 
544
608
  def __init__(self, degree: int):
545
- self._tri_shape = Triangle2DPolynomialShapeFunctions(degree=degree)
609
+ self._tri_shape = TrianglePolynomialShapeFunctions(degree=degree)
546
610
  self.ORDER = self._tri_shape.ORDER
547
611
  self.NODES_PER_ELEMENT = self._tri_shape.NODES_PER_ELEMENT
548
612
 
@@ -580,7 +644,7 @@ class SquareNonConformingPolynomialShapeFunctions:
580
644
  node_index_in_elt: int,
581
645
  ):
582
646
  node_type, type_index = self._tri_shape.node_type_and_type_index(node_index_in_elt)
583
- if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
647
+ if node_type == TrianglePolynomialShapeFunctions.VERTEX:
584
648
  return 0.18518521
585
649
  return 0.14814811
586
650
 
@@ -641,3 +705,207 @@ class SquareNonConformingPolynomialShapeFunctions:
641
705
  return wp.transpose(SQUARE_TO_TRI) * grad
642
706
 
643
707
  return element_inner_weight_gradient
708
+
709
+
710
+ class SquareNedelecFirstKindShapeFunctions(SquareShapeFunction):
711
+ value = ShapeFunction.Value.CovariantVector
712
+
713
+ def __init__(self, degree: int):
714
+ if degree != 1:
715
+ raise NotImplementedError("Only linear Nédélec implemented right now")
716
+
717
+ self.ORDER = wp.constant(degree)
718
+ self.NODES_PER_ELEMENT = wp.constant(4)
719
+ self.NODES_PER_SIDE = wp.constant(1)
720
+
721
+ self.VERTEX_NODE_COUNT = wp.constant(0)
722
+ self.EDGE_NODE_COUNT = wp.constant(1)
723
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
724
+
725
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
726
+
727
+ @property
728
+ def name(self) -> str:
729
+ return f"SquareN1_{self.ORDER}"
730
+
731
+ def _get_node_type_and_type_index(self):
732
+ @cache.dynamic_func(suffix=self.name)
733
+ def node_type_and_index(
734
+ node_index_in_elt: int,
735
+ ):
736
+ axis = node_index_in_elt // 2
737
+ offset = node_index_in_elt - 2 * axis
738
+ return SquareShapeFunction.EDGE_X + axis, offset, 0
739
+
740
+ return node_type_and_index
741
+
742
+ def make_node_coords_in_element(self):
743
+ @cache.dynamic_func(suffix=self.name)
744
+ def node_coords_in_element(
745
+ node_index_in_elt: int,
746
+ ):
747
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
748
+ axis = node_type - SquareShapeFunction.EDGE_X
749
+
750
+ coords = Coords()
751
+ coords[axis] = 0.5
752
+ coords[1 - axis] = float(type_instance)
753
+
754
+ return node_coords_in_element
755
+
756
+ def make_node_quadrature_weight(self):
757
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
758
+
759
+ @cache.dynamic_func(suffix=self.name)
760
+ def node_quadrature_weight(node_index_in_element: int):
761
+ return 1.0 / float(NODES_PER_ELEMENT)
762
+
763
+ return node_quadrature_weight
764
+
765
+ def make_trace_node_quadrature_weight(self):
766
+ NODES_PER_SIDE = self.NODES_PER_SIDE
767
+
768
+ @cache.dynamic_func(suffix=self.name)
769
+ def trace_node_quadrature_weight(node_index_in_element: int):
770
+ return 1.0 / float(NODES_PER_SIDE)
771
+
772
+ return trace_node_quadrature_weight
773
+
774
+ def make_element_inner_weight(self):
775
+ @cache.dynamic_func(suffix=self.name)
776
+ def element_inner_weight(
777
+ coords: Coords,
778
+ node_index_in_elt: int,
779
+ ):
780
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
781
+
782
+ axis = node_type - SquareShapeFunction.EDGE_X
783
+ a = float(2 * type_instance - 1)
784
+ b = float(1 - type_instance)
785
+
786
+ w = wp.vec2(0.0)
787
+ w[axis] = b + a * coords[1 - axis]
788
+
789
+ return w
790
+
791
+ return element_inner_weight
792
+
793
+ def make_element_inner_weight_gradient(self):
794
+ @cache.dynamic_func(suffix=self.name)
795
+ def element_inner_weight_gradient(
796
+ coords: Coords,
797
+ node_index_in_elt: int,
798
+ ):
799
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
800
+
801
+ axis = node_type - SquareShapeFunction.EDGE_X
802
+ a = float(2 * type_instance - 1)
803
+
804
+ grad = wp.mat22(0.0)
805
+ grad[axis, 1 - axis] = a
806
+
807
+ return grad
808
+
809
+ return element_inner_weight_gradient
810
+
811
+
812
+ class SquareRaviartThomasShapeFunctions(SquareShapeFunction):
813
+ value = ShapeFunction.Value.ContravariantVector
814
+
815
+ def __init__(self, degree: int):
816
+ if degree != 1:
817
+ raise NotImplementedError("Only linear Nédélec implemented right now")
818
+
819
+ self.ORDER = wp.constant(degree)
820
+ self.NODES_PER_ELEMENT = wp.constant(4)
821
+ self.NODES_PER_SIDE = wp.constant(1)
822
+
823
+ self.VERTEX_NODE_COUNT = wp.constant(0)
824
+ self.EDGE_NODE_COUNT = wp.constant(1)
825
+ self.INTERIOR_NODE_COUNT = wp.constant(0)
826
+
827
+ self.node_type_and_type_index = self._get_node_type_and_type_index()
828
+
829
+ @property
830
+ def name(self) -> str:
831
+ return f"SquareRT_{self.ORDER}"
832
+
833
+ def _get_node_type_and_type_index(self):
834
+ @cache.dynamic_func(suffix=self.name)
835
+ def node_type_and_index(
836
+ node_index_in_elt: int,
837
+ ):
838
+ axis = node_index_in_elt // 2
839
+ offset = node_index_in_elt - 2 * axis
840
+ return SquareShapeFunction.EDGE_X + axis, offset, 0
841
+
842
+ return node_type_and_index
843
+
844
+ def make_node_coords_in_element(self):
845
+ @cache.dynamic_func(suffix=self.name)
846
+ def node_coords_in_element(
847
+ node_index_in_elt: int,
848
+ ):
849
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
850
+ axis = node_type - SquareShapeFunction.EDGE_X
851
+
852
+ coords = Coords()
853
+ coords[axis] = 0.5
854
+ coords[1 - axis] = float(type_instance)
855
+
856
+ return node_coords_in_element
857
+
858
+ def make_node_quadrature_weight(self):
859
+ NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
860
+
861
+ @cache.dynamic_func(suffix=self.name)
862
+ def node_quadrature_weight(node_index_in_element: int):
863
+ return 1.0 / float(NODES_PER_ELEMENT)
864
+
865
+ return node_quadrature_weight
866
+
867
+ def make_trace_node_quadrature_weight(self):
868
+ NODES_PER_SIDE = self.NODES_PER_SIDE
869
+
870
+ @cache.dynamic_func(suffix=self.name)
871
+ def trace_node_quadrature_weight(node_index_in_element: int):
872
+ return 1.0 / float(NODES_PER_SIDE)
873
+
874
+ return trace_node_quadrature_weight
875
+
876
+ def make_element_inner_weight(self):
877
+ @cache.dynamic_func(suffix=self.name)
878
+ def element_inner_weight(
879
+ coords: Coords,
880
+ node_index_in_elt: int,
881
+ ):
882
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
883
+
884
+ axis = node_type - SquareShapeFunction.EDGE_X
885
+ a = float(2 * type_instance - 1)
886
+ b = float(1 - type_instance)
887
+
888
+ w = wp.vec2(0.0)
889
+ w[1 - axis] = b + a * coords[1 - axis]
890
+
891
+ return w
892
+
893
+ return element_inner_weight
894
+
895
+ def make_element_inner_weight_gradient(self):
896
+ @cache.dynamic_func(suffix=self.name)
897
+ def element_inner_weight_gradient(
898
+ coords: Coords,
899
+ node_index_in_elt: int,
900
+ ):
901
+ node_type, type_instance, type_index = self.node_type_and_type_index(node_index_in_elt)
902
+
903
+ axis = node_type - SquareShapeFunction.EDGE_X
904
+ a = float(2 * type_instance - 1)
905
+
906
+ grad = wp.mat22(0.0)
907
+ grad[1 - axis, 1 - axis] = a
908
+
909
+ return grad
910
+
911
+ return element_inner_weight_gradient