fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/helpers_num.py ADDED
@@ -0,0 +1,11 @@
1
+ """Numeric helpers (array operators)."""
2
+ from __future__ import annotations
3
+
4
+ from .physics.operators import dot, ddot, sym_grad, transpose_last2
5
+
6
+ __all__ = [
7
+ "dot",
8
+ "ddot",
9
+ "sym_grad",
10
+ "transpose_last2",
11
+ ]
fluxfem/helpers_wf.py ADDED
@@ -0,0 +1,42 @@
1
+ """WeakForm/Expr helpers (symbolic operators)."""
2
+ from __future__ import annotations
3
+
4
+ from .core.weakform import (
5
+ grad,
6
+ sym_grad,
7
+ dot,
8
+ sdot,
9
+ ddot,
10
+ inner,
11
+ action,
12
+ gaction,
13
+ I,
14
+ det,
15
+ inv,
16
+ transpose,
17
+ transpose_last2,
18
+ log,
19
+ normal,
20
+ ds,
21
+ dOmega,
22
+ )
23
+
24
+ __all__ = [
25
+ "grad",
26
+ "sym_grad",
27
+ "dot",
28
+ "sdot",
29
+ "ddot",
30
+ "inner",
31
+ "action",
32
+ "gaction",
33
+ "I",
34
+ "det",
35
+ "inv",
36
+ "transpose",
37
+ "transpose_last2",
38
+ "log",
39
+ "normal",
40
+ "ds",
41
+ "dOmega",
42
+ ]
@@ -0,0 +1,29 @@
1
+ from .hex import HexMesh, HexMeshPytree, StructuredHexBox, tag_axis_minmax_facets
2
+ from .tet import TetMesh, TetMeshPytree, StructuredTetBox, StructuredTetTensorBox
3
+ from .base import BaseMesh, BaseMeshPytree
4
+ from .predicate import bbox_predicate, plane_predicate, axis_plane_predicate, slab_predicate
5
+ from .surface import SurfaceMesh, SurfaceMeshPytree
6
+ from .io import load_gmsh_mesh, load_gmsh_hex_mesh, load_gmsh_tet_mesh, make_surface_from_facets
7
+
8
+ __all__ = [
9
+ "BaseMesh",
10
+ "BaseMeshPytree",
11
+ "bbox_predicate",
12
+ "plane_predicate",
13
+ "axis_plane_predicate",
14
+ "slab_predicate",
15
+ "HexMesh",
16
+ "HexMeshPytree",
17
+ "StructuredHexBox",
18
+ "tag_axis_minmax_facets",
19
+ "TetMesh",
20
+ "TetMeshPytree",
21
+ "StructuredTetBox",
22
+ "StructuredTetTensorBox",
23
+ "SurfaceMesh",
24
+ "SurfaceMeshPytree",
25
+ "load_gmsh_mesh",
26
+ "load_gmsh_hex_mesh",
27
+ "load_gmsh_tet_mesh",
28
+ "make_surface_from_facets",
29
+ ]
fluxfem/mesh/base.py ADDED
@@ -0,0 +1,244 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Iterable, Optional, Sequence
4
+ import numpy as np
5
+ import jax
6
+ import jax.numpy as jnp
7
+
8
+
9
+ @dataclass
10
+ class BaseMeshClosure:
11
+ coords: jnp.ndarray
12
+ conn: jnp.ndarray
13
+ cell_tags: Optional[jnp.ndarray] = None
14
+ node_tags: Optional[jnp.ndarray] = None
15
+
16
+ @property
17
+ def n_nodes(self) -> int:
18
+ return self.coords.shape[0]
19
+
20
+ @property
21
+ def n_elems(self) -> int:
22
+ return self.conn.shape[0]
23
+
24
+ def element_coords(self) -> jnp.ndarray:
25
+ return self.coords[self.conn]
26
+
27
+ # ------------------------------------------------------------------
28
+ # Face patterns must be provided by concrete mesh types.
29
+ def face_node_patterns(self):
30
+ """
31
+ Return a list of tuples, each tuple giving local node indices of a face.
32
+ Override in concrete mesh classes (HexMesh, TetMesh, etc).
33
+ """
34
+ raise NotImplementedError("face_node_patterns must be implemented by mesh subtype")
35
+
36
+ # Convenience helpers for boundary tagging / DOF lookup
37
+ def node_indices_where(self, predicate: Callable[[np.ndarray], np.ndarray]) -> np.ndarray:
38
+ """
39
+ Return node indices whose coordinates satisfy the predicate.
40
+ predicate: callable that takes coords (np.ndarray of shape (n_nodes, dim)) and returns boolean mask.
41
+ """
42
+ coords_np = np.asarray(self.coords)
43
+ mask = predicate(coords_np)
44
+ return np.nonzero(mask)[0]
45
+
46
+ def node_indices_where_point(self, predicate: Callable[[np.ndarray], bool]) -> np.ndarray:
47
+ """
48
+ Return node indices for which predicate(coord) is True.
49
+ predicate: callable accepting a single point (dim,) -> bool
50
+ """
51
+ coords_np = np.asarray(self.coords)
52
+ mask = [bool(predicate(pt)) for pt in coords_np]
53
+ return np.nonzero(mask)[0]
54
+
55
+ def axis_extrema_nodes(self, axis: int = 0, side: str = "min", tol: float = 1e-8) -> np.ndarray:
56
+ """
57
+ Nodes lying on min or max of a given axis.
58
+ side: "min" or "max"
59
+ """
60
+ coords_np = np.asarray(self.coords)
61
+ vals = coords_np[:, axis]
62
+ target = vals.min() if side == "min" else vals.max()
63
+ mask = np.isclose(vals, target, atol=tol)
64
+ return np.nonzero(mask)[0]
65
+
66
+ def boundary_nodes_bbox(self, tol: float = 1e-8) -> np.ndarray:
67
+ """
68
+ Nodes on the axis-aligned bounding box (min/max in each coordinate).
69
+ Useful for box-shaped meshes like StructuredHexBox.
70
+ """
71
+ coords_np = np.asarray(self.coords)
72
+ mins = coords_np.min(axis=0)
73
+ maxs = coords_np.max(axis=0)
74
+ mask = np.zeros(coords_np.shape[0], dtype=bool)
75
+ for axis in range(coords_np.shape[1]):
76
+ mask |= np.isclose(coords_np[:, axis], mins[axis], atol=tol)
77
+ mask |= np.isclose(coords_np[:, axis], maxs[axis], atol=tol)
78
+ return np.nonzero(mask)[0]
79
+
80
+ def node_dofs(self, nodes: Iterable[int], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
81
+ """
82
+ Build flattened DOF indices for given node ids.
83
+
84
+ components:
85
+ - sequence of component indices (e.g., [0,1,2])
86
+ - or string like "x", "xy", "xyz" (case-insensitive; maps x/y/z -> 0/1/2)
87
+ dof_per_node: optional; inferred from max component index + 1 if not provided.
88
+ """
89
+ nodes_arr = np.asarray(list(nodes), dtype=int)
90
+ if isinstance(components, str):
91
+ comp_map = {"x": 0, "y": 1, "z": 2}
92
+ comps = np.asarray([comp_map[c.lower()] for c in components], dtype=int)
93
+ else:
94
+ comps = np.asarray(list(components), dtype=int)
95
+ inferred = int(comps.max()) + 1 if comps.size else 1
96
+ dofpn = inferred if dof_per_node is None else int(dof_per_node)
97
+ if dofpn <= comps.max():
98
+ raise ValueError(f"dof_per_node={dofpn} is inconsistent with requested component {comps.max()}")
99
+ dofs = [dofpn * int(n) + int(c) for n in nodes_arr for c in comps]
100
+ return np.asarray(dofs, dtype=int)
101
+
102
+ def dofs_where(self, predicate: Callable[[np.ndarray], np.ndarray], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
103
+ """
104
+ DOF indices for nodes selected by a predicate over all coords.
105
+ predicate takes coords (np.ndarray, shape (n_nodes, dim)) and returns boolean mask.
106
+ """
107
+ nodes = self.node_indices_where(predicate)
108
+ return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
109
+
110
+ def dofs_where_point(self, predicate: Callable[[np.ndarray], bool], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
111
+ """
112
+ DOF indices for nodes selected by a per-point predicate.
113
+ predicate takes a single coord (dim,) -> bool.
114
+ """
115
+ nodes = self.node_indices_where_point(predicate)
116
+ return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
117
+
118
+ def boundary_dofs_where(self, predicate: Callable[[np.ndarray], np.ndarray], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
119
+ """
120
+ Return DOF indices for boundary nodes whose coordinates satisfy predicate.
121
+ predicate takes coords (np.ndarray, shape (n_nodes, dim)) and returns boolean mask.
122
+ """
123
+ coords_np = np.asarray(self.coords)
124
+ mask = np.asarray(predicate(coords_np), dtype=bool)
125
+ bmask = self.boundary_node_mask()
126
+ nodes = np.nonzero(mask & bmask)[0]
127
+ return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
128
+
129
+ def boundary_dofs_bbox(
130
+ self,
131
+ *,
132
+ components: Sequence[int] | str = "xyz",
133
+ dof_per_node: Optional[int] = None,
134
+ tol: float = 1e-8,
135
+ ) -> np.ndarray:
136
+ """
137
+ DOF indices on the axis-aligned bounding box (min/max in each coordinate).
138
+ """
139
+ nodes = self.boundary_nodes_bbox(tol=tol)
140
+ return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
141
+
142
+ def boundary_node_indices(self) -> np.ndarray:
143
+ """
144
+ Return node indices on the boundary based on element face adjacency.
145
+ """
146
+ cached = getattr(self, "_boundary_nodes_cache", None)
147
+ if cached is not None:
148
+ return cached
149
+ conn = np.asarray(self.conn)
150
+ patterns = self.face_node_patterns()
151
+ face_counts: dict[tuple[int, ...], int] = {}
152
+ for elem_conn in conn:
153
+ for pattern in patterns:
154
+ nodes = tuple(sorted(int(elem_conn[i]) for i in pattern))
155
+ face_counts[nodes] = face_counts.get(nodes, 0) + 1
156
+ bnodes = set()
157
+ for nodes, count in face_counts.items():
158
+ if count == 1:
159
+ bnodes.update(nodes)
160
+ out = np.asarray(sorted(bnodes), dtype=int)
161
+ setattr(self, "_boundary_nodes_cache", out)
162
+ return out
163
+
164
+ def boundary_node_mask(self) -> np.ndarray:
165
+ """
166
+ Return boolean mask for boundary nodes (shape: n_nodes).
167
+ """
168
+ mask = np.zeros(self.n_nodes, dtype=bool)
169
+ nodes = self.boundary_node_indices()
170
+ mask[nodes] = True
171
+ return mask
172
+
173
+ def make_node_tags(self, predicate: Callable[[np.ndarray], np.ndarray], tag: int, base: Optional[np.ndarray] = None) -> jnp.ndarray:
174
+ """
175
+ Build a node_tags array by applying predicate to coords and setting tag where True.
176
+ Returns a jnp.ndarray (int32). Does not mutate the mesh.
177
+ """
178
+ base_tags = np.zeros(self.n_nodes, dtype=np.int32) if base is None else np.asarray(base, dtype=np.int32).copy()
179
+ mask = predicate(np.asarray(self.coords))
180
+ base_tags[mask] = int(tag)
181
+ return jnp.asarray(base_tags, dtype=jnp.int32)
182
+
183
+ def with_node_tags(self, node_tags: np.ndarray | jnp.ndarray):
184
+ """
185
+ Return a new mesh instance with provided node_tags.
186
+ """
187
+ return self.__class__(coords=self.coords, conn=self.conn, cell_tags=self.cell_tags, node_tags=jnp.asarray(node_tags))
188
+
189
+ def boundary_facets_where(self, predicate: Callable[[np.ndarray], bool], tag: int | None = None):
190
+ """
191
+ Collect boundary facets whose node coordinates satisfy predicate.
192
+
193
+ predicate receives a (n_face_nodes, dim) NumPy array and returns True/False.
194
+ Returns facets (and optional tags if tag is provided).
195
+ """
196
+ coords = np.asarray(self.coords)
197
+ conn = np.asarray(self.conn)
198
+ patterns = self.face_node_patterns()
199
+
200
+ facet_map: dict[tuple[int, ...], tuple[list[int], Optional[int]]] = {}
201
+
202
+ for elem_conn in conn:
203
+ elem_nodes = coords[elem_conn]
204
+ for pattern in patterns:
205
+ nodes = [int(elem_conn[i]) for i in pattern]
206
+ face_coords = elem_nodes[list(pattern)]
207
+ if not predicate(face_coords):
208
+ continue
209
+ key = tuple(sorted(nodes))
210
+ if key not in facet_map:
211
+ facet_map[key] = (nodes, tag)
212
+
213
+ if not facet_map:
214
+ if tag is None:
215
+ return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32)
216
+ return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
217
+
218
+ facets = []
219
+ tags = []
220
+ for nodes, t in facet_map.values():
221
+ facets.append(nodes)
222
+ if tag is not None:
223
+ tags.append(t if t is not None else 0)
224
+
225
+ facets_arr = jnp.array(facets, dtype=jnp.int32)
226
+ if tag is None:
227
+ return facets_arr
228
+ return facets_arr, jnp.array(tags, dtype=jnp.int32)
229
+
230
+
231
+ @jax.tree_util.register_pytree_node_class
232
+ class BaseMeshPytree(BaseMeshClosure):
233
+ def tree_flatten(self):
234
+ children = (self.coords, self.conn, self.cell_tags, self.node_tags)
235
+ return children, {}
236
+
237
+ @classmethod
238
+ def tree_unflatten(cls, aux, children):
239
+ coords, conn, cell_tags, node_tags = children
240
+ return cls(coords, conn, cell_tags, node_tags)
241
+
242
+
243
+ BaseMesh = BaseMeshClosure
244
+
fluxfem/mesh/hex.py ADDED
@@ -0,0 +1,327 @@
1
+
2
+
3
+ from __future__ import annotations
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Dict, Tuple, List, Callable
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from ..core.dtypes import DEFAULT_DTYPE
9
+ import numpy as np
10
+
11
+ from .base import BaseMesh, BaseMeshPytree
12
+
13
+
14
+ @dataclass
15
+ class HexMesh(BaseMesh):
16
+ """
17
+ Structured / unstructured hex mesh (8-node linear hex elements).
18
+
19
+ coords: (n_nodes, 3) float32
20
+ conn: (n_elems, 8) int32 # node indices of each element
21
+ """
22
+ coords: jnp.ndarray # shape (n_nodes, 3)
23
+ conn: jnp.ndarray # shape (n_elems, 8), int32
24
+
25
+ def face_node_patterns(self):
26
+ return [
27
+ (0, 1, 2, 3), # -z
28
+ (4, 5, 6, 7), # +z
29
+ (0, 1, 5, 4), # -y
30
+ (3, 2, 6, 7), # +y
31
+ (0, 4, 7, 3), # -x
32
+ (1, 2, 6, 5), # +x
33
+ ]
34
+
35
+
36
+ @jax.tree_util.register_pytree_node_class
37
+ @dataclass(eq=False)
38
+ class HexMeshPytree(BaseMeshPytree):
39
+ def face_node_patterns(self):
40
+ return [
41
+ (0, 1, 2, 3), # -z
42
+ (4, 5, 6, 7), # +z
43
+ (0, 1, 5, 4), # -y
44
+ (3, 2, 6, 7), # +y
45
+ (0, 4, 7, 3), # -x
46
+ (1, 2, 6, 5), # +x
47
+ ]
48
+
49
+
50
+ def tag_axis_minmax_facets(
51
+ mesh: HexMesh,
52
+ axis: int = 0,
53
+ dirichlet_tag: int = 1,
54
+ neumann_tag: int = 2,
55
+ tol: float = 1e-8,
56
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
57
+ """
58
+ Tag boundary facets on min/max of the given axis.
59
+
60
+ Returns:
61
+ facets: (n_facets, 4) int32, quad node ids
62
+ facet_tags: (n_facets,) int32, dirichlet_tag on min side, neumann_tag on max side
63
+ """
64
+ coords = np.asarray(mesh.coords)
65
+ conn = np.asarray(mesh.conn)
66
+
67
+ axis_vals = coords[:, axis]
68
+ v_min = float(axis_vals.min())
69
+ v_max = float(axis_vals.max())
70
+
71
+ face_patterns: List[Tuple[int, int, int, int]] = [
72
+ (0, 1, 2, 3), # -z
73
+ (4, 5, 6, 7), # +z
74
+ (0, 1, 5, 4), # -y
75
+ (3, 2, 6, 7), # +y
76
+ (0, 4, 7, 3), # -x
77
+ (1, 2, 6, 5), # +x
78
+ ]
79
+
80
+ facet_map: Dict[Tuple[int, ...], Tuple[List[int], int]] = {}
81
+
82
+ for elem_conn in conn:
83
+ elem_nodes = coords[elem_conn]
84
+ for pattern in face_patterns:
85
+ nodes = [int(elem_conn[i]) for i in pattern]
86
+ vals = elem_nodes[list(pattern), axis]
87
+ tag = None
88
+ if np.allclose(vals, v_min, atol=tol):
89
+ tag = dirichlet_tag
90
+ elif np.allclose(vals, v_max, atol=tol):
91
+ tag = neumann_tag
92
+
93
+ if tag is None:
94
+ continue
95
+
96
+ key = tuple(sorted(nodes))
97
+ if key not in facet_map:
98
+ facet_map[key] = (nodes, tag)
99
+
100
+ if not facet_map:
101
+ return jnp.empty((0, 4), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
102
+
103
+ facets = []
104
+ tags = []
105
+ for nodes, tag in facet_map.values():
106
+ facets.append(nodes)
107
+ tags.append(tag)
108
+
109
+ return jnp.array(facets, dtype=jnp.int32), jnp.array(tags, dtype=jnp.int32)
110
+
111
+
112
+ @dataclass
113
+ class StructuredHexBox:
114
+ """
115
+ Uniform hex mesh generator on a box, returned as an unstructured HexMesh.
116
+ """
117
+ nx: int
118
+ ny: int
119
+ nz: int
120
+ lx: float = 1.0
121
+ ly: float = 1.0
122
+ lz: float = 1.0
123
+ origin: tuple[float, float, float] = (0.0, 0.0, 0.0)
124
+ order: int = 1 # 1: 8-node Hex, 2: 20-node serendipity Hex, 3: 27-node triquadratic Hex
125
+
126
+ def build(self) -> HexMesh:
127
+ """
128
+ Build a regular grid of nx×ny×nz elements over [origin, origin + (lx, ly, lz)].
129
+ order=1 → 8-node Hex, order=2 → 20-node serendipity Hex。
130
+ """
131
+ if self.nx <= 0 or self.ny <= 0 or self.nz <= 0:
132
+ raise ValueError("nx, ny, nz must be positive")
133
+ if self.order not in (1, 2, 3):
134
+ raise ValueError("order must be 1, 2, or 3")
135
+
136
+ ox, oy, oz = self.origin
137
+ xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=DEFAULT_DTYPE)
138
+ ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=DEFAULT_DTYPE)
139
+ zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=DEFAULT_DTYPE)
140
+
141
+ if self.order == 1:
142
+ return self._build_hex8(xs, ys, zs)
143
+ if self.order == 2:
144
+ return self._build_hex20(xs, ys, zs)
145
+ return self._build_hex27(xs, ys, zs)
146
+
147
+ def _build_hex8(self, xs, ys, zs) -> HexMesh:
148
+ coords_list = []
149
+ for k in range(self.nz + 1):
150
+ for j in range(self.ny + 1):
151
+ for i in range(self.nx + 1):
152
+ coords_list.append([xs[i], ys[j], zs[k]])
153
+ coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
154
+
155
+ def node_id(i: int, j: int, k: int) -> int:
156
+ return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
157
+
158
+ conn_list = []
159
+ for k in range(self.nz):
160
+ for j in range(self.ny):
161
+ for i in range(self.nx):
162
+ n000 = node_id(i, j, k)
163
+ n100 = node_id(i + 1, j, k)
164
+ n110 = node_id(i + 1, j + 1, k)
165
+ n010 = node_id(i, j + 1, k)
166
+ n001 = node_id(i, j, k + 1)
167
+ n101 = node_id(i + 1, j, k + 1)
168
+ n111 = node_id(i + 1, j + 1, k + 1)
169
+ n011 = node_id(i, j + 1, k + 1)
170
+ conn_list.append([n000, n100, n110, n010, n001, n101, n111, n011])
171
+
172
+ conn = jnp.array(conn_list, dtype=jnp.int32)
173
+ return HexMesh(coords=coords, conn=conn)
174
+
175
+ def _build_hex20(self, xs, ys, zs) -> HexMesh:
176
+ """
177
+ Build 20-node serendipity Hex mesh (corner + edge midpoints).
178
+ """
179
+ coords_list: List[List[float]] = []
180
+ node_map: Dict[Tuple[float, float, float], int] = {}
181
+
182
+ def add_node(pt: np.ndarray) -> int:
183
+ key = tuple(np.round(pt.astype(np.float64), 8))
184
+ if key not in node_map:
185
+ node_map[key] = len(coords_list)
186
+ coords_list.append([float(pt[0]), float(pt[1]), float(pt[2])])
187
+ return node_map[key]
188
+
189
+ # pre-create corner nodes
190
+ for k in range(self.nz + 1):
191
+ for j in range(self.ny + 1):
192
+ for i in range(self.nx + 1):
193
+ add_node(np.array([xs[i], ys[j], zs[k]], dtype=np.float64))
194
+
195
+ conn_list = []
196
+ for k in range(self.nz):
197
+ for j in range(self.ny):
198
+ for i in range(self.nx):
199
+ p000 = np.array([xs[i], ys[j], zs[k]], dtype=np.float64)
200
+ p100 = np.array([xs[i + 1], ys[j], zs[k]], dtype=np.float64)
201
+ p110 = np.array([xs[i + 1], ys[j + 1], zs[k]], dtype=np.float64)
202
+ p010 = np.array([xs[i], ys[j + 1], zs[k]], dtype=np.float64)
203
+ p001 = np.array([xs[i], ys[j], zs[k + 1]], dtype=np.float64)
204
+ p101 = np.array([xs[i + 1], ys[j], zs[k + 1]], dtype=np.float64)
205
+ p111 = np.array([xs[i + 1], ys[j + 1], zs[k + 1]], dtype=np.float64)
206
+ p011 = np.array([xs[i], ys[j + 1], zs[k + 1]], dtype=np.float64)
207
+
208
+ c000 = add_node(p000)
209
+ c100 = add_node(p100)
210
+ c110 = add_node(p110)
211
+ c010 = add_node(p010)
212
+ c001 = add_node(p001)
213
+ c101 = add_node(p101)
214
+ c111 = add_node(p111)
215
+ c011 = add_node(p011)
216
+
217
+ # edge midpoints
218
+ e01 = add_node(0.5 * (p000 + p100))
219
+ e12 = add_node(0.5 * (p100 + p110))
220
+ e23 = add_node(0.5 * (p110 + p010))
221
+ e30 = add_node(0.5 * (p010 + p000))
222
+
223
+ e45 = add_node(0.5 * (p001 + p101))
224
+ e56 = add_node(0.5 * (p101 + p111))
225
+ e67 = add_node(0.5 * (p111 + p011))
226
+ e74 = add_node(0.5 * (p011 + p001))
227
+
228
+ e04 = add_node(0.5 * (p000 + p001))
229
+ e15 = add_node(0.5 * (p100 + p101))
230
+ e26 = add_node(0.5 * (p110 + p111))
231
+ e37 = add_node(0.5 * (p010 + p011))
232
+
233
+ conn_list.append(
234
+ [
235
+ c000, c100, c110, c010, # corners
236
+ c001, c101, c111, c011,
237
+ e01, e12, e23, e30, # edge mids
238
+ e45, e56, e67, e74,
239
+ e04, e15, e26, e37,
240
+ ]
241
+ )
242
+
243
+ coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
244
+ conn = jnp.array(conn_list, dtype=jnp.int32)
245
+ return HexMesh(coords=coords, conn=conn)
246
+
247
+ def _build_hex27(self, xs, ys, zs) -> HexMesh:
248
+ """
249
+ Build 27-node triquadratic Hex mesh (tensor-product nodes: corners, edge mids, face centers, body center).
250
+ """
251
+ coords_list: List[List[float]] = []
252
+ node_map: Dict[Tuple[float, float, float], int] = {}
253
+
254
+ def add_node(pt: np.ndarray) -> int:
255
+ key = tuple(np.round(pt.astype(np.float64), 8))
256
+ if key not in node_map:
257
+ node_map[key] = len(coords_list)
258
+ coords_list.append([float(pt[0]), float(pt[1]), float(pt[2])])
259
+ return node_map[key]
260
+
261
+ def mid(a, b):
262
+ return 0.5 * (a + b)
263
+
264
+ # pre-create grid nodes at original vertices and midpoints on axes
265
+ xs_mid = [mid(xs[i], xs[i + 1]) for i in range(len(xs) - 1)]
266
+ ys_mid = [mid(ys[j], ys[j + 1]) for j in range(len(ys) - 1)]
267
+ zs_mid = [mid(zs[k], zs[k + 1]) for k in range(len(zs) - 1)]
268
+
269
+ # create all possible nodes (vertices + edge mids + face centers + cell centers)
270
+ for k_idx, zk in enumerate(zs):
271
+ for j_idx, yj in enumerate(ys):
272
+ for i_idx, xi in enumerate(xs):
273
+ add_node(np.array([xi, yj, zk], dtype=np.float64))
274
+ for k_idx, zk in enumerate(zs):
275
+ for j_idx, yj in enumerate(ys):
276
+ for i_mid in xs_mid:
277
+ add_node(np.array([i_mid, yj, zk], dtype=np.float64))
278
+ for k_idx, zk in enumerate(zs):
279
+ for i_idx, xi in enumerate(xs):
280
+ for j_mid in ys_mid:
281
+ add_node(np.array([xi, j_mid, zk], dtype=np.float64))
282
+ for j_idx, yj in enumerate(ys):
283
+ for i_idx, xi in enumerate(xs):
284
+ for k_mid in zs_mid:
285
+ add_node(np.array([xi, yj, k_mid], dtype=np.float64))
286
+ # face centers
287
+ for k_idx, zk in enumerate(zs):
288
+ for j_mid in ys_mid:
289
+ for i_mid in xs_mid:
290
+ add_node(np.array([i_mid, j_mid, zk], dtype=np.float64))
291
+ for k_mid in zs_mid:
292
+ for j_idx, yj in enumerate(ys):
293
+ for i_mid in xs_mid:
294
+ add_node(np.array([i_mid, yj, k_mid], dtype=np.float64))
295
+ for k_mid in zs_mid:
296
+ for j_mid in ys_mid:
297
+ for i_idx, xi in enumerate(xs):
298
+ add_node(np.array([xi, j_mid, k_mid], dtype=np.float64))
299
+ # cell centers (unique per cell)
300
+ for k in range(self.nz):
301
+ for j in range(self.ny):
302
+ for i in range(self.nx):
303
+ cx = mid(xs[i], xs[i + 1])
304
+ cy = mid(ys[j], ys[j + 1])
305
+ cz = mid(zs[k], zs[k + 1])
306
+ add_node(np.array([cx, cy, cz], dtype=np.float64))
307
+
308
+ conn_list = []
309
+ for k in range(self.nz):
310
+ for j in range(self.ny):
311
+ for i in range(self.nx):
312
+ x_vals = [xs[i], mid(xs[i], xs[i + 1]), xs[i + 1]]
313
+ y_vals = [ys[j], mid(ys[j], ys[j + 1]), ys[j + 1]]
314
+ z_vals = [zs[k], mid(zs[k], zs[k + 1]), zs[k + 1]]
315
+
316
+ nodes = []
317
+ for kk in range(3):
318
+ for jj in range(3):
319
+ for ii in range(3):
320
+ nodes.append(add_node(np.array([x_vals[ii], y_vals[jj], z_vals[kk]], dtype=np.float64)))
321
+
322
+ # order in lexicographic k,j,i -> length 27
323
+ conn_list.append(nodes)
324
+
325
+ coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
326
+ conn = jnp.array(conn_list, dtype=jnp.int32)
327
+ return HexMesh(coords=coords, conn=conn)