warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,220 +1,89 @@
1
1
  import warp as wp
2
- import numpy as np
3
2
 
4
-
5
- from warp.fem.types import ElementIndex, Coords, OUTSIDE, vec2i, vec3i, vec4i
3
+ from warp.fem.types import ElementIndex, Coords
6
4
  from warp.fem.geometry import Tetmesh
5
+ from warp.fem import cache
7
6
 
8
- from .dof_mapper import DofMapper
9
- from .nodal_function_space import NodalFunctionSpace, NodalFunctionSpaceTrace
10
-
7
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
8
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
11
9
 
12
- class TetmeshFunctionSpace(NodalFunctionSpace):
13
- DIMENSION = wp.constant(3)
10
+ from .shape import ShapeFunction, ConstantShapeFunction
11
+ from .shape import TetrahedronPolynomialShapeFunctions, TetrahedronNonConformingPolynomialShapeFunctions
14
12
 
15
- @wp.struct
16
- class SpaceArg:
17
- geo_arg: Tetmesh.SideArg
18
13
 
19
- reference_transforms: wp.array(dtype=wp.mat33f)
20
- tet_edge_indices: wp.array2d(dtype=int)
21
- tet_face_indices: wp.array2d(dtype=int)
14
+ @wp.struct
15
+ class TetmeshTopologyArg:
16
+ tet_edge_indices: wp.array2d(dtype=int)
17
+ tet_face_indices: wp.array2d(dtype=int)
18
+ face_vertex_indices: wp.array(dtype=wp.vec3i)
22
19
 
23
- vertex_count: int
24
- edge_count: int
25
- face_count: int
20
+ vertex_count: int
21
+ edge_count: int
22
+ face_count: int
26
23
 
27
- def __init__(self, mesh: Tetmesh, dtype: type = float, dof_mapper: DofMapper = None):
28
- super().__init__(dtype, dof_mapper)
29
- self._mesh = mesh
30
24
 
31
- self._reference_transforms: wp.array = None
32
- self._tet_face_indices: wp.array = None
25
+ class TetmeshSpaceTopology(SpaceTopology):
26
+ TopologyArg = TetmeshTopologyArg
33
27
 
34
- self._compute_reference_transforms()
35
- self._compute_tet_face_indices()
36
- self._compute_tet_edge_indices()
28
+ def __init__(
29
+ self,
30
+ mesh: Tetmesh,
31
+ shape: ShapeFunction,
32
+ need_tet_edge_indices: bool = True,
33
+ need_tet_face_indices: bool = True,
34
+ ):
35
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
36
+ self._mesh = mesh
37
+ self._shape = shape
37
38
 
38
- @property
39
- def geometry(self) -> Tetmesh:
40
- return self._mesh
39
+ if need_tet_edge_indices:
40
+ self._tet_edge_indices = self._mesh.tet_edge_indices
41
+ self._edge_count = self._mesh.edge_count()
42
+ else:
43
+ self._tet_edge_indices = wp.empty(shape=(0, 0), dtype=int)
44
+ self._edge_count = 0
41
45
 
42
- def edge_count(self):
43
- return self._edge_vertex_indices.shape[0]
46
+ if need_tet_face_indices:
47
+ self._compute_tet_face_indices()
48
+ else:
49
+ self._tet_face_indices = wp.empty(shape=(0, 0), dtype=int)
44
50
 
45
- def space_arg_value(self, device):
46
- arg = self.SpaceArg()
47
- arg.geo_arg = self.geometry.side_arg_value(device)
48
- arg.reference_transforms = self._reference_transforms.to(device)
51
+ @cache.cached_arg_value
52
+ def topo_arg_value(self, device):
53
+ arg = TetmeshTopologyArg()
49
54
  arg.tet_face_indices = self._tet_face_indices.to(device)
50
55
  arg.tet_edge_indices = self._tet_edge_indices.to(device)
56
+ arg.face_vertex_indices = self._mesh.face_vertex_indices.to(device)
51
57
 
52
58
  arg.vertex_count = self._mesh.vertex_count()
53
59
  arg.face_count = self._mesh.side_count()
54
- arg.edge_count = self.edge_count()
60
+ arg.edge_count = self._edge_count
55
61
  return arg
56
62
 
57
- class Trace(NodalFunctionSpaceTrace):
58
- def __init__(self, space: NodalFunctionSpace):
59
- super().__init__(space)
60
- self.ORDER = space.ORDER
61
-
62
- @wp.func
63
- def _inner_cell_index(args: SpaceArg, side_index: ElementIndex):
64
- return Tetmesh.side_inner_cell_index(args.geo_arg, side_index)
65
-
66
- @wp.func
67
- def _outer_cell_index(args: SpaceArg, side_index: ElementIndex):
68
- return Tetmesh.side_outer_cell_index(args.geo_arg, side_index)
69
-
70
- @wp.func
71
- def _inner_cell_coords(args: SpaceArg, side_index: ElementIndex, side_coords: Coords):
72
- tet_index = Tetmesh.side_inner_cell_index(args.geo_arg, side_index)
73
- return Tetmesh.face_to_tet_coords(args.geo_arg, side_index, tet_index, side_coords)
74
-
75
- @wp.func
76
- def _outer_cell_coords(args: SpaceArg, side_index: ElementIndex, side_coords: Coords):
77
- tet_index = Tetmesh.side_outer_cell_index(args.geo_arg, side_index)
78
- return Tetmesh.face_to_tet_coords(args.geo_arg, side_index, tet_index, side_coords)
79
-
80
- @wp.func
81
- def _cell_to_side_coords(
82
- args: SpaceArg,
83
- side_index: ElementIndex,
84
- element_index: ElementIndex,
85
- element_coords: Coords,
86
- ):
87
- return Tetmesh.tet_to_face_coords(args.geo_arg, side_index, element_index, element_coords)
88
-
89
- def _compute_reference_transforms(self):
90
- self._reference_transforms = wp.empty(
91
- dtype=wp.mat33f, device=self._mesh.positions.device, shape=(self._mesh.cell_count())
92
- )
93
-
94
- wp.launch(
95
- kernel=TetmeshFunctionSpace._compute_reference_transforms_kernel,
96
- dim=self._reference_transforms.shape,
97
- device=self._reference_transforms.device,
98
- inputs=[self._mesh.tet_vertex_indices, self._mesh.positions, self._reference_transforms],
99
- )
100
-
101
63
  def _compute_tet_face_indices(self):
102
64
  self._tet_face_indices = wp.empty(
103
65
  dtype=int, device=self._mesh.tet_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
104
66
  )
105
67
 
106
68
  wp.launch(
107
- kernel=TetmeshFunctionSpace._compute_tet_face_indices_kernel,
69
+ kernel=TetmeshSpaceTopology._compute_tet_face_indices_kernel,
108
70
  dim=self._mesh._face_tet_indices.shape,
109
71
  device=self._mesh.tet_vertex_indices.device,
110
72
  inputs=[
111
- self._mesh._face_tet_indices,
112
- self._mesh._face_vertex_indices,
73
+ self._mesh.face_tet_indices,
74
+ self._mesh.face_vertex_indices,
113
75
  self._mesh.tet_vertex_indices,
114
76
  self._tet_face_indices,
115
77
  ],
116
78
  )
117
79
 
118
- def _compute_tet_edge_indices(self):
119
- from warp.fem.utils import _get_pinned_temp_count_buffer
120
- from warp.utils import array_scan
121
-
122
- device = self._mesh.tet_vertex_indices.device
123
-
124
- vertex_start_edge_count = wp.zeros(dtype=int, device=device, shape=self._mesh.vertex_count())
125
- vertex_start_edge_offsets = wp.empty_like(vertex_start_edge_count)
126
-
127
- vertex_edge_ends = wp.empty(dtype=int, device=device, shape=(6 * self._mesh.cell_count()))
128
-
129
- # Count face edges starting at each vertex
130
- wp.launch(
131
- kernel=TetmeshFunctionSpace._count_starting_edges_kernel,
132
- device=device,
133
- dim=self._mesh.cell_count(),
134
- inputs=[self._mesh.tet_vertex_indices, vertex_start_edge_count],
135
- )
136
-
137
- array_scan(in_array=vertex_start_edge_count, out_array=vertex_start_edge_offsets, inclusive=False)
138
-
139
- # Count number of unique edges (deduplicate across faces)
140
- vertex_unique_edge_count = vertex_start_edge_count
141
- wp.launch(
142
- kernel=TetmeshFunctionSpace._count_unique_starting_edges_kernel,
143
- device=device,
144
- dim=self._mesh.vertex_count(),
145
- inputs=[
146
- self._mesh._vertex_tet_offsets,
147
- self._mesh._vertex_tet_indices,
148
- self._mesh.tet_vertex_indices,
149
- vertex_start_edge_offsets,
150
- vertex_unique_edge_count,
151
- vertex_edge_ends,
152
- ],
153
- )
154
-
155
- vertex_unique_edge_offsets = wp.empty_like(vertex_start_edge_offsets)
156
- array_scan(in_array=vertex_start_edge_count, out_array=vertex_unique_edge_offsets, inclusive=False)
157
-
158
- # Get back edge count to host
159
- if device.is_cuda:
160
- edge_count = _get_pinned_temp_count_buffer(device)
161
- # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
162
- wp.copy(dest=edge_count, src=vertex_unique_edge_offsets, src_offset=self._mesh.vertex_count() - 1, count=1)
163
- wp.synchronize_stream(wp.get_stream())
164
- edge_count = int(edge_count.numpy()[0])
165
- else:
166
- edge_count = int(vertex_unique_edge_offsets.numpy()[self._mesh.vertex_count() - 1])
167
-
168
- self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=vec2i, device=device)
169
- self._tet_edge_indices = wp.empty(
170
- dtype=int, device=self._mesh.tet_vertex_indices.device, shape=(self._mesh.cell_count(), 6)
171
- )
172
-
173
- # Compress edge data
174
- wp.launch(
175
- kernel=TetmeshFunctionSpace._compress_edges_kernel,
176
- device=device,
177
- dim=self._mesh.vertex_count(),
178
- inputs=[
179
- self._mesh._vertex_tet_offsets,
180
- self._mesh._vertex_tet_indices,
181
- self._mesh.tet_vertex_indices,
182
- vertex_start_edge_offsets,
183
- vertex_unique_edge_offsets,
184
- vertex_unique_edge_count,
185
- vertex_edge_ends,
186
- self._edge_vertex_indices,
187
- self._tet_edge_indices,
188
- ],
189
- )
190
-
191
- @wp.kernel
192
- def _compute_reference_transforms_kernel(
193
- tet_vertex_indices: wp.array2d(dtype=int),
194
- positions: wp.array(dtype=wp.vec3f),
195
- transforms: wp.array(dtype=wp.mat33f),
196
- ):
197
- t = wp.tid()
198
-
199
- p0 = positions[tet_vertex_indices[t, 0]]
200
- p1 = positions[tet_vertex_indices[t, 1]]
201
- p2 = positions[tet_vertex_indices[t, 2]]
202
- p3 = positions[tet_vertex_indices[t, 3]]
203
-
204
- e1 = p1 - p0
205
- e2 = p2 - p0
206
- e3 = p3 - p0
207
-
208
- mat = wp.mat33(e1, e2, e3)
209
- transforms[t] = wp.transpose(wp.inverse(mat))
210
-
211
80
  @wp.func
212
81
  def _find_face_index_in_tet(
213
- face_vtx: vec3i,
214
- tet_vtx: vec4i,
82
+ face_vtx: wp.vec3i,
83
+ tet_vtx: wp.vec4i,
215
84
  ):
216
85
  for k in range(3):
217
- tvk = vec3i(tet_vtx[k], tet_vtx[(k + 1) % 4], tet_vtx[(k + 2) % 4])
86
+ tvk = wp.vec3i(tet_vtx[k], tet_vtx[(k + 1) % 4], tet_vtx[(k + 2) % 4])
218
87
 
219
88
  # Use fact that face always start with min vertex
220
89
  min_t = wp.min(tvk)
@@ -230,8 +99,8 @@ class TetmeshFunctionSpace(NodalFunctionSpace):
230
99
 
231
100
  @wp.kernel
232
101
  def _compute_tet_face_indices_kernel(
233
- face_tet_indices: wp.array(dtype=vec2i),
234
- face_vertex_indices: wp.array(dtype=vec3i),
102
+ face_tet_indices: wp.array(dtype=wp.vec2i),
103
+ face_vertex_indices: wp.array(dtype=wp.vec3i),
235
104
  tet_vertex_indices: wp.array2d(dtype=int),
236
105
  tet_face_indices: wp.array2d(dtype=int),
237
106
  ):
@@ -241,821 +110,106 @@ class TetmeshFunctionSpace(NodalFunctionSpace):
241
110
  face_tets = face_tet_indices[e]
242
111
 
243
112
  t0 = face_tets[0]
244
- t0_vtx = vec4i(
113
+ t0_vtx = wp.vec4i(
245
114
  tet_vertex_indices[t0, 0], tet_vertex_indices[t0, 1], tet_vertex_indices[t0, 2], tet_vertex_indices[t0, 3]
246
115
  )
247
- t0_face = TetmeshFunctionSpace._find_face_index_in_tet(face_vtx, t0_vtx)
116
+ t0_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t0_vtx)
248
117
  tet_face_indices[t0, t0_face] = e
249
118
 
250
119
  t1 = face_tets[1]
251
120
  if t1 != t0:
252
- t1_vtx = vec4i(
121
+ t1_vtx = wp.vec4i(
253
122
  tet_vertex_indices[t1, 0],
254
123
  tet_vertex_indices[t1, 1],
255
124
  tet_vertex_indices[t1, 2],
256
125
  tet_vertex_indices[t1, 3],
257
126
  )
258
- t1_face = TetmeshFunctionSpace._find_face_index_in_tet(face_vtx, t1_vtx)
127
+ t1_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t1_vtx)
259
128
  tet_face_indices[t1, t1_face] = e
260
129
 
261
- @wp.kernel
262
- def _count_starting_edges_kernel(
263
- tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
264
- ):
265
- t = wp.tid()
266
- for k in range(3):
267
- v0 = tri_vertex_indices[t, k]
268
- v1 = tri_vertex_indices[t, (k + 1) % 3]
269
-
270
- if v0 < v1:
271
- wp.atomic_add(vertex_start_edge_count, v0, 1)
272
- else:
273
- wp.atomic_add(vertex_start_edge_count, v1, 1)
274
-
275
- for k in range(3):
276
- v0 = tri_vertex_indices[t, k]
277
- v1 = tri_vertex_indices[t, 3]
278
-
279
- if v0 < v1:
280
- wp.atomic_add(vertex_start_edge_count, v0, 1)
281
- else:
282
- wp.atomic_add(vertex_start_edge_count, v1, 1)
283
-
284
- @wp.func
285
- def _find(
286
- needle: int,
287
- values: wp.array(dtype=int),
288
- beg: int,
289
- end: int,
290
- ):
291
- for i in range(beg, end):
292
- if values[i] == needle:
293
- return i
294
-
295
- return -1
296
-
297
- @wp.kernel
298
- def _count_unique_starting_edges_kernel(
299
- vertex_tet_offsets: wp.array(dtype=int),
300
- vertex_tet_indices: wp.array(dtype=int),
301
- tet_vertex_indices: wp.array2d(dtype=int),
302
- vertex_start_edge_offsets: wp.array(dtype=int),
303
- vertex_start_edge_count: wp.array(dtype=int),
304
- edge_ends: wp.array(dtype=int),
305
- ):
306
- v = wp.tid()
307
-
308
- edge_beg = vertex_start_edge_offsets[v]
309
-
310
- tet_beg = vertex_tet_offsets[v]
311
- tet_end = vertex_tet_offsets[v + 1]
312
-
313
- edge_cur = edge_beg
314
-
315
- for tet in range(tet_beg, tet_end):
316
- t = vertex_tet_indices[tet]
317
-
318
- for k in range(3):
319
- v0 = tet_vertex_indices[t, k]
320
- v1 = tet_vertex_indices[t, (k + 1) % 3]
321
-
322
- if v == wp.min(v0, v1):
323
- other_v = wp.max(v0, v1)
324
- if TetmeshFunctionSpace._find(other_v, edge_ends, edge_beg, edge_cur) == -1:
325
- edge_ends[edge_cur] = other_v
326
- edge_cur += 1
327
-
328
- for k in range(3):
329
- v0 = tet_vertex_indices[t, k]
330
- v1 = tet_vertex_indices[t, 3]
331
-
332
- if v == wp.min(v0, v1):
333
- other_v = wp.max(v0, v1)
334
- if TetmeshFunctionSpace._find(other_v, edge_ends, edge_beg, edge_cur) == -1:
335
- edge_ends[edge_cur] = other_v
336
- edge_cur += 1
337
-
338
- vertex_start_edge_count[v] = edge_cur - edge_beg
339
-
340
- @wp.kernel
341
- def _compress_edges_kernel(
342
- vertex_tet_offsets: wp.array(dtype=int),
343
- vertex_tet_indices: wp.array(dtype=int),
344
- tet_vertex_indices: wp.array2d(dtype=int),
345
- vertex_start_edge_offsets: wp.array(dtype=int),
346
- vertex_unique_edge_offsets: wp.array(dtype=int),
347
- vertex_unique_edge_count: wp.array(dtype=int),
348
- uncompressed_edge_ends: wp.array(dtype=int),
349
- edge_vertex_indices: wp.array(dtype=vec2i),
350
- tet_edge_indices: wp.array2d(dtype=int),
351
- ):
352
- v = wp.tid()
353
-
354
- uncompressed_beg = vertex_start_edge_offsets[v]
355
-
356
- unique_beg = vertex_unique_edge_offsets[v]
357
- unique_count = vertex_unique_edge_count[v]
358
-
359
- for e in range(unique_count):
360
- src_index = uncompressed_beg + e
361
- edge_index = unique_beg + e
362
-
363
- edge_vertex_indices[edge_index] = vec2i(v, uncompressed_edge_ends[src_index])
364
-
365
- tet_beg = vertex_tet_offsets[v]
366
- tet_end = vertex_tet_offsets[v + 1]
367
-
368
- for tet in range(tet_beg, tet_end):
369
- t = vertex_tet_indices[tet]
370
-
371
- for k in range(3):
372
- v0 = tet_vertex_indices[t, k]
373
- v1 = tet_vertex_indices[t, (k + 1) % 3]
374
-
375
- if v == wp.min(v0, v1):
376
- other_v = wp.max(v0, v1)
377
- edge_id = (
378
- TetmeshFunctionSpace._find(
379
- other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
380
- )
381
- - uncompressed_beg
382
- + unique_beg
383
- )
384
- tet_edge_indices[t][k] = edge_id
385
-
386
- for k in range(3):
387
- v0 = tet_vertex_indices[t, k]
388
- v1 = tet_vertex_indices[t, 3]
389
-
390
- if v == wp.min(v0, v1):
391
- other_v = wp.max(v0, v1)
392
- edge_id = (
393
- TetmeshFunctionSpace._find(
394
- other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
395
- )
396
- - uncompressed_beg
397
- + unique_beg
398
- )
399
- tet_edge_indices[t][k + 3] = edge_id
400
-
401
-
402
- class TetmeshPiecewiseConstantSpace(TetmeshFunctionSpace):
403
- ORDER = wp.constant(0)
404
- NODES_PER_ELEMENT = wp.constant(1)
405
-
406
- def __init__(self, grid: Tetmesh, dtype: type = float, dof_mapper: DofMapper = None):
407
- super().__init__(grid, dtype, dof_mapper)
408
-
409
- self.element_outer_weight = self.element_inner_weight
410
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
411
-
412
- def node_count(self) -> int:
413
- return self._mesh.cell_count()
414
-
415
- def node_positions(self):
416
- vtx_pos = self._mesh.positions.numpy()
417
- tet_vtx = self._mesh.tet_vertex_indices.numpy()
418
-
419
- tet_pos = vtx_pos[tet_vtx]
420
- centers = tet_pos.sum(axis=1) / 4.0
421
-
422
- return centers[:, 0], centers[:, 1], centers[:, 2]
423
-
424
- @wp.func
425
- def element_node_index(
426
- args: TetmeshFunctionSpace.SpaceArg,
427
- element_index: ElementIndex,
428
- node_index_in_elt: int,
429
- ):
430
- return element_index
431
-
432
- @wp.func
433
- def node_coords_in_element(
434
- args: TetmeshFunctionSpace.SpaceArg,
435
- element_index: ElementIndex,
436
- node_index_in_elt: int,
437
- ):
438
- if node_index_in_elt == 0:
439
- return Coords(1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0)
440
-
441
- return Coords(OUTSIDE)
442
-
443
- @wp.func
444
- def node_quadrature_weight(
445
- args: TetmeshFunctionSpace.SpaceArg,
446
- element_index: ElementIndex,
447
- node_index_in_elt: int,
448
- ):
449
- return 1.0
450
-
451
- @wp.func
452
- def element_inner_weight(
453
- args: TetmeshFunctionSpace.SpaceArg,
454
- element_index: ElementIndex,
455
- coords: Coords,
456
- node_index_in_elt: int,
457
- ):
458
- if node_index_in_elt == 0:
459
- return 1.0
460
- return 0.0
461
130
 
462
- @wp.func
463
- def element_inner_weight_gradient(
464
- args: TetmeshFunctionSpace.SpaceArg,
465
- element_index: ElementIndex,
466
- coords: Coords,
467
- node_index_in_elt: int,
468
- ):
469
- return wp.vec3(0.0)
131
+ class TetmeshDiscontinuousSpaceTopology(
132
+ DiscontinuousSpaceTopologyMixin,
133
+ SpaceTopology,
134
+ ):
135
+ def __init__(self, mesh: Tetmesh, shape: ShapeFunction):
136
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
470
137
 
471
- class Trace(TetmeshFunctionSpace.Trace):
472
- NODES_PER_ELEMENT = wp.constant(2)
473
- ORDER = wp.constant(0)
474
138
 
475
- def __init__(self, space: "TetmeshPiecewiseConstantSpace"):
476
- super().__init__(space)
139
+ class TetmeshBasisSpace(ShapeBasisSpace):
140
+ def __init__(self, topology: TetmeshSpaceTopology, shape: ShapeFunction):
141
+ super().__init__(topology, shape)
477
142
 
478
- self.element_node_index = self._make_element_node_index(space)
143
+ self._mesh: Tetmesh = topology.geometry
479
144
 
480
- self.element_inner_weight = self._make_element_inner_weight(space)
481
- self.element_inner_weight_gradient = self._make_element_inner_weight_gradient(space)
482
145
 
483
- self.element_outer_weight = self._make_element_outer_weight(space)
484
- self.element_outer_weight_gradient = self._make_element_outer_weight_gradient(space)
146
+ class TetmeshPiecewiseConstantBasis(TetmeshBasisSpace):
147
+ def __init__(self, mesh: Tetmesh):
148
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=3)
149
+ topology = TetmeshDiscontinuousSpaceTopology(mesh, shape)
150
+ super().__init__(shape=shape, topology=topology)
485
151
 
152
+ class Trace(TraceBasisSpace):
486
153
  @wp.func
487
- def node_coords_in_element(
488
- args: TetmeshFunctionSpace.SpaceArg,
154
+ def _node_coords_in_element(
155
+ side_arg: Tetmesh.SideArg,
156
+ basis_arg: TetmeshBasisSpace.BasisArg,
489
157
  element_index: ElementIndex,
490
158
  node_index_in_element: int,
491
159
  ):
492
- if node_index_in_element == 0:
493
- return Coords(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0)
494
- elif node_index_in_element == 1:
495
- return Coords(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0)
496
-
497
- return Coords(OUTSIDE)
160
+ return Coords(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0)
498
161
 
499
- @wp.func
500
- def node_quadrature_weight(
501
- args: TetmeshFunctionSpace.SpaceArg,
502
- element_index: ElementIndex,
503
- node_index_in_elt: int,
504
- ):
505
- return 1.0
162
+ def make_node_coords_in_element(self):
163
+ return self._node_coords_in_element
506
164
 
507
165
  def trace(self):
508
- return TetmeshPiecewiseConstantSpace.Trace(self)
509
-
510
-
511
- def _tet_node_index(tx: int, ty: int, tz: int, degree: int):
512
- from .trimesh_2d_function_space import _triangle_node_index
513
-
514
- VERTEX_NODE_COUNT = 4
515
- EDGE_INTERIOR_NODE_COUNT = degree - 1
516
- VERTEX_EDGE_NODE_COUNT = VERTEX_NODE_COUNT + 6 * EDGE_INTERIOR_NODE_COUNT
517
- FACE_INTERIOR_NODE_COUNT = (degree - 1) * (degree - 2) // 2
518
- VERTEX_EDGE_FACE_NODE_COUNT = VERTEX_EDGE_NODE_COUNT + 4 * FACE_INTERIOR_NODE_COUNT
519
-
520
- # Index in similar order to e.g. VTK
521
- # First vertices, then edges (counterclokwise), then faces, then interior points (recursively)
522
-
523
- if tx == 0:
524
- if ty == 0:
525
- if tz == 0:
526
- return 0
527
- elif tz == degree:
528
- return 3
529
- else:
530
- # 0-3 edge
531
- edge_index = 3
532
- return VERTEX_NODE_COUNT + EDGE_INTERIOR_NODE_COUNT * edge_index + (tz - 1)
533
- elif tz == 0:
534
- if ty == degree:
535
- return 2
536
- else:
537
- # 2-0 edge
538
- edge_index = 2
539
- return VERTEX_NODE_COUNT + EDGE_INTERIOR_NODE_COUNT * edge_index + (EDGE_INTERIOR_NODE_COUNT - ty)
540
- elif tz + ty == degree:
541
- # 2-3 edge
542
- edge_index = 5
543
- return VERTEX_NODE_COUNT + EDGE_INTERIOR_NODE_COUNT * edge_index + (tz - 1)
544
- else:
545
- # 2-3-0 face
546
- face_index = 2
547
- return (
548
- VERTEX_EDGE_NODE_COUNT
549
- + FACE_INTERIOR_NODE_COUNT * face_index
550
- + _triangle_node_index(degree - 2 - ty, tz - 1, degree - 2)
551
- )
552
- elif ty == 0:
553
- if tz == 0:
554
- if tx == degree:
555
- return 1
556
- else:
557
- # 0-1 edge
558
- edge_index = 0
559
- return VERTEX_NODE_COUNT + EDGE_INTERIOR_NODE_COUNT * edge_index + (tx - 1)
560
- elif tz + tx == degree:
561
- # 1-3 edge
562
- edge_index = 4
563
- return VERTEX_NODE_COUNT + EDGE_INTERIOR_NODE_COUNT * edge_index + (tz - 1)
564
- else:
565
- # 3-0-1 face
566
- face_index = 3
567
- return (
568
- VERTEX_EDGE_NODE_COUNT
569
- + FACE_INTERIOR_NODE_COUNT * face_index
570
- + _triangle_node_index(tx - 1, tz - 1, degree - 2)
571
- )
572
- elif tz == 0:
573
- if tx + ty == degree:
574
- # 1-2 edge
575
- edge_index = 1
576
- return VERTEX_NODE_COUNT + EDGE_INTERIOR_NODE_COUNT * edge_index + (ty - 1)
577
- else:
578
- # 0-1-2 face
579
- face_index = 0
580
- return (
581
- VERTEX_EDGE_NODE_COUNT
582
- + FACE_INTERIOR_NODE_COUNT * face_index
583
- + _triangle_node_index(tx - 1, ty - 1, degree - 2)
584
- )
585
- elif tx + ty + tz == degree:
586
- # 1-2-3 face
587
- face_index = 1
588
- return (
589
- VERTEX_EDGE_NODE_COUNT
590
- + FACE_INTERIOR_NODE_COUNT * face_index
591
- + _triangle_node_index(tx - 1, tz - 1, degree - 2)
592
- )
593
-
594
- return VERTEX_EDGE_FACE_NODE_COUNT + _tet_node_index(tx - 1, ty - 1, tz - 1, degree - 3)
595
-
596
-
597
- class TetmeshPolynomialShapeFunctions:
598
- INVALID = wp.constant(-1)
599
- VERTEX = wp.constant(0)
600
- EDGE = wp.constant(1)
601
- FACE = wp.constant(2)
602
- INTERIOR = wp.constant(3)
603
-
604
- def __init__(self, degree: int):
605
- self.ORDER = wp.constant(degree)
606
-
607
- self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) * (degree + 3) // 6)
608
- self.NODES_PER_SIDE = wp.constant((degree + 1) * (degree + 2) // 2)
609
-
610
- tet_coords = np.empty((self.NODES_PER_ELEMENT, 3), dtype=int)
611
-
612
- for tx in range(degree + 1):
613
- for ty in range(degree + 1 - tx):
614
- for tz in range(degree + 1 - tx - ty):
615
- index = _tet_node_index(tx, ty, tz, degree)
616
- tet_coords[index] = [tx, ty, tz]
617
-
618
- CoordTypeVec = wp.mat(dtype=int, shape=(self.NODES_PER_ELEMENT, 3))
619
- self.NODE_TET_COORDS = wp.constant(CoordTypeVec(tet_coords))
620
-
621
- self.node_type_and_type_index = self._get_node_type_and_type_index()
622
- self._node_tet_coordinates = self._get_node_tet_coordinates()
623
-
624
- @property
625
- def name(self) -> str:
626
- return f"{self.ORDER}"
166
+ return TetmeshPiecewiseConstantBasis.Trace(self)
627
167
 
628
- def _get_node_tet_coordinates(self):
629
- NODE_TET_COORDS = self.NODE_TET_COORDS
630
168
 
631
- def node_tet_coordinates(
632
- node_index_in_elt: int,
633
- ):
634
- return vec3i(
635
- NODE_TET_COORDS[node_index_in_elt, 0],
636
- NODE_TET_COORDS[node_index_in_elt, 1],
637
- NODE_TET_COORDS[node_index_in_elt, 2],
638
- )
639
-
640
- from warp.fem import cache
641
-
642
- return cache.get_func(node_tet_coordinates, self.name)
643
-
644
- def _get_node_type_and_type_index(self):
645
- ORDER = self.ORDER
646
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
647
-
648
- def node_type_and_index(
649
- node_index_in_elt: int,
650
- ):
651
- if node_index_in_elt < 0 or node_index_in_elt >= NODES_PER_ELEMENT:
652
- return TetmeshPolynomialShapeFunctions.INVALID, TetmeshPolynomialShapeFunctions.INVALID
653
-
654
- if node_index_in_elt < 4:
655
- return TetmeshPolynomialShapeFunctions.VERTEX, node_index_in_elt
656
-
657
- if node_index_in_elt < (6 * ORDER - 2):
658
- return TetmeshPolynomialShapeFunctions.EDGE, (node_index_in_elt - 4)
659
-
660
- if node_index_in_elt < (2 * ORDER * ORDER + 2):
661
- return TetmeshPolynomialShapeFunctions.FACE, (node_index_in_elt - (6 * ORDER - 2))
662
-
663
- return TetmeshPolynomialShapeFunctions.INTERIOR, (node_index_in_elt - (2 * ORDER * ORDER + 2))
664
-
665
- from warp.fem import cache
666
-
667
- return cache.get_func(node_type_and_index, self.name)
668
-
669
- def make_node_coords_in_element(self):
670
- ORDER = self.ORDER
671
-
672
- def node_coords_in_element(
673
- args: TetmeshFunctionSpace.SpaceArg,
674
- element_index: ElementIndex,
675
- node_index_in_elt: int,
676
- ):
677
- tet_coords = self._node_tet_coordinates(node_index_in_elt)
678
- cx = float(tet_coords[0]) / float(ORDER)
679
- cy = float(tet_coords[1]) / float(ORDER)
680
- cz = float(tet_coords[2]) / float(ORDER)
681
- return Coords(cx, cy, cz)
682
-
683
- from warp.fem import cache
684
-
685
- return cache.get_func(node_coords_in_element, self.name)
686
-
687
- def make_node_quadrature_weight(self):
688
- ORDER = self.ORDER
689
-
690
- def node_uniform_quadrature_weight(
691
- args: TetmeshFunctionSpace.SpaceArg,
692
- element_index: ElementIndex,
693
- node_index_in_elt: int,
694
- ):
695
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
696
-
697
- base_weight = 1.0 / float(4 * ORDER * ORDER * ORDER)
698
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
699
- return base_weight
700
- if node_type == TetmeshPolynomialShapeFunctions.EDGE:
701
- return 2.0 * base_weight
702
- if node_type == TetmeshPolynomialShapeFunctions.FACE:
703
- return 4.0 * base_weight
704
- return 8.0 * base_weight
705
-
706
- def node_linear_quadrature_weight(
707
- args: TetmeshFunctionSpace.SpaceArg,
708
- element_index: ElementIndex,
709
- node_index_in_elt: int,
710
- ):
711
- return 1.0 / 4.0
712
-
713
- from warp.fem import cache
714
-
715
- if ORDER == 1:
716
- return cache.get_func(node_linear_quadrature_weight, self.name)
717
- return cache.get_func(node_uniform_quadrature_weight, self.name)
718
-
719
- def make_trace_node_quadrature_weight(self):
720
- ORDER = self.ORDER
721
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
722
-
723
- def trace_uniform_node_quadrature_weight(
724
- args: TetmeshFunctionSpace.SpaceArg,
725
- element_index: ElementIndex,
726
- node_index_in_elt: int,
727
- ):
728
- if node_index_in_elt >= NODES_PER_ELEMENT:
729
- node_index_in_cell = node_index_in_elt - NODES_PER_ELEMENT
730
- else:
731
- node_index_in_cell = node_index_in_elt
732
-
733
- # We're either on a side interior or at a vertex
734
- node_type, type_index = self.node_type_and_type_index(node_index_in_cell)
735
-
736
- base_weight = 1.0 / float(3 * ORDER * ORDER)
737
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
738
- return base_weight
739
- if node_type == TetmeshPolynomialShapeFunctions.EDGE:
740
- return 2.0 * base_weight
741
-
742
- return 4.0 * base_weight
743
-
744
- def trace_linear_node_quadrature_weight(
745
- args: TetmeshFunctionSpace.SpaceArg,
746
- element_index: ElementIndex,
747
- node_index_in_elt: int,
748
- ):
749
- return 1.0 / 3.0
750
-
751
- from warp.fem import cache
752
-
753
- if ORDER == 1:
754
- return cache.get_func(trace_linear_node_quadrature_weight, self.name)
755
-
756
- return cache.get_func(trace_uniform_node_quadrature_weight, self.name)
757
-
758
- def make_element_inner_weight(self):
759
- ORDER = self.ORDER
760
-
761
- def element_inner_weight_linear(
762
- args: TetmeshFunctionSpace.SpaceArg,
763
- element_index: ElementIndex,
764
- coords: Coords,
765
- node_index_in_elt: int,
766
- ):
767
- if node_index_in_elt < 0 or node_index_in_elt >= 4:
768
- return 0.0
769
-
770
- tet_coords = wp.vec4(1.0 - coords[0] - coords[1] - coords[2], coords[0], coords[1], coords[2])
771
- return tet_coords[node_index_in_elt]
772
-
773
- def element_inner_weight_quadratic(
774
- args: TetmeshFunctionSpace.SpaceArg,
775
- element_index: ElementIndex,
776
- coords: Coords,
777
- node_index_in_elt: int,
778
- ):
779
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
780
-
781
- tet_coords = wp.vec4(1.0 - coords[0] - coords[1] - coords[2], coords[0], coords[1], coords[2])
782
-
783
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
784
- # Vertex
785
- return tet_coords[type_index] * (2.0 * tet_coords[type_index] - 1.0)
786
-
787
- elif node_type == TetmeshPolynomialShapeFunctions.EDGE:
788
- # Edge
789
- if type_index < 3:
790
- c1 = type_index
791
- c2 = (type_index + 1) % 3
792
- else:
793
- c1 = type_index - 3
794
- c2 = 3
795
- return 4.0 * tet_coords[c1] * tet_coords[c2]
796
-
797
- return 0.0
798
-
799
- def element_inner_weight_cubic(
800
- args: TetmeshFunctionSpace.SpaceArg,
801
- element_index: ElementIndex,
802
- coords: Coords,
803
- node_index_in_elt: int,
804
- ):
805
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
806
-
807
- tet_coords = wp.vec4(1.0 - coords[0] - coords[1] - coords[2], coords[0], coords[1], coords[2])
808
-
809
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
810
- # Vertex
811
- return (
812
- 0.5
813
- * tet_coords[type_index]
814
- * (3.0 * tet_coords[type_index] - 1.0)
815
- * (3.0 * tet_coords[type_index] - 2.0)
816
- )
817
-
818
- elif node_type == TetmeshPolynomialShapeFunctions.EDGE:
819
- # Edge
820
- edge = type_index // 2
821
- edge_node = type_index - 2 * edge
822
-
823
- if edge < 3:
824
- c1 = (edge + edge_node) % 3
825
- c2 = (edge + 1 - edge_node) % 3
826
- elif edge_node == 0:
827
- c1 = edge - 3
828
- c2 = 3
829
- else:
830
- c1 = 3
831
- c2 = edge - 3
832
-
833
- return 4.5 * tet_coords[c1] * tet_coords[c2] * (3.0 * tet_coords[c1] - 1.0)
834
-
835
- elif node_type == TetmeshPolynomialShapeFunctions.FACE:
836
- # Interior
837
- c1 = type_index
838
- c2 = (c1 + 1) % 4
839
- c3 = (c1 + 2) % 4
840
- return 27.0 * tet_coords[c1] * tet_coords[c2] * tet_coords[c3]
841
-
842
- return 0.0
843
-
844
- from warp.fem import cache
845
-
846
- if ORDER == 1:
847
- return cache.get_func(element_inner_weight_linear, self.name)
848
- elif ORDER == 2:
849
- return cache.get_func(element_inner_weight_quadratic, self.name)
850
- elif ORDER == 3:
851
- return cache.get_func(element_inner_weight_cubic, self.name)
852
-
853
- return None
854
-
855
- def make_element_inner_weight_gradient(self):
856
- ORDER = self.ORDER
857
-
858
- def element_inner_weight_gradient_linear(
859
- args: TetmeshFunctionSpace.SpaceArg,
860
- element_index: ElementIndex,
861
- coords: Coords,
862
- node_index_in_elt: int,
863
- ):
864
- if node_index_in_elt < 0 or node_index_in_elt >= 4:
865
- return wp.vec3(0.0)
866
-
867
- dw_dc = wp.vec4(0.0)
868
- dw_dc[node_index_in_elt] = 1.0
869
-
870
- dw_du = wp.vec3(dw_dc[1] - dw_dc[0], dw_dc[2] - dw_dc[0], dw_dc[3] - dw_dc[0])
871
-
872
- return args.reference_transforms[element_index] * dw_du
873
-
874
- def element_inner_weight_gradient_quadratic(
875
- args: TetmeshFunctionSpace.SpaceArg,
876
- element_index: ElementIndex,
877
- coords: Coords,
878
- node_index_in_elt: int,
879
- ):
880
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
881
-
882
- tet_coords = wp.vec4(1.0 - coords[0] - coords[1] - coords[2], coords[0], coords[1], coords[2])
883
- dw_dc = wp.vec4(0.0)
884
-
885
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
886
- # Vertex
887
- dw_dc[type_index] = 4.0 * tet_coords[type_index] - 1.0
888
-
889
- elif node_type == TetmeshPolynomialShapeFunctions.EDGE:
890
- # Edge
891
- if type_index < 3:
892
- c1 = type_index
893
- c2 = (type_index + 1) % 3
894
- else:
895
- c1 = type_index - 3
896
- c2 = 3
897
- dw_dc[c1] = 4.0 * tet_coords[c2]
898
- dw_dc[c2] = 4.0 * tet_coords[c1]
899
-
900
- dw_du = wp.vec3(dw_dc[1] - dw_dc[0], dw_dc[2] - dw_dc[0], dw_dc[3] - dw_dc[0])
901
- return args.reference_transforms[element_index] * dw_du
902
-
903
- def element_inner_weight_gradient_cubic(
904
- args: TetmeshFunctionSpace.SpaceArg,
905
- element_index: ElementIndex,
906
- coords: Coords,
907
- node_index_in_elt: int,
908
- ):
909
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
910
-
911
- tet_coords = wp.vec4(1.0 - coords[0] - coords[1] - coords[2], coords[0], coords[1], coords[2])
912
-
913
- dw_dc = wp.vec4(0.0)
914
-
915
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
916
- # Vertex
917
- dw_dc[type_index] = (
918
- 0.5 * 27.0 * tet_coords[type_index] * tet_coords[type_index] - 9.0 * tet_coords[type_index] + 1.0
919
- )
920
-
921
- elif node_type == TetmeshPolynomialShapeFunctions.EDGE:
922
- # Edge
923
- edge = type_index // 2
924
- edge_node = type_index - 2 * edge
925
-
926
- if edge < 3:
927
- c1 = (edge + edge_node) % 3
928
- c2 = (edge + 1 - edge_node) % 3
929
- elif edge_node == 0:
930
- c1 = edge - 3
931
- c2 = 3
932
- else:
933
- c1 = 3
934
- c2 = edge - 3
935
-
936
- dw_dc[c1] = 4.5 * tet_coords[c2] * (6.0 * tet_coords[c1] - 1.0)
937
- dw_dc[c2] = 4.5 * tet_coords[c1] * (3.0 * tet_coords[c1] - 1.0)
938
-
939
- elif node_type == TetmeshPolynomialShapeFunctions.FACE:
940
- # Interior
941
- c1 = type_index
942
- c2 = (c1 + 1) % 4
943
- c3 = (c1 + 2) % 4
944
-
945
- dw_dc[c1] = 27.0 * tet_coords[c2] * tet_coords[c3]
946
- dw_dc[c2] = 27.0 * tet_coords[c3] * tet_coords[c1]
947
- dw_dc[c3] = 27.0 * tet_coords[c1] * tet_coords[c2]
948
-
949
- dw_du = wp.vec3(dw_dc[1] - dw_dc[0], dw_dc[2] - dw_dc[0], dw_dc[3] - dw_dc[0])
950
- return args.reference_transforms[element_index] * dw_du
951
-
952
- from warp.fem import cache
953
-
954
- if ORDER == 1:
955
- return cache.get_func(element_inner_weight_gradient_linear, self.name)
956
- elif ORDER == 2:
957
- return cache.get_func(element_inner_weight_gradient_quadratic, self.name)
958
- elif ORDER == 3:
959
- return cache.get_func(element_inner_weight_gradient_cubic, self.name)
960
-
961
- return None
962
-
963
- @staticmethod
964
- def node_positions(space):
965
- if space.ORDER == 1:
966
- node_positions = space._mesh.positions.numpy()
967
- return node_positions[:, 0], node_positions[:, 1], node_positions[:, 2]
968
-
969
- NODES_PER_ELEMENT = space.NODES_PER_ELEMENT
970
-
971
- def fill_node_positions_fn(
972
- space_arg: space.SpaceArg,
973
- node_positions: wp.array(dtype=wp.vec3),
974
- ):
975
- element_index = wp.tid()
976
- tet_idx = space_arg.geo_arg.tet_vertex_indices[element_index]
977
- p0 = space_arg.geo_arg.positions[tet_idx[0]]
978
- p1 = space_arg.geo_arg.positions[tet_idx[1]]
979
- p2 = space_arg.geo_arg.positions[tet_idx[2]]
980
- p3 = space_arg.geo_arg.positions[tet_idx[3]]
981
-
982
- for n in range(NODES_PER_ELEMENT):
983
- node_index = space.element_node_index(space_arg, element_index, n)
984
- coords = space.node_coords_in_element(space_arg, element_index, n)
985
-
986
- pos = p0 + coords[0] * (p1 - p0) + coords[1] * (p2 - p0) + coords[2] * (p3 - p0)
169
+ class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology):
170
+ def __init__(self, mesh: Tetmesh, shape: TetrahedronPolynomialShapeFunctions):
171
+ super().__init__(mesh, shape, need_tet_edge_indices=shape.ORDER >= 2, need_tet_face_indices=shape.ORDER >= 3)
987
172
 
988
- node_positions[node_index] = pos
989
-
990
- from warp.fem import cache
173
+ self.element_node_index = self._make_element_node_index()
991
174
 
992
- fill_node_positions = cache.get_kernel(
993
- fill_node_positions_fn,
994
- suffix=space.name,
995
- )
175
+ def node_count(self) -> int:
176
+ ORDER = self._shape.ORDER
177
+ INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1)
178
+ INTERIOR_NODES_PER_FACE = max(0, ORDER - 2) * max(0, ORDER - 1) // 2
179
+ INTERIOR_NODES_PER_CELL = max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6
996
180
 
997
- device = space._mesh.tet_vertex_indices.device
998
- node_positions = wp.empty(
999
- shape=space.node_count(),
1000
- dtype=wp.vec3,
1001
- device=device,
1002
- )
1003
- wp.launch(
1004
- dim=space._mesh.cell_count(),
1005
- kernel=fill_node_positions,
1006
- inputs=[
1007
- space.space_arg_value(device),
1008
- node_positions,
1009
- ],
1010
- device=device,
181
+ return (
182
+ self._mesh.vertex_count()
183
+ + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE
184
+ + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
185
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
1011
186
  )
1012
187
 
1013
- node_positions = node_positions.numpy()
1014
- return node_positions[:, 0], node_positions[:, 1], node_positions[:, 2]
1015
-
1016
-
1017
- class TetmeshPolynomialSpace(TetmeshFunctionSpace):
1018
- def __init__(self, grid: Tetmesh, degree: int, dtype: type = float, dof_mapper: DofMapper = None):
1019
- super().__init__(grid, dtype, dof_mapper)
1020
-
1021
- self._shape = TetmeshPolynomialShapeFunctions(degree)
1022
-
1023
- self.ORDER = self._shape.ORDER
1024
- self.NODES_PER_ELEMENT = self._shape.NODES_PER_ELEMENT
1025
-
1026
- self.element_node_index = self._make_element_node_index()
1027
- self.node_coords_in_element = self._shape.make_node_coords_in_element()
1028
- self.node_quadrature_weight = self._shape.make_node_quadrature_weight()
1029
- self.element_inner_weight = self._shape.make_element_inner_weight()
1030
- self.element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
1031
-
1032
- self.element_outer_weight = self.element_inner_weight
1033
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
1034
-
1035
188
  def _make_element_node_index(self):
1036
- INTERIOR_NODES_PER_EDGE = wp.constant(max(0, self.ORDER - 1))
1037
- INTERIOR_NODES_PER_FACE = wp.constant(max(0, self.ORDER - 2) * max(0, self.ORDER - 1) // 2)
1038
- INTERIOR_NODES_PER_CELL = wp.constant(
1039
- max(0, self.ORDER - 3) * max(0, self.ORDER - 2) * max(0, self.ORDER - 1) // 6
1040
- )
189
+ ORDER = self._shape.ORDER
190
+ INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1))
191
+ INTERIOR_NODES_PER_FACE = wp.constant(max(0, ORDER - 2) * max(0, ORDER - 1) // 2)
192
+ INTERIOR_NODES_PER_CELL = wp.constant(max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6)
1041
193
 
194
+ @cache.dynamic_func(suffix=self.name)
1042
195
  def element_node_index(
1043
- args: TetmeshFunctionSpace.SpaceArg,
196
+ geo_arg: Tetmesh.CellArg,
197
+ topo_arg: TetmeshTopologyArg,
1044
198
  element_index: ElementIndex,
1045
199
  node_index_in_elt: int,
1046
200
  ):
1047
201
  node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
1048
202
 
1049
- if node_type == TetmeshPolynomialShapeFunctions.VERTEX:
1050
- return args.geo_arg.tet_vertex_indices[element_index][type_index]
203
+ if node_type == TetrahedronPolynomialShapeFunctions.VERTEX:
204
+ return geo_arg.tet_vertex_indices[element_index][type_index]
1051
205
 
1052
- global_offset = args.vertex_count
206
+ global_offset = topo_arg.vertex_count
1053
207
 
1054
- if node_type == TetmeshPolynomialShapeFunctions.EDGE:
208
+ if node_type == TetrahedronPolynomialShapeFunctions.EDGE:
1055
209
  edge = type_index // INTERIOR_NODES_PER_EDGE
1056
210
  edge_node = type_index - INTERIOR_NODES_PER_EDGE * edge
1057
211
 
1058
- global_edge_index = args.tet_edge_indices[element_index][edge]
212
+ global_edge_index = topo_arg.tet_edge_indices[element_index][edge]
1059
213
 
1060
214
  # Test if we need to swap edge direction
1061
215
  if INTERIOR_NODES_PER_EDGE > 1:
@@ -1066,28 +220,25 @@ class TetmeshPolynomialSpace(TetmeshFunctionSpace):
1066
220
  c1 = edge - 3
1067
221
  c2 = 3
1068
222
 
1069
- if (
1070
- args.geo_arg.tet_vertex_indices[element_index][c1]
1071
- > args.geo_arg.tet_vertex_indices[element_index][c2]
1072
- ):
223
+ if geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2]:
1073
224
  edge_node = INTERIOR_NODES_PER_EDGE - 1 - edge_node
1074
225
 
1075
226
  return global_offset + INTERIOR_NODES_PER_EDGE * global_edge_index + edge_node
1076
227
 
1077
- global_offset += INTERIOR_NODES_PER_EDGE * args.edge_count
228
+ global_offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count
1078
229
 
1079
- if node_type == TetmeshPolynomialShapeFunctions.FACE:
230
+ if node_type == TetrahedronPolynomialShapeFunctions.FACE:
1080
231
  face = type_index // INTERIOR_NODES_PER_FACE
1081
232
  face_node = type_index - INTERIOR_NODES_PER_FACE * face
1082
233
 
1083
- global_face_index = args.tet_face_indices[element_index][face]
234
+ global_face_index = topo_arg.tet_face_indices[element_index][face]
1084
235
 
1085
236
  if INTERIOR_NODES_PER_FACE == 3:
1086
237
  # Hard code for P4 case, 3 nodes per face
1087
238
  # Higher orders would require rotating triangle coordinates, this is not supported yet
1088
239
 
1089
- vidx = args.geo_arg.tet_vertex_indices[element_index][(face + face_node) % 4]
1090
- fvi = args.geo_arg.face_vertex_indices[global_face_index]
240
+ vidx = geo_arg.tet_vertex_indices[element_index][(face + face_node) % 4]
241
+ fvi = topo_arg.face_vertex_indices[global_face_index]
1091
242
 
1092
243
  if vidx == fvi[0]:
1093
244
  face_node = 0
@@ -1098,102 +249,44 @@ class TetmeshPolynomialSpace(TetmeshFunctionSpace):
1098
249
 
1099
250
  return global_offset + INTERIOR_NODES_PER_FACE * global_face_index + face_node
1100
251
 
1101
- global_offset += INTERIOR_NODES_PER_FACE * args.face_count
252
+ global_offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count
1102
253
 
1103
254
  return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
1104
255
 
1105
- from warp.fem import cache
1106
-
1107
- return cache.get_func(element_node_index, self.name)
1108
-
1109
- def node_count(self) -> int:
1110
- INTERIOR_NODES_PER_EDGE = wp.constant(max(0, self.ORDER - 1))
1111
- INTERIOR_NODES_PER_FACE = wp.constant(max(0, self.ORDER - 2) * max(0, self.ORDER - 1) // 2)
1112
- INTERIOR_NODES_PER_CELL = wp.constant(
1113
- max(0, self.ORDER - 3) * max(0, self.ORDER - 2) * max(0, self.ORDER - 1) // 6
1114
- )
1115
-
1116
- return (
1117
- self._mesh.vertex_count()
1118
- + self.edge_count() * INTERIOR_NODES_PER_EDGE
1119
- + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
1120
- + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
1121
- )
1122
-
1123
- def node_positions(self):
1124
- return TetmeshPolynomialShapeFunctions.node_positions(self)
1125
-
1126
- class Trace(TetmeshFunctionSpace.Trace):
1127
- NODES_PER_ELEMENT = wp.constant(2)
1128
- ORDER = wp.constant(0)
1129
-
1130
- def __init__(self, space: "TetmeshPolynomialSpace"):
1131
- super().__init__(space)
1132
-
1133
- self.element_node_index = self._make_element_node_index(space)
1134
- self.node_coords_in_element = self._make_node_coords_in_element(space)
1135
- self.node_quadrature_weight = space._shape.make_trace_node_quadrature_weight()
1136
-
1137
- self.element_inner_weight = self._make_element_inner_weight(space)
1138
- self.element_inner_weight_gradient = self._make_element_inner_weight_gradient(space)
256
+ return element_node_index
1139
257
 
1140
- self.element_outer_weight = self._make_element_outer_weight(space)
1141
- self.element_outer_weight_gradient = self._make_element_outer_weight_gradient(space)
1142
258
 
1143
- def trace(self):
1144
- return TetmeshPolynomialSpace.Trace(self)
1145
-
1146
-
1147
- class TetmeshDGPolynomialSpace(TetmeshFunctionSpace):
259
+ class TetmeshPolynomialBasisSpace(TetmeshBasisSpace):
1148
260
  def __init__(
1149
261
  self,
1150
262
  mesh: Tetmesh,
1151
263
  degree: int,
1152
- dtype: type = float,
1153
- dof_mapper: DofMapper = None,
1154
264
  ):
1155
- super().__init__(mesh, dtype, dof_mapper)
1156
-
1157
- self._shape = TetmeshPolynomialShapeFunctions(degree)
1158
-
1159
- self.ORDER = self._shape.ORDER
1160
- self.NODES_PER_ELEMENT = self._shape.NODES_PER_ELEMENT
1161
-
1162
- self.element_node_index = self._make_element_node_index()
1163
- self.node_coords_in_element = self._shape.make_node_coords_in_element()
1164
- self.node_quadrature_weight = self._shape.make_node_quadrature_weight()
1165
- self.element_inner_weight = self._shape.make_element_inner_weight()
1166
- self.element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
1167
-
1168
- self.element_outer_weight = self.element_inner_weight
1169
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
1170
-
1171
- def node_count(self) -> int:
1172
- return self._mesh.cell_count() * self.NODES_PER_ELEMENT
1173
-
1174
- def node_positions(self):
1175
- return TetmeshPolynomialShapeFunctions.node_positions(self)
265
+ shape = TetrahedronPolynomialShapeFunctions(degree)
266
+ topology = forward_base_topology(TetmeshPolynomialSpaceTopology, mesh, shape)
1176
267
 
1177
- def node_triangulation(self):
1178
- return TetmeshPolynomialShapeFunctions.node_triangulation(self)
268
+ super().__init__(topology, shape)
1179
269
 
1180
- def _make_element_node_index(self):
1181
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
1182
270
 
1183
- def element_node_index(
1184
- args: TetmeshFunctionSpace.SpaceArg,
1185
- element_index: ElementIndex,
1186
- node_index_in_elt: int,
1187
- ):
1188
- return element_index * NODES_PER_ELEMENT + node_index_in_elt
271
+ class TetmeshDGPolynomialBasisSpace(TetmeshBasisSpace):
272
+ def __init__(
273
+ self,
274
+ mesh: Tetmesh,
275
+ degree: int,
276
+ ):
277
+ shape = TetrahedronPolynomialShapeFunctions(degree)
278
+ topology = TetmeshDiscontinuousSpaceTopology(mesh, shape)
1189
279
 
1190
- from warp.fem import cache
280
+ super().__init__(topology, shape)
1191
281
 
1192
- return cache.get_func(element_node_index, f"{self.name}_{self.ORDER}")
1193
282
 
1194
- class Trace(TetmeshPolynomialSpace.Trace):
1195
- def __init__(self, space: "TetmeshDGPolynomialSpace"):
1196
- super().__init__(space)
283
+ class TetmeshNonConformingPolynomialBasisSpace(TetmeshBasisSpace):
284
+ def __init__(
285
+ self,
286
+ mesh: Tetmesh,
287
+ degree: int,
288
+ ):
289
+ shape = TetrahedronNonConformingPolynomialShapeFunctions(degree)
290
+ topology = TetmeshDiscontinuousSpaceTopology(mesh, shape)
1197
291
 
1198
- def trace(self):
1199
- return TetmeshDGPolynomialSpace.Trace(self)
292
+ super().__init__(topology, shape)