warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,306 @@
1
+ import warp as wp
2
+ import numpy as np
3
+
4
+ from warp.fem.types import ElementIndex, Coords
5
+ from warp.fem.polynomial import Polynomial, is_closed
6
+ from warp.fem.geometry import Grid3D
7
+ from warp.fem import cache
8
+
9
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
10
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
11
+
12
+ from .shape import ShapeFunction, ConstantShapeFunction
13
+ from .shape.cube_shape_function import (
14
+ CubeTripolynomialShapeFunctions,
15
+ CubeSerendipityShapeFunctions,
16
+ CubeNonConformingPolynomialShapeFunctions,
17
+ )
18
+
19
+
20
+ class Grid3DSpaceTopology(SpaceTopology):
21
+ def __init__(self, grid: Grid3D, shape: ShapeFunction):
22
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
23
+ self._shape = shape
24
+
25
+ @wp.func
26
+ def _vertex_coords(vidx_in_cell: int):
27
+ x = vidx_in_cell // 4
28
+ y = (vidx_in_cell - 4 * x) // 2
29
+ z = vidx_in_cell - 4 * x - 2 * y
30
+ return wp.vec3i(x, y, z)
31
+
32
+ @wp.func
33
+ def _vertex_index(cell_arg: Grid3D.CellArg, cell_index: ElementIndex, vidx_in_cell: int):
34
+ res = cell_arg.res
35
+ strides = wp.vec2i((res[1] + 1) * (res[2] + 1), res[2] + 1)
36
+
37
+ corner = Grid3D.get_cell(res, cell_index) + Grid3DSpaceTopology._vertex_coords(vidx_in_cell)
38
+ return Grid3D._from_3d_index(strides, corner)
39
+
40
+
41
+ class Grid3DDiscontinuousSpaceTopology(
42
+ DiscontinuousSpaceTopologyMixin,
43
+ Grid3DSpaceTopology,
44
+ ):
45
+ pass
46
+
47
+
48
+ class Grid3DBasisSpace(ShapeBasisSpace):
49
+ def __init__(self, topology: Grid3DSpaceTopology, shape: ShapeFunction):
50
+ super().__init__(topology, shape)
51
+
52
+ self._grid: Grid3D = topology.geometry
53
+
54
+
55
+ class Grid3DPiecewiseConstantBasis(Grid3DBasisSpace):
56
+ def __init__(self, grid: Grid3D):
57
+ shape = ConstantShapeFunction(grid.reference_cell(), space_dimension=3)
58
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape)
59
+ super().__init__(shape=shape, topology=topology)
60
+
61
+ if isinstance(grid, Grid3D):
62
+ self.node_grid = self._node_grid
63
+
64
+ def _node_grid(self):
65
+ X = (np.arange(0, self.geometry.res[0], dtype=float) + 0.5) * self._grid.cell_size[0] + self._grid.bounds_lo[0]
66
+ Y = (np.arange(0, self.geometry.res[1], dtype=float) + 0.5) * self._grid.cell_size[1] + self._grid.bounds_lo[1]
67
+ Z = (np.arange(0, self.geometry.res[2], dtype=float) + 0.5) * self._grid.cell_size[2] + self._grid.bounds_lo[2]
68
+ return np.meshgrid(X, Y, Z, indexing="ij")
69
+
70
+ class Trace(TraceBasisSpace):
71
+ @wp.func
72
+ def _node_coords_in_element(
73
+ side_arg: Grid3D.SideArg,
74
+ basis_arg: Grid3DBasisSpace.BasisArg,
75
+ element_index: ElementIndex,
76
+ node_index_in_element: int,
77
+ ):
78
+ return Coords(0.5, 0.5, 0.0)
79
+
80
+ def make_node_coords_in_element(self):
81
+ return self._node_coords_in_element
82
+
83
+ def trace(self):
84
+ return Grid3DPiecewiseConstantBasis.Trace(self)
85
+
86
+
87
+ class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
88
+ def __init__(self, grid: Grid3D, shape: CubeTripolynomialShapeFunctions):
89
+ super().__init__(grid, shape)
90
+
91
+ self.element_node_index = self._make_element_node_index()
92
+
93
+ def node_count(self) -> int:
94
+ return (
95
+ (self.geometry.res[0] * self._shape.ORDER + 1)
96
+ * (self.geometry.res[1] * self._shape.ORDER + 1)
97
+ * (self.geometry.res[2] * self._shape.ORDER + 1)
98
+ )
99
+
100
+ def _make_element_node_index(self):
101
+ ORDER = self._shape.ORDER
102
+
103
+ @cache.dynamic_func(suffix=self.name)
104
+ def element_node_index(
105
+ cell_arg: Grid3D.CellArg,
106
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
107
+ element_index: ElementIndex,
108
+ node_index_in_elt: int,
109
+ ):
110
+ res = cell_arg.res
111
+ cell = Grid3D.get_cell(res, element_index)
112
+
113
+ node_i, node_j, node_k = self._shape._node_ijk(node_index_in_elt)
114
+
115
+ node_x = ORDER * cell[0] + node_i
116
+ node_y = ORDER * cell[1] + node_j
117
+ node_z = ORDER * cell[2] + node_k
118
+
119
+ node_pitch_y = (res[2] * ORDER) + 1
120
+ node_pitch_x = node_pitch_y * ((res[1] * ORDER) + 1)
121
+ node_index = node_pitch_x * node_x + node_pitch_y * node_y + node_z
122
+
123
+ return node_index
124
+
125
+ return element_node_index
126
+
127
+
128
+ class GridTripolynomialBasisSpace(Grid3DBasisSpace):
129
+ def __init__(
130
+ self,
131
+ grid: Grid3D,
132
+ degree: int,
133
+ family: Polynomial,
134
+ ):
135
+ if family is None:
136
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
137
+
138
+ if not is_closed(family):
139
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
140
+
141
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
142
+ topology = forward_base_topology(GridTripolynomialSpaceTopology, grid, shape)
143
+
144
+ super().__init__(topology, shape)
145
+
146
+ if isinstance(grid, Grid3D):
147
+ self.node_grid = self._node_grid
148
+
149
+ def _node_grid(self):
150
+ res = self._grid.res
151
+
152
+ cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
153
+
154
+ grid_coords_x = np.repeat(np.arange(0, res[0], dtype=float), len(cell_coords)) + np.tile(
155
+ cell_coords, reps=res[0]
156
+ )
157
+ grid_coords_x = np.append(grid_coords_x, res[0])
158
+ X = grid_coords_x * self._grid.cell_size[0] + self._grid.origin[0]
159
+
160
+ grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
161
+ cell_coords, reps=res[1]
162
+ )
163
+ grid_coords_y = np.append(grid_coords_y, res[1])
164
+ Y = grid_coords_y * self._grid.cell_size[1] + self._grid.origin[1]
165
+
166
+ grid_coords_z = np.repeat(np.arange(0, res[2], dtype=float), len(cell_coords)) + np.tile(
167
+ cell_coords, reps=res[2]
168
+ )
169
+ grid_coords_z = np.append(grid_coords_z, res[2])
170
+ Z = grid_coords_z * self._grid.cell_size[2] + self._grid.origin[2]
171
+
172
+ return np.meshgrid(X, Y, Z, indexing="ij")
173
+
174
+
175
+ class GridDGTripolynomialBasisSpace(Grid3DBasisSpace):
176
+ def __init__(
177
+ self,
178
+ grid: Grid3D,
179
+ degree: int,
180
+ family: Polynomial,
181
+ ):
182
+ if family is None:
183
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
184
+
185
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
186
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape)
187
+
188
+ super().__init__(shape=shape, topology=topology)
189
+
190
+ def node_grid(self):
191
+ res = self._grid.res
192
+
193
+ cell_coords = np.array(self._shape.LOBATTO_COORDS)
194
+
195
+ grid_coords_x = np.repeat(np.arange(0, res[0], dtype=float), len(cell_coords)) + np.tile(
196
+ cell_coords, reps=res[0]
197
+ )
198
+ X = grid_coords_x * self._grid.cell_size[0] + self._grid.origin[0]
199
+
200
+ grid_coords_y = np.repeat(np.arange(0, res[1], dtype=float), len(cell_coords)) + np.tile(
201
+ cell_coords, reps=res[1]
202
+ )
203
+ Y = grid_coords_y * self._grid.cell_size[1] + self._grid.origin[1]
204
+
205
+ grid_coords_z = np.repeat(np.arange(0, res[2], dtype=float), len(cell_coords)) + np.tile(
206
+ cell_coords, reps=res[2]
207
+ )
208
+ Z = grid_coords_z * self._grid.cell_size[2] + self._grid.origin[2]
209
+
210
+ return np.meshgrid(X, Y, Z, indexing="ij")
211
+
212
+
213
+ class Grid3DSerendipitySpaceTopology(Grid3DSpaceTopology):
214
+ def __init__(self, grid: Grid3D, shape: CubeSerendipityShapeFunctions):
215
+ super().__init__(grid, shape)
216
+
217
+ self.element_node_index = self._make_element_node_index()
218
+
219
+ def node_count(self) -> int:
220
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.edge_count()
221
+
222
+ def _make_element_node_index(self):
223
+ ORDER = self._shape.ORDER
224
+
225
+ @cache.dynamic_func(suffix=self.name)
226
+ def element_node_index(
227
+ cell_arg: Grid3D.CellArg,
228
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
229
+ element_index: ElementIndex,
230
+ node_index_in_elt: int,
231
+ ):
232
+ res = cell_arg.res
233
+ cell = Grid3D.get_cell(res, element_index)
234
+
235
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
236
+
237
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
238
+ return Grid3DSpaceTopology._vertex_index(cell_arg, element_index, type_index)
239
+
240
+ axis = CubeSerendipityShapeFunctions._edge_axis(node_type)
241
+ node_all = CubeSerendipityShapeFunctions._edge_coords(type_index)
242
+
243
+ res = cell_arg.res
244
+
245
+ edge_index = 0
246
+ if axis > 0:
247
+ edge_index += (res[1] + 1) * (res[2] + 1) * res[0]
248
+ if axis > 1:
249
+ edge_index += (res[0] + 1) * (res[2] + 1) * res[1]
250
+
251
+ res_loc = Grid3D._world_to_local(axis, res)
252
+ cell_loc = Grid3D._world_to_local(axis, cell)
253
+
254
+ edge_index += (res_loc[1] + 1) * (res_loc[2] + 1) * cell_loc[0]
255
+ edge_index += (res_loc[2] + 1) * (cell_loc[1] + node_all[1])
256
+ edge_index += cell_loc[2] + node_all[2]
257
+
258
+ vertex_count = (res[0] + 1) * (res[1] + 1) * (res[2] + 1)
259
+
260
+ return vertex_count + (ORDER - 1) * edge_index + (node_all[0] - 1)
261
+
262
+ return element_node_index
263
+
264
+
265
+ class Grid3DSerendipityBasisSpace(Grid3DBasisSpace):
266
+ def __init__(
267
+ self,
268
+ grid: Grid3D,
269
+ degree: int,
270
+ family: Polynomial,
271
+ ):
272
+ if family is None:
273
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
274
+
275
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
276
+ topology = forward_base_topology(Grid3DSerendipitySpaceTopology, grid, shape=shape)
277
+
278
+ super().__init__(topology=topology, shape=shape)
279
+
280
+
281
+ class Grid3DDGSerendipityBasisSpace(Grid3DBasisSpace):
282
+ def __init__(
283
+ self,
284
+ grid: Grid3D,
285
+ degree: int,
286
+ family: Polynomial,
287
+ ):
288
+ if family is None:
289
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
290
+
291
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
292
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape=shape)
293
+
294
+ super().__init__(topology=topology, shape=shape)
295
+
296
+
297
+ class Grid3DDGPolynomialBasisSpace(Grid3DBasisSpace):
298
+ def __init__(
299
+ self,
300
+ grid: Grid3D,
301
+ degree: int,
302
+ ):
303
+ shape = CubeNonConformingPolynomialShapeFunctions(degree)
304
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape=shape)
305
+
306
+ super().__init__(topology=topology, shape=shape)
@@ -0,0 +1,352 @@
1
+ import warp as wp
2
+
3
+ from warp.fem.types import ElementIndex, Coords
4
+ from warp.fem.polynomial import Polynomial, is_closed
5
+ from warp.fem.geometry import Hexmesh
6
+ from warp.fem import cache
7
+
8
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
9
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
+
11
+ from .shape import ShapeFunction, ConstantShapeFunction
12
+ from .shape import (
13
+ CubeTripolynomialShapeFunctions,
14
+ CubeSerendipityShapeFunctions,
15
+ CubeNonConformingPolynomialShapeFunctions,
16
+ )
17
+
18
+ from warp.fem.geometry.hexmesh import (
19
+ EDGE_VERTEX_INDICES,
20
+ FACE_ORIENTATION,
21
+ FACE_TRANSLATION,
22
+ )
23
+
24
+ _FACE_ORIENTATION_I = wp.constant(wp.mat(shape=(16, 2), dtype=int)(FACE_ORIENTATION))
25
+ _FACE_TRANSLATION_I = wp.constant(wp.mat(shape=(4, 2), dtype=int)(FACE_TRANSLATION))
26
+
27
+ _CUBE_VERTEX_INDICES = wp.constant(wp.vec(length=8, dtype=int)([0, 4, 3, 7, 1, 5, 2, 6]))
28
+
29
+
30
+ @wp.struct
31
+ class HexmeshTopologyArg:
32
+ hex_edge_indices: wp.array2d(dtype=int)
33
+ hex_face_indices: wp.array2d(dtype=wp.vec2i)
34
+
35
+ vertex_count: int
36
+ edge_count: int
37
+ face_count: int
38
+
39
+
40
+ class HexmeshSpaceTopology(SpaceTopology):
41
+ TopologyArg = HexmeshTopologyArg
42
+
43
+ def __init__(
44
+ self,
45
+ mesh: Hexmesh,
46
+ shape: ShapeFunction,
47
+ need_hex_edge_indices: bool = True,
48
+ need_hex_face_indices: bool = True,
49
+ ):
50
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
51
+ self._mesh = mesh
52
+ self._shape = shape
53
+
54
+ if need_hex_edge_indices:
55
+ self._hex_edge_indices = self._mesh.hex_edge_indices
56
+ self._edge_count = self._mesh.edge_count()
57
+ else:
58
+ self._hex_edge_indices = wp.empty(shape=(0, 0), dtype=int)
59
+ self._edge_count = 0
60
+
61
+ if need_hex_face_indices:
62
+ self._compute_hex_face_indices()
63
+ else:
64
+ self._hex_face_indices = wp.empty(shape=(0, 0), dtype=wp.vec2i)
65
+
66
+ self._compute_hex_face_indices()
67
+
68
+ @cache.cached_arg_value
69
+ def topo_arg_value(self, device):
70
+ arg = HexmeshTopologyArg()
71
+ arg.hex_edge_indices = self._hex_edge_indices.to(device)
72
+ arg.hex_face_indices = self._hex_face_indices.to(device)
73
+
74
+ arg.vertex_count = self._mesh.vertex_count()
75
+ arg.face_count = self._mesh.side_count()
76
+ arg.edge_count = self._edge_count
77
+ return arg
78
+
79
+ def _compute_hex_face_indices(self):
80
+ self._hex_face_indices = wp.empty(
81
+ dtype=wp.vec2i, device=self._mesh.hex_vertex_indices.device, shape=(self._mesh.cell_count(), 6)
82
+ )
83
+
84
+ wp.launch(
85
+ kernel=HexmeshSpaceTopology._compute_hex_face_indices_kernel,
86
+ dim=self._mesh.side_count(),
87
+ device=self._mesh.hex_vertex_indices.device,
88
+ inputs=[
89
+ self._mesh.face_hex_indices,
90
+ self._mesh._face_hex_face_orientation,
91
+ self._hex_face_indices,
92
+ ],
93
+ )
94
+
95
+ @wp.kernel
96
+ def _compute_hex_face_indices_kernel(
97
+ face_hex_indices: wp.array(dtype=wp.vec2i),
98
+ face_hex_face_ori: wp.array(dtype=wp.vec4i),
99
+ hex_face_indices: wp.array2d(dtype=wp.vec2i),
100
+ ):
101
+ f = wp.tid()
102
+
103
+ hx0 = face_hex_indices[f][0]
104
+ local_face_0 = face_hex_face_ori[f][0]
105
+ ori_0 = face_hex_face_ori[f][1]
106
+
107
+ hex_face_indices[hx0, local_face_0] = wp.vec2i(f, ori_0)
108
+
109
+ hx1 = face_hex_indices[f][1]
110
+ local_face_1 = face_hex_face_ori[f][2]
111
+ ori_1 = face_hex_face_ori[f][3]
112
+
113
+ hex_face_indices[hx1, local_face_1] = wp.vec2i(f, ori_1)
114
+
115
+
116
+ class HexmeshDiscontinuousSpaceTopology(
117
+ DiscontinuousSpaceTopologyMixin,
118
+ SpaceTopology,
119
+ ):
120
+ def __init__(self, mesh: Hexmesh, shape: ShapeFunction):
121
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
122
+
123
+
124
+ class HexmeshBasisSpace(ShapeBasisSpace):
125
+ def __init__(self, topology: HexmeshSpaceTopology, shape: ShapeFunction):
126
+ super().__init__(topology, shape)
127
+
128
+ self._mesh: Hexmesh = topology.geometry
129
+
130
+
131
+ class HexmeshPiecewiseConstantBasis(HexmeshBasisSpace):
132
+ def __init__(self, mesh: Hexmesh):
133
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=3)
134
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape)
135
+ super().__init__(shape=shape, topology=topology)
136
+
137
+ class Trace(TraceBasisSpace):
138
+ @wp.func
139
+ def _node_coords_in_element(
140
+ side_arg: Hexmesh.SideArg,
141
+ basis_arg: HexmeshBasisSpace.BasisArg,
142
+ element_index: ElementIndex,
143
+ node_index_in_element: int,
144
+ ):
145
+ return Coords(0.5, 0.5, 0.0)
146
+
147
+ def make_node_coords_in_element(self):
148
+ return self._node_coords_in_element
149
+
150
+ def trace(self):
151
+ return HexmeshPiecewiseConstantBasis.Trace(self)
152
+
153
+
154
+ class HexmeshTripolynomialSpaceTopology(HexmeshSpaceTopology):
155
+ def __init__(self, mesh: Hexmesh, shape: CubeTripolynomialShapeFunctions):
156
+ super().__init__(mesh, shape, need_hex_edge_indices=shape.ORDER >= 2, need_hex_face_indices=shape.ORDER >= 2)
157
+
158
+ self.element_node_index = self._make_element_node_index()
159
+
160
+ def node_count(self) -> int:
161
+ ORDER = self._shape.ORDER
162
+ INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1)
163
+ INTERIOR_NODES_PER_FACE = INTERIOR_NODES_PER_EDGE**2
164
+ INTERIOR_NODES_PER_CELL = INTERIOR_NODES_PER_EDGE**3
165
+
166
+ return (
167
+ self._mesh.vertex_count()
168
+ + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE
169
+ + self._mesh.side_count() * INTERIOR_NODES_PER_FACE
170
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
171
+ )
172
+
173
+ @wp.func
174
+ def _rotate_face_index(type_index: int, ori: int, size: int):
175
+ i = type_index // size
176
+ j = type_index - i * size
177
+ coords = wp.vec2i(i, j)
178
+
179
+ fv = ori // 2
180
+
181
+ rot_i = wp.dot(_FACE_ORIENTATION_I[2 * ori], coords) + _FACE_TRANSLATION_I[fv, 0]
182
+ rot_j = wp.dot(_FACE_ORIENTATION_I[2 * ori + 1], coords) + _FACE_TRANSLATION_I[fv, 1]
183
+
184
+ return rot_i * size + rot_j
185
+
186
+ def _make_element_node_index(self):
187
+ ORDER = self._shape.ORDER
188
+ INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1))
189
+ INTERIOR_NODES_PER_FACE = wp.constant(INTERIOR_NODES_PER_EDGE**2)
190
+ INTERIOR_NODES_PER_CELL = wp.constant(INTERIOR_NODES_PER_EDGE**3)
191
+
192
+ @cache.dynamic_func(suffix=self.name)
193
+ def element_node_index(
194
+ geo_arg: Hexmesh.CellArg,
195
+ topo_arg: HexmeshTopologyArg,
196
+ element_index: ElementIndex,
197
+ node_index_in_elt: int,
198
+ ):
199
+ node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
200
+
201
+ if node_type == CubeTripolynomialShapeFunctions.VERTEX:
202
+ return geo_arg.hex_vertex_indices[element_index, _CUBE_VERTEX_INDICES[type_instance]]
203
+
204
+ offset = topo_arg.vertex_count
205
+
206
+ if node_type == CubeTripolynomialShapeFunctions.EDGE:
207
+ edge_index = topo_arg.hex_edge_indices[element_index, type_instance]
208
+
209
+ v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 0]]
210
+ v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 1]]
211
+
212
+ if v0 > v1:
213
+ type_index = ORDER - 1 - type_index
214
+
215
+ return offset + INTERIOR_NODES_PER_EDGE * edge_index + type_index
216
+
217
+ offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count
218
+
219
+ if node_type == CubeTripolynomialShapeFunctions.FACE:
220
+ face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
221
+ face_index = face_index_and_ori[0]
222
+ face_orientation = face_index_and_ori[1]
223
+
224
+ type_index = HexmeshTripolynomialSpaceTopology._rotate_face_index(
225
+ type_index, face_orientation, ORDER - 1
226
+ )
227
+
228
+ return offset + INTERIOR_NODES_PER_FACE * face_index + type_index
229
+
230
+ offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count
231
+
232
+ return offset + INTERIOR_NODES_PER_CELL * element_index + type_index
233
+
234
+ return element_node_index
235
+
236
+
237
+ class HexmeshTripolynomialBasisSpace(HexmeshBasisSpace):
238
+ def __init__(
239
+ self,
240
+ mesh: Hexmesh,
241
+ degree: int,
242
+ family: Polynomial,
243
+ ):
244
+ if family is None:
245
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
246
+
247
+ if not is_closed(family):
248
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
249
+
250
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
251
+ topology = forward_base_topology(HexmeshTripolynomialSpaceTopology, mesh, shape)
252
+
253
+ super().__init__(topology, shape)
254
+
255
+
256
+ class HexmeshDGTripolynomialBasisSpace(HexmeshBasisSpace):
257
+ def __init__(
258
+ self,
259
+ mesh: Hexmesh,
260
+ degree: int,
261
+ family: Polynomial,
262
+ ):
263
+ if family is None:
264
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
265
+
266
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
267
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape)
268
+
269
+ super().__init__(topology, shape)
270
+
271
+
272
+ class HexmeshSerendipitySpaceTopology(HexmeshSpaceTopology):
273
+ def __init__(self, grid: Hexmesh, shape: CubeSerendipityShapeFunctions):
274
+ super().__init__(grid, shape, need_hex_edge_indices=True, need_hex_face_indices=False)
275
+
276
+ self.element_node_index = self._make_element_node_index()
277
+
278
+ def node_count(self) -> int:
279
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.edge_count()
280
+
281
+ def _make_element_node_index(self):
282
+ ORDER = self._shape.ORDER
283
+
284
+ @cache.dynamic_func(suffix=self.name)
285
+ def element_node_index(
286
+ cell_arg: Hexmesh.CellArg,
287
+ topo_arg: HexmeshSpaceTopology.TopologyArg,
288
+ element_index: ElementIndex,
289
+ node_index_in_elt: int,
290
+ ):
291
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
292
+
293
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
294
+ return cell_arg.hex_vertex_indices[element_index, _CUBE_VERTEX_INDICES[type_index]]
295
+
296
+ type_instance, index_in_edge = CubeSerendipityShapeFunctions._cube_edge_index(node_type, type_index)
297
+
298
+ edge_index = topo_arg.hex_edge_indices[element_index, type_instance]
299
+
300
+ v0 = cell_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 0]]
301
+ v1 = cell_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[type_instance, 1]]
302
+
303
+ if v0 > v1:
304
+ index_in_edge = ORDER - 1 - index_in_edge
305
+
306
+ return topo_arg.vertex_count + (ORDER - 1) * edge_index + index_in_edge
307
+
308
+ return element_node_index
309
+
310
+
311
+ class HexmeshSerendipityBasisSpace(HexmeshBasisSpace):
312
+ def __init__(
313
+ self,
314
+ mesh: Hexmesh,
315
+ degree: int,
316
+ family: Polynomial,
317
+ ):
318
+ if family is None:
319
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
320
+
321
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
322
+ topology = forward_base_topology(HexmeshSerendipitySpaceTopology, mesh, shape=shape)
323
+
324
+ super().__init__(topology=topology, shape=shape)
325
+
326
+
327
+ class HexmeshDGSerendipityBasisSpace(HexmeshBasisSpace):
328
+ def __init__(
329
+ self,
330
+ mesh: Hexmesh,
331
+ degree: int,
332
+ family: Polynomial,
333
+ ):
334
+ if family is None:
335
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
336
+
337
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
338
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape=shape)
339
+
340
+ super().__init__(topology=topology, shape=shape)
341
+
342
+
343
+ class HexmeshPolynomialBasisSpace(HexmeshBasisSpace):
344
+ def __init__(
345
+ self,
346
+ mesh: Hexmesh,
347
+ degree: int,
348
+ ):
349
+ shape = CubeNonConformingPolynomialShapeFunctions(degree)
350
+ topology = HexmeshDiscontinuousSpaceTopology(mesh, shape)
351
+
352
+ super().__init__(topology, shape)