warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,532 +1,152 @@
1
1
  import warp as wp
2
2
  import numpy as np
3
3
 
4
- from warp.fem.types import ElementIndex, Coords, OUTSIDE
5
- from warp.fem.types import vec2i, vec3i
6
- from warp.fem.polynomial import Polynomial, lagrange_scales, quadrature_1d, is_closed
7
-
4
+ from warp.fem.types import ElementIndex, Coords
5
+ from warp.fem.polynomial import Polynomial, is_closed
8
6
  from warp.fem.geometry import Grid3D
7
+ from warp.fem import cache
9
8
 
10
- from .dof_mapper import DofMapper
11
- from .nodal_function_space import NodalFunctionSpace, NodalFunctionSpaceTrace
12
-
13
-
14
- class Grid3DFunctionSpace(NodalFunctionSpace):
15
- DIMENSION = wp.constant(3)
16
-
17
- @wp.struct
18
- class SpaceArg:
19
- geo_arg: Grid3D.SideArg
20
- inv_cell_size: wp.vec3
21
-
22
- def __init__(self, grid: Grid3D, dtype: type = float, dof_mapper: DofMapper = None):
23
- super().__init__(dtype, dof_mapper)
24
- self._grid = grid
25
-
26
- @property
27
- def geometry(self) -> Grid3D:
28
- return self._grid
29
-
30
- def space_arg_value(self, device):
31
- arg = self.SpaceArg()
32
- arg.geo_arg = self.geometry.side_arg_value(device)
33
- arg.inv_cell_size = wp.vec3(
34
- 1.0 / self.geometry.cell_size[0],
35
- 1.0 / self.geometry.cell_size[1],
36
- 1.0 / self.geometry.cell_size[2],
37
- )
9
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
10
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
38
11
 
39
- return arg
40
-
41
- class Trace(NodalFunctionSpaceTrace):
42
- def __init__(self, space: NodalFunctionSpace):
43
- super().__init__(space)
44
- self.ORDER = space.ORDER
45
-
46
- @wp.func
47
- def _inner_cell_index(args: SpaceArg, side_index: ElementIndex):
48
- cell_index = Grid3D.side_inner_cell_index(args.geo_arg, side_index)
49
- return cell_index
50
-
51
- @wp.func
52
- def _outer_cell_index(args: SpaceArg, side_index: ElementIndex):
53
- return Grid3D.side_outer_cell_index(args.geo_arg, side_index)
54
-
55
- @wp.func
56
- def _inner_cell_coords(args: SpaceArg, side_index: ElementIndex, side_coords: Coords):
57
- side = Grid3D.get_side(args.geo_arg, side_index)
12
+ from .shape import ShapeFunction, ConstantShapeFunction
13
+ from .shape.cube_shape_function import (
14
+ CubeTripolynomialShapeFunctions,
15
+ CubeSerendipityShapeFunctions,
16
+ CubeNonConformingPolynomialShapeFunctions,
17
+ )
58
18
 
59
- if side.origin[0] == 0:
60
- inner_alt = 0.0
61
- else:
62
- inner_alt = 1.0
63
19
 
64
- return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coords[0], side_coords[1]))
65
-
66
- @wp.func
67
- def _outer_cell_coords(args: SpaceArg, side_index: ElementIndex, side_coords: Coords):
68
- side = Grid3D.get_side(args.geo_arg, side_index)
69
-
70
- alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
71
- if side.origin[0] == args.geo_arg.cell_arg.res[alt_axis]:
72
- outer_alt = 1.0
73
- else:
74
- outer_alt = 0.0
75
-
76
- return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coords[0], side_coords[1]))
77
-
78
- @wp.func
79
- def _cell_to_side_coords(
80
- args: SpaceArg,
81
- side_index: ElementIndex,
82
- element_index: ElementIndex,
83
- element_coords: Coords,
84
- ):
85
- side = Grid3D.get_side(args.geo_arg, side_index)
86
- cell = Grid3D.get_cell(args.geo_arg.cell_arg.res, element_index)
87
-
88
- if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
89
- long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
90
- lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
91
- return Coords(element_coords[long_axis], element_coords[lat_axis], 0.0)
92
-
93
- return Coords(OUTSIDE)
20
+ class Grid3DSpaceTopology(SpaceTopology):
21
+ def __init__(self, grid: Grid3D, shape: ShapeFunction):
22
+ super().__init__(grid, shape.NODES_PER_ELEMENT)
23
+ self._shape = shape
94
24
 
95
25
  @wp.func
96
26
  def _vertex_coords(vidx_in_cell: int):
97
27
  x = vidx_in_cell // 4
98
28
  y = (vidx_in_cell - 4 * x) // 2
99
29
  z = vidx_in_cell - 4 * x - 2 * y
100
- return vec3i(x, y, z)
101
-
102
- @wp.func
103
- def _vertex_coords_f(vidx_in_cell: int):
104
- x = vidx_in_cell // 4
105
- y = (vidx_in_cell - 4 * x) // 2
106
- z = vidx_in_cell - 4 * x - 2 * y
107
- return wp.vec3(float(x), float(y), float(z))
30
+ return wp.vec3i(x, y, z)
108
31
 
109
32
  @wp.func
110
- def _vertex_index(args: SpaceArg, cell_index: ElementIndex, vidx_in_cell: int):
111
- res = args.geo_arg.cell_arg.res
112
- strides = vec2i((res[1] + 1) * (res[2] + 1), res[2] + 1)
33
+ def _vertex_index(cell_arg: Grid3D.CellArg, cell_index: ElementIndex, vidx_in_cell: int):
34
+ res = cell_arg.res
35
+ strides = wp.vec2i((res[1] + 1) * (res[2] + 1), res[2] + 1)
113
36
 
114
- corner = Grid3D.get_cell(res, cell_index) + Grid3DFunctionSpace._vertex_coords(vidx_in_cell)
37
+ corner = Grid3D.get_cell(res, cell_index) + Grid3DSpaceTopology._vertex_coords(vidx_in_cell)
115
38
  return Grid3D._from_3d_index(strides, corner)
116
39
 
117
40
 
41
+ class Grid3DDiscontinuousSpaceTopology(
42
+ DiscontinuousSpaceTopologyMixin,
43
+ Grid3DSpaceTopology,
44
+ ):
45
+ pass
118
46
 
119
- class Grid3DPiecewiseConstantSpace(Grid3DFunctionSpace):
120
- ORDER = wp.constant(0)
121
- NODES_PER_ELEMENT = wp.constant(1)
122
47
 
123
- def __init__(self, grid: Grid3D, dtype: type = float, dof_mapper: DofMapper = None):
124
- super().__init__(grid, dtype, dof_mapper)
48
+ class Grid3DBasisSpace(ShapeBasisSpace):
49
+ def __init__(self, topology: Grid3DSpaceTopology, shape: ShapeFunction):
50
+ super().__init__(topology, shape)
125
51
 
126
- self.element_outer_weight = self.element_inner_weight
127
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
52
+ self._grid: Grid3D = topology.geometry
128
53
 
129
- def node_count(self) -> int:
130
- return self._grid.cell_count()
131
54
 
132
- def node_positions(self):
55
+ class Grid3DPiecewiseConstantBasis(Grid3DBasisSpace):
56
+ def __init__(self, grid: Grid3D):
57
+ shape = ConstantShapeFunction(grid.reference_cell(), space_dimension=3)
58
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape)
59
+ super().__init__(shape=shape, topology=topology)
60
+
61
+ if isinstance(grid, Grid3D):
62
+ self.node_grid = self._node_grid
63
+
64
+ def _node_grid(self):
133
65
  X = (np.arange(0, self.geometry.res[0], dtype=float) + 0.5) * self._grid.cell_size[0] + self._grid.bounds_lo[0]
134
66
  Y = (np.arange(0, self.geometry.res[1], dtype=float) + 0.5) * self._grid.cell_size[1] + self._grid.bounds_lo[1]
135
67
  Z = (np.arange(0, self.geometry.res[2], dtype=float) + 0.5) * self._grid.cell_size[2] + self._grid.bounds_lo[2]
136
68
  return np.meshgrid(X, Y, Z, indexing="ij")
137
69
 
138
- @wp.func
139
- def element_node_index(
140
- args: Grid3DFunctionSpace.SpaceArg,
141
- element_index: ElementIndex,
142
- node_index_in_elt: int,
143
- ):
144
- return element_index
145
-
146
- @wp.func
147
- def node_coords_in_element(
148
- args: Grid3DFunctionSpace.SpaceArg,
149
- element_index: ElementIndex,
150
- node_index_in_elt: int,
151
- ):
152
- if node_index_in_elt == 0:
153
- return Coords(0.5, 0.5, 0.5)
154
-
155
- return Coords(OUTSIDE)
156
-
157
- @wp.func
158
- def node_quadrature_weight(
159
- args: Grid3DFunctionSpace.SpaceArg,
160
- element_index: ElementIndex,
161
- node_index_in_elt: int,
162
- ):
163
- return 1.0
164
-
165
- @wp.func
166
- def element_inner_weight(
167
- args: Grid3DFunctionSpace.SpaceArg,
168
- element_index: ElementIndex,
169
- coords: Coords,
170
- node_index_in_elt: int,
171
- ):
172
- if node_index_in_elt == 0:
173
- return 1.0
174
- return 0.0
175
-
176
- @wp.func
177
- def element_inner_weight_gradient(
178
- args: Grid3DFunctionSpace.SpaceArg,
179
- element_index: ElementIndex,
180
- coords: Coords,
181
- node_index_in_elt: int,
182
- ):
183
- return wp.vec3(0.0)
184
-
185
- class Trace(Grid3DFunctionSpace.Trace):
186
- NODES_PER_ELEMENT = wp.constant(2)
187
- ORDER = wp.constant(0)
188
-
189
- def __init__(self, space: "Grid3DPiecewiseConstantSpace"):
190
- super().__init__(space)
191
-
192
- self.element_node_index = self._make_element_node_index(space)
193
-
194
- self.element_inner_weight = self._make_element_inner_weight(space)
195
- self.element_inner_weight_gradient = self._make_element_inner_weight_gradient(space)
196
-
197
- self.element_outer_weight = self._make_element_outer_weight(space)
198
- self.element_outer_weight_gradient = self._make_element_outer_weight_gradient(space)
199
-
70
+ class Trace(TraceBasisSpace):
200
71
  @wp.func
201
- def node_coords_in_element(
202
- args: Grid3DFunctionSpace.SpaceArg,
72
+ def _node_coords_in_element(
73
+ side_arg: Grid3D.SideArg,
74
+ basis_arg: Grid3DBasisSpace.BasisArg,
203
75
  element_index: ElementIndex,
204
76
  node_index_in_element: int,
205
77
  ):
206
- if node_index_in_element >= 0:
207
- return Coords(0.5, 0.5, 0.0)
208
- elif node_index_in_element == 1:
209
- return Coords(0.5, 0.5, 0.0)
78
+ return Coords(0.5, 0.5, 0.0)
210
79
 
211
- return Coords(OUTSIDE)
212
-
213
- @wp.func
214
- def node_quadrature_weight(
215
- args: Grid3DFunctionSpace.SpaceArg,
216
- element_index: ElementIndex,
217
- node_index_in_elt: int,
218
- ):
219
- return 1.0
80
+ def make_node_coords_in_element(self):
81
+ return self._node_coords_in_element
220
82
 
221
83
  def trace(self):
222
- return Grid3DPiecewiseConstantSpace.Trace(self)
84
+ return Grid3DPiecewiseConstantBasis.Trace(self)
223
85
 
224
86
 
225
- class GridTripolynomialShapeFunctions:
226
- def __init__(self, degree: int, family: Polynomial):
227
- self.family = family
87
+ class GridTripolynomialSpaceTopology(Grid3DSpaceTopology):
88
+ def __init__(self, grid: Grid3D, shape: CubeTripolynomialShapeFunctions):
89
+ super().__init__(grid, shape)
228
90
 
229
- self.ORDER = wp.constant(degree)
230
- self.NODES_PER_ELEMENT = wp.constant((degree + 1) ** 3)
231
- self.NODES_PER_SIDE = wp.constant(degree + 1)
232
-
233
- lobatto_coords, lobatto_weight = quadrature_1d(point_count=degree + 1, family=family)
234
- lagrange_scale = lagrange_scales(lobatto_coords)
235
-
236
- NodeVec = wp.types.vector(length=degree + 1, dtype=wp.float32)
237
- self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
238
- self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
239
- self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
240
-
241
- self._node_ijk = self._make_node_ijk()
242
-
243
- @property
244
- def name(self) -> str:
245
- return f"{self.family}_{self.ORDER}"
246
-
247
- def _make_node_ijk(self):
248
- ORDER = self.ORDER
249
-
250
- def node_ijk(
251
- node_index_in_elt: int,
252
- ):
253
- node_i = node_index_in_elt // ((ORDER + 1) * (ORDER + 1))
254
- node_jk = node_index_in_elt - (ORDER + 1) * (ORDER + 1) * node_i
255
- node_j = node_jk // (ORDER + 1)
256
- node_k = node_jk - (ORDER + 1) * node_j
257
- return node_i, node_j, node_k
258
-
259
- from warp.fem import cache
260
-
261
- return cache.get_func(node_ijk, self.name)
262
-
263
- def make_node_coords_in_element(self):
264
- LOBATTO_COORDS = self.LOBATTO_COORDS
265
-
266
- def node_coords_in_element(
267
- args: Grid3DFunctionSpace.SpaceArg,
268
- element_index: ElementIndex,
269
- node_index_in_elt: int,
270
- ):
271
- node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
272
- return Coords(LOBATTO_COORDS[node_i], LOBATTO_COORDS[node_j], LOBATTO_COORDS[node_k])
273
-
274
- from warp.fem import cache
275
-
276
- return cache.get_func(node_coords_in_element, self.name)
277
-
278
- def make_node_quadrature_weight(self):
279
- ORDER = self.ORDER
280
- LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
281
-
282
- def node_quadrature_weight(
283
- args: Grid3DFunctionSpace.SpaceArg,
284
- element_index: ElementIndex,
285
- node_index_in_elt: int,
286
- ):
287
- node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
288
- return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j] * LOBATTO_WEIGHT[node_k]
289
-
290
- def node_quadrature_weight_linear(
291
- args: Grid3DFunctionSpace.SpaceArg,
292
- element_index: ElementIndex,
293
- node_index_in_elt: int,
294
- ):
295
- return 0.125
296
-
297
- from warp.fem import cache
298
-
299
- if ORDER == 1:
300
- return cache.get_func(node_quadrature_weight_linear, self.name)
301
-
302
- return cache.get_func(node_quadrature_weight, self.name)
303
-
304
- def make_trace_node_quadrature_weight(self):
305
- ORDER = self.ORDER
306
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
307
- LOBATTO_WEIGHT = self.LOBATTO_WEIGHT
308
-
309
- def trace_node_quadrature_weight(
310
- args: Grid3DFunctionSpace.SpaceArg,
311
- element_index: ElementIndex,
312
- node_index_in_elt: int,
313
- ):
314
- if node_index_in_elt >= NODES_PER_ELEMENT:
315
- node_index_in_cell = node_index_in_elt - NODES_PER_ELEMENT
316
- else:
317
- node_index_in_cell = node_index_in_elt
318
-
319
- # We're either on a side interior or at a vertex
320
- # If we find one index at extremum, pick the two other
321
-
322
- node_i, node_j, node_k = self._node_ijk(node_index_in_cell)
323
-
324
- if node_i == 0 or node_i == ORDER:
325
- return LOBATTO_WEIGHT[node_j] * LOBATTO_WEIGHT[node_k]
326
-
327
- if node_j == 0 or node_j == ORDER:
328
- return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_k]
329
-
330
- return LOBATTO_WEIGHT[node_i] * LOBATTO_WEIGHT[node_j]
331
-
332
- def trace_node_quadrature_weight_linear(
333
- args: Grid3DFunctionSpace.SpaceArg,
334
- element_index: ElementIndex,
335
- node_index_in_elt: int,
336
- ):
337
- return 0.25
338
-
339
- from warp.fem import cache
340
-
341
- if ORDER == 1:
342
- return cache.get_func(trace_node_quadrature_weight_linear, self.name)
343
-
344
- return cache.get_func(trace_node_quadrature_weight, self.name)
345
-
346
- def make_element_inner_weight(self):
347
- ORDER = self.ORDER
348
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
349
- LOBATTO_COORDS = self.LOBATTO_COORDS
350
- LAGRANGE_SCALE = self.LAGRANGE_SCALE
351
-
352
- def element_inner_weight(
353
- args: Grid3DFunctionSpace.SpaceArg,
354
- element_index: ElementIndex,
355
- coords: Coords,
356
- node_index_in_elt: int,
357
- ):
358
- if node_index_in_elt < 0 or node_index_in_elt >= NODES_PER_ELEMENT:
359
- return 0.0
360
-
361
- node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
362
-
363
- w = float(1.0)
364
- for k in range(ORDER + 1):
365
- if k != node_i:
366
- w *= coords[0] - LOBATTO_COORDS[k]
367
- if k != node_j:
368
- w *= coords[1] - LOBATTO_COORDS[k]
369
- if k != node_k:
370
- w *= coords[2] - LOBATTO_COORDS[k]
371
-
372
- w *= LAGRANGE_SCALE[node_i] * LAGRANGE_SCALE[node_j] * LAGRANGE_SCALE[node_k]
373
-
374
- return w
375
-
376
- def element_inner_weight_linear(
377
- args: Grid3DFunctionSpace.SpaceArg,
378
- element_index: ElementIndex,
379
- coords: Coords,
380
- node_index_in_elt: int,
381
- ):
382
- if node_index_in_elt < 0 or node_index_in_elt >= 8:
383
- return 0.0
384
-
385
- v = Grid3DFunctionSpace._vertex_coords_f(node_index_in_elt)
386
-
387
- wx = (1.0 - coords[0]) * (1.0 - v[0]) + v[0] * coords[0]
388
- wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
389
- wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
390
- return wx * wy * wz
391
-
392
- from warp.fem import cache
393
-
394
- if ORDER == 1:
395
- return cache.get_func(element_inner_weight_linear, self.name)
91
+ self.element_node_index = self._make_element_node_index()
396
92
 
397
- return cache.get_func(element_inner_weight, self.name)
93
+ def node_count(self) -> int:
94
+ return (
95
+ (self.geometry.res[0] * self._shape.ORDER + 1)
96
+ * (self.geometry.res[1] * self._shape.ORDER + 1)
97
+ * (self.geometry.res[2] * self._shape.ORDER + 1)
98
+ )
398
99
 
399
- def make_element_inner_weight_gradient(self):
400
- ORDER = self.ORDER
401
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
402
- LOBATTO_COORDS = self.LOBATTO_COORDS
403
- LAGRANGE_SCALE = self.LAGRANGE_SCALE
100
+ def _make_element_node_index(self):
101
+ ORDER = self._shape.ORDER
404
102
 
405
- def element_inner_weight_gradient(
406
- args: Grid3DFunctionSpace.SpaceArg,
407
- element_index: ElementIndex,
408
- coords: Coords,
409
- node_index_in_elt: int,
410
- ):
411
- if node_index_in_elt < 0 or node_index_in_elt >= NODES_PER_ELEMENT:
412
- return wp.vec3(0.0)
413
-
414
- node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
415
-
416
- prefix_xy = float(1.0)
417
- prefix_yz = float(1.0)
418
- prefix_zx = float(1.0)
419
- for k in range(ORDER + 1):
420
- if k != node_i:
421
- prefix_yz *= coords[0] - LOBATTO_COORDS[k]
422
- if k != node_j:
423
- prefix_zx *= coords[1] - LOBATTO_COORDS[k]
424
- if k != node_k:
425
- prefix_xy *= coords[2] - LOBATTO_COORDS[k]
426
-
427
- prefix_x = prefix_zx * prefix_xy
428
- prefix_y = prefix_yz * prefix_xy
429
- prefix_z = prefix_zx * prefix_yz
430
-
431
- grad_x = float(0.0)
432
- grad_y = float(0.0)
433
- grad_z = float(0.0)
434
-
435
- for k in range(ORDER + 1):
436
- if k != node_i:
437
- delta_x = coords[0] - LOBATTO_COORDS[k]
438
- grad_x = grad_x * delta_x + prefix_x
439
- prefix_x *= delta_x
440
- if k != node_j:
441
- delta_y = coords[1] - LOBATTO_COORDS[k]
442
- grad_y = grad_y * delta_y + prefix_y
443
- prefix_y *= delta_y
444
- if k != node_k:
445
- delta_z = coords[2] - LOBATTO_COORDS[k]
446
- grad_z = grad_z * delta_z + prefix_z
447
- prefix_z *= delta_z
448
-
449
- grad = (
450
- LAGRANGE_SCALE[node_i]
451
- * LAGRANGE_SCALE[node_j]
452
- * LAGRANGE_SCALE[node_k]
453
- * wp.vec3(
454
- grad_x * args.inv_cell_size[0],
455
- grad_y * args.inv_cell_size[1],
456
- grad_z * args.inv_cell_size[2],
457
- )
458
- )
459
-
460
- return grad
461
-
462
- def element_inner_weight_gradient_linear(
463
- args: Grid3DFunctionSpace.SpaceArg,
103
+ @cache.dynamic_func(suffix=self.name)
104
+ def element_node_index(
105
+ cell_arg: Grid3D.CellArg,
106
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
464
107
  element_index: ElementIndex,
465
- coords: Coords,
466
108
  node_index_in_elt: int,
467
109
  ):
468
- if node_index_in_elt < 0 or node_index_in_elt >= 8:
469
- return wp.vec3(0.0)
470
-
471
- v = Grid3DFunctionSpace._vertex_coords_f(node_index_in_elt)
472
-
473
- wx = (1.0 - coords[0]) * (1.0 - v[0]) + v[0] * coords[0]
474
- wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
475
- wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
110
+ res = cell_arg.res
111
+ cell = Grid3D.get_cell(res, element_index)
476
112
 
477
- dx = (2.0 * v[0] - 1.0) * args.inv_cell_size[0]
478
- dy = (2.0 * v[1] - 1.0) * args.inv_cell_size[1]
479
- dz = (2.0 * v[2] - 1.0) * args.inv_cell_size[2]
113
+ node_i, node_j, node_k = self._shape._node_ijk(node_index_in_elt)
480
114
 
481
- return wp.vec3(dx * wy * wz, dy * wz * wx, dz * wx * wy)
115
+ node_x = ORDER * cell[0] + node_i
116
+ node_y = ORDER * cell[1] + node_j
117
+ node_z = ORDER * cell[2] + node_k
482
118
 
483
- from warp.fem import cache
119
+ node_pitch_y = (res[2] * ORDER) + 1
120
+ node_pitch_x = node_pitch_y * ((res[1] * ORDER) + 1)
121
+ node_index = node_pitch_x * node_x + node_pitch_y * node_y + node_z
484
122
 
485
- if ORDER == 1:
486
- return cache.get_func(element_inner_weight_gradient_linear, self.name)
123
+ return node_index
487
124
 
488
- return cache.get_func(element_inner_weight_gradient, self.name)
125
+ return element_node_index
489
126
 
490
127
 
491
- class GridTripolynomialSpace(Grid3DFunctionSpace):
128
+ class GridTripolynomialBasisSpace(Grid3DBasisSpace):
492
129
  def __init__(
493
130
  self,
494
131
  grid: Grid3D,
495
132
  degree: int,
496
- family: int,
497
- dtype: type = float,
498
- dof_mapper: DofMapper = None,
133
+ family: Polynomial,
499
134
  ):
500
- super().__init__(grid, dtype, dof_mapper)
501
-
502
135
  if family is None:
503
136
  family = Polynomial.LOBATTO_GAUSS_LEGENDRE
504
137
 
505
138
  if not is_closed(family):
506
- raise ValueError("A closed polynomial family is required to defined a continuous funciton space")
139
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
507
140
 
508
- self._shape = GridTripolynomialShapeFunctions(degree, family=family)
141
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
142
+ topology = forward_base_topology(GridTripolynomialSpaceTopology, grid, shape)
509
143
 
510
- self.ORDER = self._shape.ORDER
511
- self.NODES_PER_ELEMENT = self._shape.NODES_PER_ELEMENT
144
+ super().__init__(topology, shape)
512
145
 
513
- self.element_node_index = self._make_element_node_index()
514
- self.node_coords_in_element = self._shape.make_node_coords_in_element()
515
- self.node_quadrature_weight = self._shape.make_node_quadrature_weight()
516
- self.element_inner_weight = self._shape.make_element_inner_weight()
517
- self.element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
146
+ if isinstance(grid, Grid3D):
147
+ self.node_grid = self._node_grid
518
148
 
519
- self.element_outer_weight = self.element_inner_weight
520
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
521
-
522
- def node_count(self) -> int:
523
- return (
524
- (self._grid.res[0] * self.ORDER + 1)
525
- * (self._grid.res[1] * self.ORDER + 1)
526
- * (self._grid.res[2] * self.ORDER + 1)
527
- )
528
-
529
- def node_positions(self):
149
+ def _node_grid(self):
530
150
  res = self._grid.res
531
151
 
532
152
  cell_coords = np.array(self._shape.LOBATTO_COORDS)[:-1]
@@ -551,83 +171,23 @@ class GridTripolynomialSpace(Grid3DFunctionSpace):
551
171
 
552
172
  return np.meshgrid(X, Y, Z, indexing="ij")
553
173
 
554
- def _make_element_node_index(self):
555
- ORDER = self.ORDER
556
-
557
- def element_node_index(
558
- args: Grid3DFunctionSpace.SpaceArg,
559
- element_index: ElementIndex,
560
- node_index_in_elt: int,
561
- ):
562
- res = args.geo_arg.cell_arg.res
563
- cell = Grid3D.get_cell(res, element_index)
564
-
565
- node_i, node_j, node_k = self._shape._node_ijk(node_index_in_elt)
566
-
567
- node_x = ORDER * cell[0] + node_i
568
- node_y = ORDER * cell[1] + node_j
569
- node_z = ORDER * cell[2] + node_k
570
-
571
- node_pitch_y = (res[2] * ORDER) + 1
572
- node_pitch_x = node_pitch_y * ((res[1] * ORDER) + 1)
573
- node_index = node_pitch_x * node_x + node_pitch_y * node_y + node_z
574
-
575
- return node_index
576
-
577
- from warp.fem import cache
578
-
579
- return cache.get_func(element_node_index, f"{self.name}_{ORDER}")
580
-
581
- class Trace(Grid3DFunctionSpace.Trace):
582
- def __init__(self, space: "GridTripolynomialSpace"):
583
- super().__init__(space)
584
-
585
- self.element_node_index = self._make_element_node_index(space)
586
- self.node_coords_in_element = self._make_node_coords_in_element(space)
587
- self.node_quadrature_weight = space._shape.make_trace_node_quadrature_weight()
588
174
 
589
- self.element_inner_weight = self._make_element_inner_weight(space)
590
- self.element_inner_weight_gradient = self._make_element_inner_weight_gradient(space)
591
-
592
- self.element_outer_weight = self._make_element_outer_weight(space)
593
- self.element_outer_weight_gradient = self._make_element_outer_weight_gradient(space)
594
-
595
- def trace(self):
596
- return GridTripolynomialSpace.Trace(self)
597
-
598
-
599
- class GridDGTripolynomialSpace(Grid3DFunctionSpace):
175
+ class GridDGTripolynomialBasisSpace(Grid3DBasisSpace):
600
176
  def __init__(
601
177
  self,
602
178
  grid: Grid3D,
603
179
  degree: int,
604
180
  family: Polynomial,
605
- dtype: type = float,
606
- dof_mapper: DofMapper = None,
607
181
  ):
608
- super().__init__(grid, dtype, dof_mapper)
609
-
610
182
  if family is None:
611
183
  family = Polynomial.LOBATTO_GAUSS_LEGENDRE
612
184
 
613
- self._shape = GridTripolynomialShapeFunctions(degree, family=family)
185
+ shape = CubeTripolynomialShapeFunctions(degree, family=family)
186
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape)
614
187
 
615
- self.ORDER = self._shape.ORDER
616
- self.NODES_PER_ELEMENT = self._shape.NODES_PER_ELEMENT
617
-
618
- self.element_node_index = self._make_element_node_index()
619
- self.node_coords_in_element = self._shape.make_node_coords_in_element()
620
- self.node_quadrature_weight = self._shape.make_node_quadrature_weight()
621
- self.element_inner_weight = self._shape.make_element_inner_weight()
622
- self.element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
188
+ super().__init__(shape=shape, topology=topology)
623
189
 
624
- self.element_outer_weight = self.element_inner_weight
625
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
626
-
627
- def node_count(self) -> int:
628
- return self._grid.cell_count() * (self.ORDER + 1) ** 3
629
-
630
- def node_positions(self):
190
+ def node_grid(self):
631
191
  res = self._grid.res
632
192
 
633
193
  cell_coords = np.array(self._shape.LOBATTO_COORDS)
@@ -649,23 +209,98 @@ class GridDGTripolynomialSpace(Grid3DFunctionSpace):
649
209
 
650
210
  return np.meshgrid(X, Y, Z, indexing="ij")
651
211
 
212
+
213
+ class Grid3DSerendipitySpaceTopology(Grid3DSpaceTopology):
214
+ def __init__(self, grid: Grid3D, shape: CubeSerendipityShapeFunctions):
215
+ super().__init__(grid, shape)
216
+
217
+ self.element_node_index = self._make_element_node_index()
218
+
219
+ def node_count(self) -> int:
220
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.edge_count()
221
+
652
222
  def _make_element_node_index(self):
653
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
223
+ ORDER = self._shape.ORDER
654
224
 
225
+ @cache.dynamic_func(suffix=self.name)
655
226
  def element_node_index(
656
- args: Grid3DFunctionSpace.SpaceArg,
227
+ cell_arg: Grid3D.CellArg,
228
+ topo_arg: Grid3DSpaceTopology.TopologyArg,
657
229
  element_index: ElementIndex,
658
230
  node_index_in_elt: int,
659
231
  ):
660
- return element_index * NODES_PER_ELEMENT + node_index_in_elt
232
+ res = cell_arg.res
233
+ cell = Grid3D.get_cell(res, element_index)
661
234
 
662
- from warp.fem import cache
235
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
663
236
 
664
- return cache.get_func(element_node_index, f"{self.name}_{self.ORDER}")
237
+ if node_type == CubeSerendipityShapeFunctions.VERTEX:
238
+ return Grid3DSpaceTopology._vertex_index(cell_arg, element_index, type_index)
665
239
 
666
- class Trace(GridTripolynomialSpace.Trace):
667
- def __init__(self, space: "GridDGTripolynomialSpace"):
668
- super().__init__(space)
240
+ axis = CubeSerendipityShapeFunctions._edge_axis(node_type)
241
+ node_all = CubeSerendipityShapeFunctions._edge_coords(type_index)
669
242
 
670
- def trace(self):
671
- return GridDGTripolynomialSpace.Trace(self)
243
+ res = cell_arg.res
244
+
245
+ edge_index = 0
246
+ if axis > 0:
247
+ edge_index += (res[1] + 1) * (res[2] + 1) * res[0]
248
+ if axis > 1:
249
+ edge_index += (res[0] + 1) * (res[2] + 1) * res[1]
250
+
251
+ res_loc = Grid3D._world_to_local(axis, res)
252
+ cell_loc = Grid3D._world_to_local(axis, cell)
253
+
254
+ edge_index += (res_loc[1] + 1) * (res_loc[2] + 1) * cell_loc[0]
255
+ edge_index += (res_loc[2] + 1) * (cell_loc[1] + node_all[1])
256
+ edge_index += cell_loc[2] + node_all[2]
257
+
258
+ vertex_count = (res[0] + 1) * (res[1] + 1) * (res[2] + 1)
259
+
260
+ return vertex_count + (ORDER - 1) * edge_index + (node_all[0] - 1)
261
+
262
+ return element_node_index
263
+
264
+
265
+ class Grid3DSerendipityBasisSpace(Grid3DBasisSpace):
266
+ def __init__(
267
+ self,
268
+ grid: Grid3D,
269
+ degree: int,
270
+ family: Polynomial,
271
+ ):
272
+ if family is None:
273
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
274
+
275
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
276
+ topology = forward_base_topology(Grid3DSerendipitySpaceTopology, grid, shape=shape)
277
+
278
+ super().__init__(topology=topology, shape=shape)
279
+
280
+
281
+ class Grid3DDGSerendipityBasisSpace(Grid3DBasisSpace):
282
+ def __init__(
283
+ self,
284
+ grid: Grid3D,
285
+ degree: int,
286
+ family: Polynomial,
287
+ ):
288
+ if family is None:
289
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
290
+
291
+ shape = CubeSerendipityShapeFunctions(degree, family=family)
292
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape=shape)
293
+
294
+ super().__init__(topology=topology, shape=shape)
295
+
296
+
297
+ class Grid3DDGPolynomialBasisSpace(Grid3DBasisSpace):
298
+ def __init__(
299
+ self,
300
+ grid: Grid3D,
301
+ degree: int,
302
+ ):
303
+ shape = CubeNonConformingPolynomialShapeFunctions(degree)
304
+ topology = Grid3DDiscontinuousSpaceTopology(grid, shape=shape)
305
+
306
+ super().__init__(topology=topology, shape=shape)