warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -7,21 +7,21 @@ from warp.fem.polynomial import is_closed
7
7
  from warp.fem.types import ElementIndex
8
8
 
9
9
  from .shape import (
10
- CubeSerendipityShapeFunctions,
10
+ CubeShapeFunction,
11
11
  CubeTripolynomialShapeFunctions,
12
- ShapeFunction,
13
12
  )
14
13
  from .topology import SpaceTopology, forward_base_topology
15
14
 
16
15
 
17
16
  class Grid3DSpaceTopology(SpaceTopology):
18
- def __init__(self, grid: Grid3D, shape: ShapeFunction):
19
- if not is_closed(shape.family):
20
- raise ValueError("A closed polynomial family is required to define a continuous function space")
21
-
22
- super().__init__(grid, shape.NODES_PER_ELEMENT)
17
+ def __init__(self, grid: Grid3D, shape: CubeShapeFunction):
23
18
  self._shape = shape
24
- self._grid = grid
19
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
20
+ self.element_node_index = self._make_element_node_index()
21
+
22
+ @property
23
+ def name(self):
24
+ return f"{self.geometry.name}_{self._shape.name}"
25
25
 
26
26
  @wp.func
27
27
  def _vertex_coords(vidx_in_cell: int):
@@ -38,10 +38,107 @@ class Grid3DSpaceTopology(SpaceTopology):
38
38
  corner = Grid3D.get_cell(res, cell_index) + Grid3DSpaceTopology._vertex_coords(vidx_in_cell)
39
39
  return Grid3D._from_3d_index(strides, corner)
40
40
 
41
+ def node_count(self) -> int:
42
+ return (
43
+ self.geometry.vertex_count() * self._shape.VERTEX_NODE_COUNT
44
+ + self.geometry.edge_count() * self._shape.EDGE_NODE_COUNT
45
+ + self.geometry.side_count() * self._shape.FACE_NODE_COUNT
46
+ + self.geometry.cell_count() * self._shape.INTERIOR_NODE_COUNT
47
+ )
48
+
49
+ def _make_element_node_index(self):
50
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
51
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
52
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
53
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
54
+
55
+ @cache.dynamic_func(suffix=self.name)
56
+ def element_node_index(
57
+ cell_arg: Grid3D.CellArg,
58
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
59
+ element_index: ElementIndex,
60
+ node_index_in_elt: int,
61
+ ):
62
+ res = cell_arg.res
63
+ cell = Grid3D.get_cell(res, element_index)
64
+
65
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
66
+
67
+ if wp.static(VERTEX_NODE_COUNT > 0):
68
+ if node_type == CubeShapeFunction.VERTEX:
69
+ return (
70
+ Grid3DSpaceTopology._vertex_index(cell_arg, element_index, type_instance) * VERTEX_NODE_COUNT
71
+ + type_index
72
+ )
73
+
74
+ res = cell_arg.res
75
+ vertex_count = (res[0] + 1) * (res[1] + 1) * (res[2] + 1)
76
+ global_offset = vertex_count * VERTEX_NODE_COUNT
77
+
78
+ if wp.static(EDGE_NODE_COUNT > 0):
79
+ if node_type == CubeShapeFunction.EDGE:
80
+ axis = CubeShapeFunction._edge_axis(type_instance)
81
+ node_all = CubeShapeFunction._edge_coords(type_instance, type_index)
82
+
83
+ res = cell_arg.res
84
+
85
+ edge_index = 0
86
+ if axis > 0:
87
+ edge_index += (res[1] + 1) * (res[2] + 1) * res[0]
88
+ if axis > 1:
89
+ edge_index += (res[0] + 1) * (res[2] + 1) * res[1]
90
+
91
+ res_loc = Grid3D._world_to_local(axis, res)
92
+ cell_loc = Grid3D._world_to_local(axis, cell)
93
+
94
+ edge_index += (res_loc[1] + 1) * (res_loc[2] + 1) * cell_loc[0]
95
+ edge_index += (res_loc[2] + 1) * (cell_loc[1] + node_all[1])
96
+ edge_index += cell_loc[2] + node_all[2]
41
97
 
42
- class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
98
+ return global_offset + EDGE_NODE_COUNT * edge_index + type_index
99
+
100
+ edge_count = (
101
+ (res[0] + 1) * (res[1] + 1) * (res[2])
102
+ + (res[0]) * (res[1] + 1) * (res[2] + 1)
103
+ + (res[0] + 1) * (res[1]) * (res[2] + 1)
104
+ )
105
+ global_offset += edge_count * EDGE_NODE_COUNT
106
+
107
+ if wp.static(FACE_NODE_COUNT > 0):
108
+ if node_type == CubeShapeFunction.FACE:
109
+ axis = CubeShapeFunction._face_axis(type_instance)
110
+ face_offset = CubeShapeFunction._face_offset(type_instance)
111
+
112
+ face_index = 0
113
+ if axis > 0:
114
+ face_index += (res[0] + 1) * res[1] * res[2]
115
+ if axis > 1:
116
+ face_index += (res[1] + 1) * res[2] * res[0]
117
+
118
+ res_loc = Grid3D._world_to_local(axis, res)
119
+ cell_loc = Grid3D._world_to_local(axis, cell)
120
+
121
+ face_index += res_loc[1] * res_loc[2] * (cell_loc[0] + face_offset)
122
+ face_index += res_loc[2] * cell_loc[1]
123
+ face_index += cell_loc[2]
124
+
125
+ return global_offset + FACE_NODE_COUNT * face_index + type_index
126
+
127
+ face_count = (
128
+ (res[0] + 1) * res[1] * res[2] + res[0] * (res[1] + 1) * res[2] + res[0] * res[1] * (res[2] + 1)
129
+ )
130
+ global_offset += face_count * FACE_NODE_COUNT
131
+
132
+ # interior
133
+ return global_offset + element_index * INTERIOR_NODE_COUNT + type_index
134
+
135
+ return element_node_index
136
+
137
+
138
+ class GridTripolynomialSpaceTopology(SpaceTopology):
43
139
  def __init__(self, grid: Grid3D, shape: CubeTripolynomialShapeFunctions):
44
- super().__init__(grid, shape)
140
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
141
+ self._shape = shape
45
142
 
46
143
  self.element_node_index = self._make_element_node_index()
47
144
 
@@ -58,7 +155,7 @@ class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
58
155
  @cache.dynamic_func(suffix=self.name)
59
156
  def element_node_index(
60
157
  cell_arg: Grid3D.CellArg,
61
- topo_arg: Grid3DSpaceTopology.TopologyArg,
158
+ topo_arg: self.TopologyArg,
62
159
  element_index: ElementIndex,
63
160
  node_index_in_elt: int,
64
161
  ):
@@ -105,63 +202,11 @@ class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
105
202
  return np.meshgrid(X, Y, Z, indexing="ij")
106
203
 
107
204
 
108
- class Grid3DSerendipitySpaceTopology(Grid3DSpaceTopology):
109
- def __init__(self, grid: Grid3D, shape: CubeSerendipityShapeFunctions):
110
- super().__init__(grid, shape)
111
-
112
- self.element_node_index = self._make_element_node_index()
113
-
114
- def node_count(self) -> int:
115
- return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.edge_count()
116
-
117
- def _make_element_node_index(self):
118
- ORDER = self._shape.ORDER
119
-
120
- @cache.dynamic_func(suffix=self.name)
121
- def element_node_index(
122
- cell_arg: Grid3D.CellArg,
123
- topo_arg: Grid3DSpaceTopology.TopologyArg,
124
- element_index: ElementIndex,
125
- node_index_in_elt: int,
126
- ):
127
- res = cell_arg.res
128
- cell = Grid3D.get_cell(res, element_index)
129
-
130
- node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
131
-
132
- if node_type == CubeSerendipityShapeFunctions.VERTEX:
133
- return Grid3DSpaceTopology._vertex_index(cell_arg, element_index, type_index)
134
-
135
- axis = CubeSerendipityShapeFunctions._edge_axis(node_type)
136
- node_all = CubeSerendipityShapeFunctions._edge_coords(type_index)
137
-
138
- res = cell_arg.res
139
-
140
- edge_index = 0
141
- if axis > 0:
142
- edge_index += (res[1] + 1) * (res[2] + 1) * res[0]
143
- if axis > 1:
144
- edge_index += (res[0] + 1) * (res[2] + 1) * res[1]
145
-
146
- res_loc = Grid3D._world_to_local(axis, res)
147
- cell_loc = Grid3D._world_to_local(axis, cell)
148
-
149
- edge_index += (res_loc[1] + 1) * (res_loc[2] + 1) * cell_loc[0]
150
- edge_index += (res_loc[2] + 1) * (cell_loc[1] + node_all[1])
151
- edge_index += cell_loc[2] + node_all[2]
152
-
153
- vertex_count = (res[0] + 1) * (res[1] + 1) * (res[2] + 1)
154
-
155
- return vertex_count + (ORDER - 1) * edge_index + (node_all[0] - 1)
156
-
157
- return element_node_index
158
-
159
-
160
- def make_grid_3d_space_topology(grid: Grid3D, shape: ShapeFunction):
161
- if isinstance(shape, CubeSerendipityShapeFunctions):
162
- return forward_base_topology(Grid3DSerendipitySpaceTopology, grid, shape)
163
-
164
- if isinstance(shape, CubeTripolynomialShapeFunctions):
205
+ def make_grid_3d_space_topology(grid: Grid3D, shape: CubeShapeFunction):
206
+ if isinstance(shape, CubeTripolynomialShapeFunctions) and is_closed(shape.family):
165
207
  return forward_base_topology(GridTripolynomialSpaceTopology, grid, shape)
166
208
 
209
+ if isinstance(shape, CubeShapeFunction):
210
+ return forward_base_topology(Grid3DSpaceTopology, grid, shape)
211
+
167
212
  raise ValueError(f"Unsupported shape function {shape.name}")
@@ -6,14 +6,9 @@ from warp.fem.geometry.hexmesh import (
6
6
  FACE_ORIENTATION,
7
7
  FACE_TRANSLATION,
8
8
  )
9
- from warp.fem.polynomial import is_closed
10
9
  from warp.fem.types import ElementIndex
11
10
 
12
- from .shape import (
13
- CubeSerendipityShapeFunctions,
14
- CubeTripolynomialShapeFunctions,
15
- ShapeFunction,
16
- )
11
+ from .shape import CubeShapeFunction
17
12
  from .topology import SpaceTopology, forward_base_topology
18
13
 
19
14
  _FACE_ORIENTATION_I = wp.constant(wp.mat(shape=(16, 2), dtype=int)(FACE_ORIENTATION))
@@ -42,31 +37,36 @@ class HexmeshSpaceTopology(SpaceTopology):
42
37
  def __init__(
43
38
  self,
44
39
  mesh: Hexmesh,
45
- shape: ShapeFunction,
46
- need_hex_edge_indices: bool = True,
47
- need_hex_face_indices: bool = True,
40
+ shape: CubeShapeFunction,
48
41
  ):
49
- if not is_closed(shape.family):
50
- raise ValueError("A closed polynomial family is required to define a continuous function space")
51
-
42
+ self._shape = shape
52
43
  super().__init__(mesh, shape.NODES_PER_ELEMENT)
53
44
  self._mesh = mesh
54
- self.shape = shape
55
45
 
56
- if need_hex_edge_indices:
46
+ need_edge_indices = shape.EDGE_NODE_COUNT > 0
47
+ need_face_indices = shape.FACE_NODE_COUNT > 0
48
+
49
+ if need_edge_indices:
57
50
  self._hex_edge_indices = self._mesh.hex_edge_indices
58
51
  self._edge_count = self._mesh.edge_count()
59
52
  else:
60
53
  self._hex_edge_indices = wp.empty(shape=(0, 0), dtype=int)
61
54
  self._edge_count = 0
62
55
 
63
- if need_hex_face_indices:
56
+ if need_face_indices:
64
57
  self._compute_hex_face_indices()
65
58
  else:
66
59
  self._hex_face_indices = wp.empty(shape=(0, 0), dtype=wp.vec2i)
67
60
 
68
61
  self._compute_hex_face_indices()
69
62
 
63
+ self.element_node_index = self._make_element_node_index()
64
+ self.element_node_sign = self._make_element_node_sign()
65
+
66
+ @property
67
+ def name(self):
68
+ return f"{self.geometry.name}_{self._shape.name}"
69
+
70
70
  @cache.cached_arg_value
71
71
  def topo_arg_value(self, device):
72
72
  arg = HexmeshTopologyArg()
@@ -102,57 +102,50 @@ class HexmeshSpaceTopology(SpaceTopology):
102
102
  ):
103
103
  f = wp.tid()
104
104
 
105
+ # face indices from CubeShapeFunction always have positive orientation,
106
+ # while Hexmesh faces are oriented to point "outside"
107
+ # We need to flip orientation for faces at offset 0
108
+
105
109
  hx0 = face_hex_indices[f][0]
106
110
  local_face_0 = face_hex_face_ori[f][0]
107
111
  ori_0 = face_hex_face_ori[f][1]
108
112
 
109
- hex_face_indices[hx0, local_face_0] = wp.vec2i(f, ori_0)
113
+ local_face_offset_0 = CubeShapeFunction._face_offset(local_face_0)
114
+ flip_0 = ori_0 ^ (1 - local_face_offset_0)
115
+
116
+ hex_face_indices[hx0, local_face_0] = wp.vec2i(f, flip_0)
110
117
 
111
118
  hx1 = face_hex_indices[f][1]
112
119
  local_face_1 = face_hex_face_ori[f][2]
113
120
  ori_1 = face_hex_face_ori[f][3]
114
121
 
115
- hex_face_indices[hx1, local_face_1] = wp.vec2i(f, ori_1)
122
+ local_face_offset_1 = CubeShapeFunction._face_offset(local_face_1)
123
+ flip_1 = ori_1 ^ (1 - local_face_offset_1)
116
124
 
117
-
118
- class HexmeshTripolynomialSpaceTopology(HexmeshSpaceTopology):
119
- def __init__(self, mesh: Hexmesh, shape: CubeTripolynomialShapeFunctions):
120
- super().__init__(mesh, shape, need_hex_edge_indices=shape.ORDER >= 2, need_hex_face_indices=shape.ORDER >= 2)
121
-
122
- self.element_node_index = self._make_element_node_index()
125
+ hex_face_indices[hx1, local_face_1] = wp.vec2i(f, flip_1)
123
126
 
124
127
  def node_count(self) -> int:
125
- ORDER = self.shape.ORDER
126
- INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1)
127
- INTERIOR_NODES_PER_FACE = INTERIOR_NODES_PER_EDGE**2
128
- INTERIOR_NODES_PER_CELL = INTERIOR_NODES_PER_EDGE**3
129
-
130
128
  return (
131
- self._mesh.vertex_count()
132
- + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE
133
- + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
134
- + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
129
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
130
+ + self._mesh.edge_count() * self._shape.EDGE_NODE_COUNT
131
+ + self._mesh.side_count() * self._shape.FACE_NODE_COUNT
132
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
135
133
  )
136
134
 
137
135
  @wp.func
138
- def _rotate_face_index(type_index: int, ori: int, size: int):
139
- i = type_index // size
140
- j = type_index - i * size
141
- coords = wp.vec2i(i, j)
142
-
136
+ def _rotate_face_coordinates(ori: int, offset: int, coords: wp.vec2i):
143
137
  fv = ori // 2
144
138
 
145
- # face indices from shape function always have positive orientation, drop `ori % 2`
146
- rot_i = wp.dot(_FACE_ORIENTATION_I[4 * fv], coords) + _FACE_TRANSLATION_I[fv, 0]
147
- rot_j = wp.dot(_FACE_ORIENTATION_I[4 * fv + 1], coords) + _FACE_TRANSLATION_I[fv, 1]
139
+ rot_i = wp.dot(_FACE_ORIENTATION_I[2 * ori], coords)
140
+ rot_j = wp.dot(_FACE_ORIENTATION_I[2 * ori + 1], coords)
148
141
 
149
- return rot_i * size + rot_j
142
+ return wp.vec2i(rot_i, rot_j) + _FACE_TRANSLATION_I[fv]
150
143
 
151
144
  def _make_element_node_index(self):
152
- ORDER = self.shape.ORDER
153
- INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1))
154
- INTERIOR_NODES_PER_FACE = wp.constant(INTERIOR_NODES_PER_EDGE**2)
155
- INTERIOR_NODES_PER_CELL = wp.constant(INTERIOR_NODES_PER_EDGE**3)
145
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
146
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
147
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
148
+ INTERIOR_NODE_COUNT = self._shape.INTERIOR_NODE_COUNT
156
149
 
157
150
  @cache.dynamic_func(suffix=self.name)
158
151
  def element_node_index(
@@ -161,94 +154,89 @@ class HexmeshTripolynomialSpaceTopology(HexmeshSpaceTopology):
161
154
  element_index: ElementIndex,
162
155
  node_index_in_elt: int,
163
156
  ):
164
- node_type, type_instance, type_index = self.shape.node_type_and_type_index(node_index_in_elt)
165
-
166
- if node_type == CubeTripolynomialShapeFunctions.VERTEX:
167
- return geo_arg.hex_vertex_indices[element_index, _CUBE_TO_HEX_VERTEX[type_instance]]
157
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
168
158
 
169
- offset = topo_arg.vertex_count
159
+ if wp.static(VERTEX_NODE_COUNT > 0):
160
+ if node_type == CubeShapeFunction.VERTEX:
161
+ return (
162
+ geo_arg.hex_vertex_indices[element_index, _CUBE_TO_HEX_VERTEX[type_instance]]
163
+ * VERTEX_NODE_COUNT
164
+ + type_index
165
+ )
170
166
 
171
- if node_type == CubeTripolynomialShapeFunctions.EDGE:
172
- hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
173
- edge_index = topo_arg.hex_edge_indices[element_index, hex_edge]
167
+ offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
174
168
 
175
- v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
176
- v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
169
+ if wp.static(EDGE_NODE_COUNT > 0):
170
+ if node_type == CubeShapeFunction.EDGE:
171
+ hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
172
+ edge_index = topo_arg.hex_edge_indices[element_index, hex_edge]
177
173
 
178
- if v0 > v1:
179
- type_index = ORDER - 1 - type_index
174
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
175
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
180
176
 
181
- return offset + INTERIOR_NODES_PER_EDGE * edge_index + type_index
177
+ if v0 > v1:
178
+ type_index = EDGE_NODE_COUNT - 1 - type_index
182
179
 
183
- offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count
180
+ return offset + EDGE_NODE_COUNT * edge_index + type_index
184
181
 
185
- if node_type == CubeTripolynomialShapeFunctions.FACE:
186
- face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
187
- face_index = face_index_and_ori[0]
188
- face_orientation = face_index_and_ori[1]
182
+ offset += EDGE_NODE_COUNT * topo_arg.edge_count
189
183
 
190
- type_index = HexmeshTripolynomialSpaceTopology._rotate_face_index(
191
- type_index, face_orientation, ORDER - 1
192
- )
184
+ if wp.static(FACE_NODE_COUNT > 0):
185
+ if node_type == CubeShapeFunction.FACE:
186
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
187
+ face_index = face_index_and_ori[0]
188
+ face_orientation = face_index_and_ori[1]
193
189
 
194
- return offset + INTERIOR_NODES_PER_FACE * face_index + type_index
190
+ face_offset = CubeShapeFunction._face_offset(type_instance)
195
191
 
196
- offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count
192
+ if wp.static(FACE_NODE_COUNT > 1):
193
+ face_coords = self._shape._face_node_ij(type_index)
194
+ rot_face_coords = HexmeshSpaceTopology._rotate_face_coordinates(
195
+ face_orientation, face_offset, face_coords
196
+ )
197
+ type_index = self._shape._linear_face_node_index(type_index, rot_face_coords)
197
198
 
198
- return offset + INTERIOR_NODES_PER_CELL * element_index + type_index
199
+ return offset + FACE_NODE_COUNT * face_index + type_index
199
200
 
200
- return element_node_index
201
+ offset += FACE_NODE_COUNT * topo_arg.face_count
201
202
 
203
+ return offset + INTERIOR_NODE_COUNT * element_index + type_index
202
204
 
203
- class HexmeshSerendipitySpaceTopology(HexmeshSpaceTopology):
204
- def __init__(
205
- self,
206
- grid: Hexmesh,
207
- shape: CubeSerendipityShapeFunctions,
208
- ):
209
- super().__init__(grid, shape, need_hex_edge_indices=True, need_hex_face_indices=False)
210
-
211
- self.element_node_index = self._make_element_node_index()
212
-
213
- def node_count(self) -> int:
214
- return self.geometry.vertex_count() + (self.shape.ORDER - 1) * self.geometry.edge_count()
205
+ return element_node_index
215
206
 
216
- def _make_element_node_index(self):
217
- ORDER = self.shape.ORDER
207
+ def _make_element_node_sign(self):
208
+ EDGE_NODE_COUNT = self._shape.EDGE_NODE_COUNT
209
+ FACE_NODE_COUNT = self._shape.FACE_NODE_COUNT
218
210
 
219
211
  @cache.dynamic_func(suffix=self.name)
220
- def element_node_index(
221
- cell_arg: Hexmesh.CellArg,
222
- topo_arg: HexmeshSpaceTopology.TopologyArg,
212
+ def element_node_sign(
213
+ geo_arg: self.geometry.CellArg,
214
+ topo_arg: HexmeshTopologyArg,
223
215
  element_index: ElementIndex,
224
216
  node_index_in_elt: int,
225
217
  ):
226
- node_type, type_index = self.shape.node_type_and_type_index(node_index_in_elt)
227
-
228
- if node_type == CubeSerendipityShapeFunctions.VERTEX:
229
- return cell_arg.hex_vertex_indices[element_index, _CUBE_TO_HEX_VERTEX[type_index]]
230
-
231
- type_instance, index_in_edge = CubeSerendipityShapeFunctions._cube_edge_index(node_type, type_index)
232
- hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
218
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
233
219
 
234
- edge_index = topo_arg.hex_edge_indices[element_index, hex_edge]
220
+ if wp.static(EDGE_NODE_COUNT > 0):
221
+ if node_type == CubeShapeFunction.EDGE:
222
+ hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
223
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
224
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
225
+ return wp.select(v0 > v1, 1.0, -1.0)
235
226
 
236
- v0 = cell_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
237
- v1 = cell_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
227
+ if wp.static(FACE_NODE_COUNT > 0):
228
+ if node_type == CubeShapeFunction.FACE:
229
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
230
+ flip = face_index_and_ori[1] & 1
231
+ return wp.select(flip == 0, -1.0, 1.0)
238
232
 
239
- if v0 > v1:
240
- index_in_edge = ORDER - 1 - index_in_edge
241
-
242
- return topo_arg.vertex_count + (ORDER - 1) * edge_index + index_in_edge
243
-
244
- return element_node_index
233
+ return 1.0
245
234
 
235
+ return element_node_sign
246
236
 
247
- def make_hexmesh_space_topology(mesh: Hexmesh, shape: ShapeFunction):
248
- if isinstance(shape, CubeSerendipityShapeFunctions):
249
- return forward_base_topology(HexmeshSerendipitySpaceTopology, mesh, shape)
250
237
 
251
- if isinstance(shape, CubeTripolynomialShapeFunctions):
252
- return forward_base_topology(HexmeshTripolynomialSpaceTopology, mesh, shape)
238
+ def make_hexmesh_space_topology(mesh: Hexmesh, shape: CubeShapeFunction):
239
+ if isinstance(shape, CubeShapeFunction):
240
+ return forward_base_topology(HexmeshSpaceTopology, mesh, shape)
253
241
 
254
242
  raise ValueError(f"Unsupported shape function {shape.name}")