warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,352 @@
1
+ import warp as wp
2
+
3
+ from warp.fem.types import ElementIndex, Coords
4
+ from warp.fem.polynomial import Polynomial, is_closed
5
+ from warp.fem.geometry import Hexmesh
6
+ from warp.fem import cache
7
+
8
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
9
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
+
11
+ from .shape import ShapeFunction, ConstantShapeFunction
12
+ from .shape import (
13
+ CubeTripolynomialShapeFunctions,
14
+ CubeSerendipityShapeFunctions,
15
+ CubeNonConformingPolynomialShapeFunctions,
16
+ )
17
+
18
+ from warp.fem.geometry.hexmesh import (
19
+ EDGE_VERTEX_INDICES,
20
+ FACE_ORIENTATION,
21
+ FACE_TRANSLATION,
22
+ )
23
+
24
+ _FACE_ORIENTATION_I = wp.constant(wp.mat(shape=(16, 2), dtype=int)(FACE_ORIENTATION))
25
+ _FACE_TRANSLATION_I = wp.constant(wp.mat(shape=(4, 2), dtype=int)(FACE_TRANSLATION))
26
+
27
+ _CUBE_VERTEX_INDICES = wp.constant(wp.vec(length=8, dtype=int)([0, 4, 3, 7, 1, 5, 2, 6]))
28
+
29
+
30
+ @wp.struct
31
+ class HexmeshTopologyArg:
32
+ hex_edge_indices: wp.array2d(dtype=int)
33
+ hex_face_indices: wp.array2d(dtype=wp.vec2i)
34
+
35
+ vertex_count: int
36
+ edge_count: int
37
+ face_count: int
38
+
39
+
40
+ class HexmeshSpaceTopology(SpaceTopology):
41
+ TopologyArg = HexmeshTopologyArg
42
+
43
+ def __init__(
44
+ self,
45
+ mesh: Hexmesh,
46
+ shape: ShapeFunction,
47
+ need_hex_edge_indices: bool = True,
48
+ need_hex_face_indices: bool = True,
49
+ ):
50
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
51
+ self._mesh = mesh
52
+ self._shape = shape
53
+
54
+ if need_hex_edge_indices:
55
+ self._hex_edge_indices = self._mesh.hex_edge_indices
56
+ self._edge_count = self._mesh.edge_count()
57
+ else:
58
+ self._hex_edge_indices = wp.empty(shape=(0, 0), dtype=int)
59
+ self._edge_count = 0
60
+
61
+ if need_hex_face_indices:
62
+ self._compute_hex_face_indices()
63
+ else:
64
+ self._hex_face_indices = wp.empty(shape=(0, 0), dtype=wp.vec2i)
65
+
66
+ self._compute_hex_face_indices()
67
+
68
+ @cache.cached_arg_value
69
+ def topo_arg_value(self, device):
70
+ arg = HexmeshTopologyArg()
71
+ arg.hex_edge_indices = self._hex_edge_indices.to(device)
72
+ arg.hex_face_indices = self._hex_face_indices.to(device)
73
+
74
+ arg.vertex_count = self._mesh.vertex_count()
75
+ arg.face_count = self._mesh.side_count()
76
+ arg.edge_count = self._edge_count
77
+ return arg
78
+
79
+ def _compute_hex_face_indices(self):
80
+ self._hex_face_indices = wp.empty(
81
+ dtype=wp.vec2i, device=self._mesh.hex_vertex_indices.device, shape=(self._mesh.cell_count(), 6)
82
+ )
83
+
84
+ wp.launch(
85
+ kernel=HexmeshSpaceTopology._compute_hex_face_indices_kernel,
86
+ dim=self._mesh.side_count(),
87
+ device=self._mesh.hex_vertex_indices.device,
88
+ inputs=[
89
+ self._mesh.face_hex_indices,
90
+ self._mesh._face_hex_face_orientation,
91
+ self._hex_face_indices,
92
+ ],
93
+ )
94
+
95
+ @wp.kernel
96
+ def _compute_hex_face_indices_kernel(
97
+ face_hex_indices: wp.array(dtype=wp.vec2i),
98
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
99
+ hex_face_indices: wp.array2d(dtype=wp.vec2i),
100
+ ):
101
+ f = wp.tid()
102
+
103
+ hx0 = face_hex_indices[f][0]
104
+ local_face_0 = face_hex_face_ori[f][0]
105
+ ori_0 = face_hex_face_ori[f][1]
106
+
107
+ hex_face_indices[hx0, local_face_0] = wp.vec2i(f, ori_0)
108
+
109
+ hx1 = face_hex_indices[f][1]
110
+ local_face_1 = face_hex_face_ori[f][2]
111
+ ori_1 = face_hex_face_ori[f][3]
112
+
113
+ hex_face_indices[hx1, local_face_1] = wp.vec2i(f, ori_1)
114
+
115
+
116
+ class HexmeshDiscontinuousSpaceTopology(
117
+ DiscontinuousSpaceTopologyMixin,
118
+ SpaceTopology,
119
+ ):
120
+ def __init__(self, mesh: Hexmesh, shape: ShapeFunction):
121
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
122
+
123
+
124
+ class HexmeshBasisSpace(ShapeBasisSpace):
125
+ def __init__(self, topology: HexmeshSpaceTopology, shape: ShapeFunction):
126
+ super().__init__(topology, shape)
127
+
128
+ self._mesh: Hexmesh = topology.geometry
129
+
130
+
131
+ class HexmeshPiecewiseConstantBasis(HexmeshBasisSpace):
132
+ def __init__(self, mesh: Hexmesh):
133
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=3)
134
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape)
135
+ super().__init__(shape=shape, topology=topology)
136
+
137
+ class Trace(TraceBasisSpace):
138
+ @wp.func
139
+ def _node_coords_in_element(
140
+ side_arg: Hexmesh.SideArg,
141
+ basis_arg: HexmeshBasisSpace.BasisArg,
142
+ element_index: ElementIndex,
143
+ node_index_in_element: int,
144
+ ):
145
+ return Coords(0.5, 0.5, 0.0)
146
+
147
+ def make_node_coords_in_element(self):
148
+ return self._node_coords_in_element
149
+
150
+ def trace(self):
151
+ return HexmeshPiecewiseConstantBasis.Trace(self)
152
+
153
+
154
+ class HexmeshTripolynomialSpaceTopology(HexmeshSpaceTopology):
155
+ def __init__(self, mesh: Hexmesh, shape: CubeTripolynomialShapeFunctions):
156
+ super().__init__(mesh, shape, need_hex_edge_indices=shape.ORDER >= 2, need_hex_face_indices=shape.ORDER >= 2)
157
+
158
+ self.element_node_index = self._make_element_node_index()
159
+
160
+ def node_count(self) -> int:
161
+ ORDER = self._shape.ORDER
162
+ INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1)
163
+ INTERIOR_NODES_PER_FACE = INTERIOR_NODES_PER_EDGE**2
164
+ INTERIOR_NODES_PER_CELL = INTERIOR_NODES_PER_EDGE**3
165
+
166
+ return (
167
+ self._mesh.vertex_count()
168
+ + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE
169
+ + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
170
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
171
+ )
172
+
173
+ @wp.func
174
+ def _rotate_face_index(type_index: int, ori: int, size: int):
175
+ i = type_index // size
176
+ j = type_index - i * size
177
+ coords = wp.vec2i(i, j)
178
+
179
+ fv = ori // 2
180
+
181
+ rot_i = wp.dot(_FACE_ORIENTATION_I[2 * ori], coords) + _FACE_TRANSLATION_I[fv, 0]
182
+ rot_j = wp.dot(_FACE_ORIENTATION_I[2 * ori + 1], coords) + _FACE_TRANSLATION_I[fv, 1]
183
+
184
+ return rot_i * size + rot_j
185
+
186
+ def _make_element_node_index(self):
187
+ ORDER = self._shape.ORDER
188
+ INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1))
189
+ INTERIOR_NODES_PER_FACE = wp.constant(INTERIOR_NODES_PER_EDGE**2)
190
+ INTERIOR_NODES_PER_CELL = wp.constant(INTERIOR_NODES_PER_EDGE**3)
191
+
192
+ @cache.dynamic_func(suffix=self.name)
193
+ def element_node_index(
194
+ geo_arg: Hexmesh.CellArg,
195
+ topo_arg: HexmeshTopologyArg,
196
+ element_index: ElementIndex,
197
+ node_index_in_elt: int,
198
+ ):
199
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
200
+
201
+ if node_type == CubeTripolynomialShapeFunctions.VERTEX:
202
+ return geo_arg.hex_vertex_indices[element_index, _CUBE_VERTEX_INDICES[type_instance]]
203
+
204
+ offset = topo_arg.vertex_count
205
+
206
+ if node_type == CubeTripolynomialShapeFunctions.EDGE:
207
+ edge_index = topo_arg.hex_edge_indices[element_index, type_instance]
208
+
209
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 0]]
210
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 1]]
211
+
212
+ if v0 > v1:
213
+ type_index = ORDER - 1 - type_index
214
+
215
+ return offset + INTERIOR_NODES_PER_EDGE * edge_index + type_index
216
+
217
+ offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count
218
+
219
+ if node_type == CubeTripolynomialShapeFunctions.FACE:
220
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
221
+ face_index = face_index_and_ori[0]
222
+ face_orientation = face_index_and_ori[1]
223
+
224
+ type_index = HexmeshTripolynomialSpaceTopology._rotate_face_index(
225
+ type_index, face_orientation, ORDER - 1
226
+ )
227
+
228
+ return offset + INTERIOR_NODES_PER_FACE * face_index + type_index
229
+
230
+ offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count
231
+
232
+ return offset + INTERIOR_NODES_PER_CELL * element_index + type_index
233
+
234
+ return element_node_index
235
+
236
+
237
+ class HexmeshTripolynomialBasisSpace(HexmeshBasisSpace):
238
+ def __init__(
239
+ self,
240
+ mesh: Hexmesh,
241
+ degree: int,
242
+ family: Polynomial,
243
+ ):
244
+ if family is None:
245
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
246
+
247
+ if not is_closed(family):
248
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
249
+
250
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
251
+ topology = forward_base_topology(HexmeshTripolynomialSpaceTopology, mesh, shape)
252
+
253
+ super().__init__(topology, shape)
254
+
255
+
256
+ class HexmeshDGTripolynomialBasisSpace(HexmeshBasisSpace):
257
+ def __init__(
258
+ self,
259
+ mesh: Hexmesh,
260
+ degree: int,
261
+ family: Polynomial,
262
+ ):
263
+ if family is None:
264
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
265
+
266
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
267
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape)
268
+
269
+ super().__init__(topology, shape)
270
+
271
+
272
+ class HexmeshSerendipitySpaceTopology(HexmeshSpaceTopology):
273
+ def __init__(self, grid: Hexmesh, shape: CubeSerendipityShapeFunctions):
274
+ super().__init__(grid, shape, need_hex_edge_indices=True, need_hex_face_indices=False)
275
+
276
+ self.element_node_index = self._make_element_node_index()
277
+
278
+ def node_count(self) -> int:
279
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.edge_count()
280
+
281
+ def _make_element_node_index(self):
282
+ ORDER = self._shape.ORDER
283
+
284
+ @cache.dynamic_func(suffix=self.name)
285
+ def element_node_index(
286
+ cell_arg: Hexmesh.CellArg,
287
+ topo_arg: HexmeshSpaceTopology.TopologyArg,
288
+ element_index: ElementIndex,
289
+ node_index_in_elt: int,
290
+ ):
291
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
292
+
293
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
294
+ return cell_arg.hex_vertex_indices[element_index, _CUBE_VERTEX_INDICES[type_index]]
295
+
296
+ type_instance, index_in_edge = CubeSerendipityShapeFunctions._cube_edge_index(node_type, type_index)
297
+
298
+ edge_index = topo_arg.hex_edge_indices[element_index, type_instance]
299
+
300
+ v0 = cell_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 0]]
301
+ v1 = cell_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 1]]
302
+
303
+ if v0 > v1:
304
+ index_in_edge = ORDER - 1 - index_in_edge
305
+
306
+ return topo_arg.vertex_count + (ORDER - 1) * edge_index + index_in_edge
307
+
308
+ return element_node_index
309
+
310
+
311
+ class HexmeshSerendipityBasisSpace(HexmeshBasisSpace):
312
+ def __init__(
313
+ self,
314
+ mesh: Hexmesh,
315
+ degree: int,
316
+ family: Polynomial,
317
+ ):
318
+ if family is None:
319
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
320
+
321
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
322
+ topology = forward_base_topology(HexmeshSerendipitySpaceTopology, mesh, shape=shape)
323
+
324
+ super().__init__(topology=topology, shape=shape)
325
+
326
+
327
+ class HexmeshDGSerendipityBasisSpace(HexmeshBasisSpace):
328
+ def __init__(
329
+ self,
330
+ mesh: Hexmesh,
331
+ degree: int,
332
+ family: Polynomial,
333
+ ):
334
+ if family is None:
335
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
336
+
337
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
338
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape=shape)
339
+
340
+ super().__init__(topology=topology, shape=shape)
341
+
342
+
343
+ class HexmeshPolynomialBasisSpace(HexmeshBasisSpace):
344
+ def __init__(
345
+ self,
346
+ mesh: Hexmesh,
347
+ degree: int,
348
+ ):
349
+ shape = CubeNonConformingPolynomialShapeFunctions(degree)
350
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape)
351
+
352
+ super().__init__(topology, shape)