fluxfem 0.1.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +316 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +818 -0
- fluxfem/helpers_num.py +11 -0
- fluxfem/helpers_wf.py +42 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.1a0.dist-info/METADATA +111 -0
- fluxfem-0.1.1a0.dist-info/RECORD +47 -0
- fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
- fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/helpers_num.py
ADDED
fluxfem/helpers_wf.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""WeakForm/Expr helpers (symbolic operators)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from .core.weakform import (
|
|
5
|
+
grad,
|
|
6
|
+
sym_grad,
|
|
7
|
+
dot,
|
|
8
|
+
sdot,
|
|
9
|
+
ddot,
|
|
10
|
+
inner,
|
|
11
|
+
action,
|
|
12
|
+
gaction,
|
|
13
|
+
I,
|
|
14
|
+
det,
|
|
15
|
+
inv,
|
|
16
|
+
transpose,
|
|
17
|
+
transpose_last2,
|
|
18
|
+
log,
|
|
19
|
+
normal,
|
|
20
|
+
ds,
|
|
21
|
+
dOmega,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"grad",
|
|
26
|
+
"sym_grad",
|
|
27
|
+
"dot",
|
|
28
|
+
"sdot",
|
|
29
|
+
"ddot",
|
|
30
|
+
"inner",
|
|
31
|
+
"action",
|
|
32
|
+
"gaction",
|
|
33
|
+
"I",
|
|
34
|
+
"det",
|
|
35
|
+
"inv",
|
|
36
|
+
"transpose",
|
|
37
|
+
"transpose_last2",
|
|
38
|
+
"log",
|
|
39
|
+
"normal",
|
|
40
|
+
"ds",
|
|
41
|
+
"dOmega",
|
|
42
|
+
]
|
fluxfem/mesh/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .hex import HexMesh, HexMeshPytree, StructuredHexBox, tag_axis_minmax_facets
|
|
2
|
+
from .tet import TetMesh, TetMeshPytree, StructuredTetBox, StructuredTetTensorBox
|
|
3
|
+
from .base import BaseMesh, BaseMeshPytree
|
|
4
|
+
from .predicate import bbox_predicate, plane_predicate, axis_plane_predicate, slab_predicate
|
|
5
|
+
from .surface import SurfaceMesh, SurfaceMeshPytree
|
|
6
|
+
from .io import load_gmsh_mesh, load_gmsh_hex_mesh, load_gmsh_tet_mesh, make_surface_from_facets
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BaseMesh",
|
|
10
|
+
"BaseMeshPytree",
|
|
11
|
+
"bbox_predicate",
|
|
12
|
+
"plane_predicate",
|
|
13
|
+
"axis_plane_predicate",
|
|
14
|
+
"slab_predicate",
|
|
15
|
+
"HexMesh",
|
|
16
|
+
"HexMeshPytree",
|
|
17
|
+
"StructuredHexBox",
|
|
18
|
+
"tag_axis_minmax_facets",
|
|
19
|
+
"TetMesh",
|
|
20
|
+
"TetMeshPytree",
|
|
21
|
+
"StructuredTetBox",
|
|
22
|
+
"StructuredTetTensorBox",
|
|
23
|
+
"SurfaceMesh",
|
|
24
|
+
"SurfaceMeshPytree",
|
|
25
|
+
"load_gmsh_mesh",
|
|
26
|
+
"load_gmsh_hex_mesh",
|
|
27
|
+
"load_gmsh_tet_mesh",
|
|
28
|
+
"make_surface_from_facets",
|
|
29
|
+
]
|
fluxfem/mesh/base.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Iterable, Optional, Sequence
|
|
4
|
+
import numpy as np
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BaseMeshClosure:
|
|
11
|
+
coords: jnp.ndarray
|
|
12
|
+
conn: jnp.ndarray
|
|
13
|
+
cell_tags: Optional[jnp.ndarray] = None
|
|
14
|
+
node_tags: Optional[jnp.ndarray] = None
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def n_nodes(self) -> int:
|
|
18
|
+
return self.coords.shape[0]
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def n_elems(self) -> int:
|
|
22
|
+
return self.conn.shape[0]
|
|
23
|
+
|
|
24
|
+
def element_coords(self) -> jnp.ndarray:
|
|
25
|
+
return self.coords[self.conn]
|
|
26
|
+
|
|
27
|
+
# ------------------------------------------------------------------
|
|
28
|
+
# Face patterns must be provided by concrete mesh types.
|
|
29
|
+
def face_node_patterns(self):
|
|
30
|
+
"""
|
|
31
|
+
Return a list of tuples, each tuple giving local node indices of a face.
|
|
32
|
+
Override in concrete mesh classes (HexMesh, TetMesh, etc).
|
|
33
|
+
"""
|
|
34
|
+
raise NotImplementedError("face_node_patterns must be implemented by mesh subtype")
|
|
35
|
+
|
|
36
|
+
# Convenience helpers for boundary tagging / DOF lookup
|
|
37
|
+
def node_indices_where(self, predicate: Callable[[np.ndarray], np.ndarray]) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Return node indices whose coordinates satisfy the predicate.
|
|
40
|
+
predicate: callable that takes coords (np.ndarray of shape (n_nodes, dim)) and returns boolean mask.
|
|
41
|
+
"""
|
|
42
|
+
coords_np = np.asarray(self.coords)
|
|
43
|
+
mask = predicate(coords_np)
|
|
44
|
+
return np.nonzero(mask)[0]
|
|
45
|
+
|
|
46
|
+
def node_indices_where_point(self, predicate: Callable[[np.ndarray], bool]) -> np.ndarray:
|
|
47
|
+
"""
|
|
48
|
+
Return node indices for which predicate(coord) is True.
|
|
49
|
+
predicate: callable accepting a single point (dim,) -> bool
|
|
50
|
+
"""
|
|
51
|
+
coords_np = np.asarray(self.coords)
|
|
52
|
+
mask = [bool(predicate(pt)) for pt in coords_np]
|
|
53
|
+
return np.nonzero(mask)[0]
|
|
54
|
+
|
|
55
|
+
def axis_extrema_nodes(self, axis: int = 0, side: str = "min", tol: float = 1e-8) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Nodes lying on min or max of a given axis.
|
|
58
|
+
side: "min" or "max"
|
|
59
|
+
"""
|
|
60
|
+
coords_np = np.asarray(self.coords)
|
|
61
|
+
vals = coords_np[:, axis]
|
|
62
|
+
target = vals.min() if side == "min" else vals.max()
|
|
63
|
+
mask = np.isclose(vals, target, atol=tol)
|
|
64
|
+
return np.nonzero(mask)[0]
|
|
65
|
+
|
|
66
|
+
def boundary_nodes_bbox(self, tol: float = 1e-8) -> np.ndarray:
|
|
67
|
+
"""
|
|
68
|
+
Nodes on the axis-aligned bounding box (min/max in each coordinate).
|
|
69
|
+
Useful for box-shaped meshes like StructuredHexBox.
|
|
70
|
+
"""
|
|
71
|
+
coords_np = np.asarray(self.coords)
|
|
72
|
+
mins = coords_np.min(axis=0)
|
|
73
|
+
maxs = coords_np.max(axis=0)
|
|
74
|
+
mask = np.zeros(coords_np.shape[0], dtype=bool)
|
|
75
|
+
for axis in range(coords_np.shape[1]):
|
|
76
|
+
mask |= np.isclose(coords_np[:, axis], mins[axis], atol=tol)
|
|
77
|
+
mask |= np.isclose(coords_np[:, axis], maxs[axis], atol=tol)
|
|
78
|
+
return np.nonzero(mask)[0]
|
|
79
|
+
|
|
80
|
+
def node_dofs(self, nodes: Iterable[int], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
|
|
81
|
+
"""
|
|
82
|
+
Build flattened DOF indices for given node ids.
|
|
83
|
+
|
|
84
|
+
components:
|
|
85
|
+
- sequence of component indices (e.g., [0,1,2])
|
|
86
|
+
- or string like "x", "xy", "xyz" (case-insensitive; maps x/y/z -> 0/1/2)
|
|
87
|
+
dof_per_node: optional; inferred from max component index + 1 if not provided.
|
|
88
|
+
"""
|
|
89
|
+
nodes_arr = np.asarray(list(nodes), dtype=int)
|
|
90
|
+
if isinstance(components, str):
|
|
91
|
+
comp_map = {"x": 0, "y": 1, "z": 2}
|
|
92
|
+
comps = np.asarray([comp_map[c.lower()] for c in components], dtype=int)
|
|
93
|
+
else:
|
|
94
|
+
comps = np.asarray(list(components), dtype=int)
|
|
95
|
+
inferred = int(comps.max()) + 1 if comps.size else 1
|
|
96
|
+
dofpn = inferred if dof_per_node is None else int(dof_per_node)
|
|
97
|
+
if dofpn <= comps.max():
|
|
98
|
+
raise ValueError(f"dof_per_node={dofpn} is inconsistent with requested component {comps.max()}")
|
|
99
|
+
dofs = [dofpn * int(n) + int(c) for n in nodes_arr for c in comps]
|
|
100
|
+
return np.asarray(dofs, dtype=int)
|
|
101
|
+
|
|
102
|
+
def dofs_where(self, predicate: Callable[[np.ndarray], np.ndarray], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
|
|
103
|
+
"""
|
|
104
|
+
DOF indices for nodes selected by a predicate over all coords.
|
|
105
|
+
predicate takes coords (np.ndarray, shape (n_nodes, dim)) and returns boolean mask.
|
|
106
|
+
"""
|
|
107
|
+
nodes = self.node_indices_where(predicate)
|
|
108
|
+
return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
|
|
109
|
+
|
|
110
|
+
def dofs_where_point(self, predicate: Callable[[np.ndarray], bool], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
|
|
111
|
+
"""
|
|
112
|
+
DOF indices for nodes selected by a per-point predicate.
|
|
113
|
+
predicate takes a single coord (dim,) -> bool.
|
|
114
|
+
"""
|
|
115
|
+
nodes = self.node_indices_where_point(predicate)
|
|
116
|
+
return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
|
|
117
|
+
|
|
118
|
+
def boundary_dofs_where(self, predicate: Callable[[np.ndarray], np.ndarray], components: Sequence[int] | str = "xyz", dof_per_node: Optional[int] = None) -> np.ndarray:
|
|
119
|
+
"""
|
|
120
|
+
Return DOF indices for boundary nodes whose coordinates satisfy predicate.
|
|
121
|
+
predicate takes coords (np.ndarray, shape (n_nodes, dim)) and returns boolean mask.
|
|
122
|
+
"""
|
|
123
|
+
coords_np = np.asarray(self.coords)
|
|
124
|
+
mask = np.asarray(predicate(coords_np), dtype=bool)
|
|
125
|
+
bmask = self.boundary_node_mask()
|
|
126
|
+
nodes = np.nonzero(mask & bmask)[0]
|
|
127
|
+
return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
|
|
128
|
+
|
|
129
|
+
def boundary_dofs_bbox(
|
|
130
|
+
self,
|
|
131
|
+
*,
|
|
132
|
+
components: Sequence[int] | str = "xyz",
|
|
133
|
+
dof_per_node: Optional[int] = None,
|
|
134
|
+
tol: float = 1e-8,
|
|
135
|
+
) -> np.ndarray:
|
|
136
|
+
"""
|
|
137
|
+
DOF indices on the axis-aligned bounding box (min/max in each coordinate).
|
|
138
|
+
"""
|
|
139
|
+
nodes = self.boundary_nodes_bbox(tol=tol)
|
|
140
|
+
return self.node_dofs(nodes, components=components, dof_per_node=dof_per_node)
|
|
141
|
+
|
|
142
|
+
def boundary_node_indices(self) -> np.ndarray:
|
|
143
|
+
"""
|
|
144
|
+
Return node indices on the boundary based on element face adjacency.
|
|
145
|
+
"""
|
|
146
|
+
cached = getattr(self, "_boundary_nodes_cache", None)
|
|
147
|
+
if cached is not None:
|
|
148
|
+
return cached
|
|
149
|
+
conn = np.asarray(self.conn)
|
|
150
|
+
patterns = self.face_node_patterns()
|
|
151
|
+
face_counts: dict[tuple[int, ...], int] = {}
|
|
152
|
+
for elem_conn in conn:
|
|
153
|
+
for pattern in patterns:
|
|
154
|
+
nodes = tuple(sorted(int(elem_conn[i]) for i in pattern))
|
|
155
|
+
face_counts[nodes] = face_counts.get(nodes, 0) + 1
|
|
156
|
+
bnodes = set()
|
|
157
|
+
for nodes, count in face_counts.items():
|
|
158
|
+
if count == 1:
|
|
159
|
+
bnodes.update(nodes)
|
|
160
|
+
out = np.asarray(sorted(bnodes), dtype=int)
|
|
161
|
+
setattr(self, "_boundary_nodes_cache", out)
|
|
162
|
+
return out
|
|
163
|
+
|
|
164
|
+
def boundary_node_mask(self) -> np.ndarray:
|
|
165
|
+
"""
|
|
166
|
+
Return boolean mask for boundary nodes (shape: n_nodes).
|
|
167
|
+
"""
|
|
168
|
+
mask = np.zeros(self.n_nodes, dtype=bool)
|
|
169
|
+
nodes = self.boundary_node_indices()
|
|
170
|
+
mask[nodes] = True
|
|
171
|
+
return mask
|
|
172
|
+
|
|
173
|
+
def make_node_tags(self, predicate: Callable[[np.ndarray], np.ndarray], tag: int, base: Optional[np.ndarray] = None) -> jnp.ndarray:
|
|
174
|
+
"""
|
|
175
|
+
Build a node_tags array by applying predicate to coords and setting tag where True.
|
|
176
|
+
Returns a jnp.ndarray (int32). Does not mutate the mesh.
|
|
177
|
+
"""
|
|
178
|
+
base_tags = np.zeros(self.n_nodes, dtype=np.int32) if base is None else np.asarray(base, dtype=np.int32).copy()
|
|
179
|
+
mask = predicate(np.asarray(self.coords))
|
|
180
|
+
base_tags[mask] = int(tag)
|
|
181
|
+
return jnp.asarray(base_tags, dtype=jnp.int32)
|
|
182
|
+
|
|
183
|
+
def with_node_tags(self, node_tags: np.ndarray | jnp.ndarray):
|
|
184
|
+
"""
|
|
185
|
+
Return a new mesh instance with provided node_tags.
|
|
186
|
+
"""
|
|
187
|
+
return self.__class__(coords=self.coords, conn=self.conn, cell_tags=self.cell_tags, node_tags=jnp.asarray(node_tags))
|
|
188
|
+
|
|
189
|
+
def boundary_facets_where(self, predicate: Callable[[np.ndarray], bool], tag: int | None = None):
|
|
190
|
+
"""
|
|
191
|
+
Collect boundary facets whose node coordinates satisfy predicate.
|
|
192
|
+
|
|
193
|
+
predicate receives a (n_face_nodes, dim) NumPy array and returns True/False.
|
|
194
|
+
Returns facets (and optional tags if tag is provided).
|
|
195
|
+
"""
|
|
196
|
+
coords = np.asarray(self.coords)
|
|
197
|
+
conn = np.asarray(self.conn)
|
|
198
|
+
patterns = self.face_node_patterns()
|
|
199
|
+
|
|
200
|
+
facet_map: dict[tuple[int, ...], tuple[list[int], Optional[int]]] = {}
|
|
201
|
+
|
|
202
|
+
for elem_conn in conn:
|
|
203
|
+
elem_nodes = coords[elem_conn]
|
|
204
|
+
for pattern in patterns:
|
|
205
|
+
nodes = [int(elem_conn[i]) for i in pattern]
|
|
206
|
+
face_coords = elem_nodes[list(pattern)]
|
|
207
|
+
if not predicate(face_coords):
|
|
208
|
+
continue
|
|
209
|
+
key = tuple(sorted(nodes))
|
|
210
|
+
if key not in facet_map:
|
|
211
|
+
facet_map[key] = (nodes, tag)
|
|
212
|
+
|
|
213
|
+
if not facet_map:
|
|
214
|
+
if tag is None:
|
|
215
|
+
return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32)
|
|
216
|
+
return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
|
|
217
|
+
|
|
218
|
+
facets = []
|
|
219
|
+
tags = []
|
|
220
|
+
for nodes, t in facet_map.values():
|
|
221
|
+
facets.append(nodes)
|
|
222
|
+
if tag is not None:
|
|
223
|
+
tags.append(t if t is not None else 0)
|
|
224
|
+
|
|
225
|
+
facets_arr = jnp.array(facets, dtype=jnp.int32)
|
|
226
|
+
if tag is None:
|
|
227
|
+
return facets_arr
|
|
228
|
+
return facets_arr, jnp.array(tags, dtype=jnp.int32)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@jax.tree_util.register_pytree_node_class
|
|
232
|
+
class BaseMeshPytree(BaseMeshClosure):
|
|
233
|
+
def tree_flatten(self):
|
|
234
|
+
children = (self.coords, self.conn, self.cell_tags, self.node_tags)
|
|
235
|
+
return children, {}
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def tree_unflatten(cls, aux, children):
|
|
239
|
+
coords, conn, cell_tags, node_tags = children
|
|
240
|
+
return cls(coords, conn, cell_tags, node_tags)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
BaseMesh = BaseMeshClosure
|
|
244
|
+
|
fluxfem/mesh/hex.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional, Dict, Tuple, List, Callable
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from ..core.dtypes import DEFAULT_DTYPE
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .base import BaseMesh, BaseMeshPytree
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class HexMesh(BaseMesh):
|
|
16
|
+
"""
|
|
17
|
+
Structured / unstructured hex mesh (8-node linear hex elements).
|
|
18
|
+
|
|
19
|
+
coords: (n_nodes, 3) float32
|
|
20
|
+
conn: (n_elems, 8) int32 # node indices of each element
|
|
21
|
+
"""
|
|
22
|
+
coords: jnp.ndarray # shape (n_nodes, 3)
|
|
23
|
+
conn: jnp.ndarray # shape (n_elems, 8), int32
|
|
24
|
+
|
|
25
|
+
def face_node_patterns(self):
|
|
26
|
+
return [
|
|
27
|
+
(0, 1, 2, 3), # -z
|
|
28
|
+
(4, 5, 6, 7), # +z
|
|
29
|
+
(0, 1, 5, 4), # -y
|
|
30
|
+
(3, 2, 6, 7), # +y
|
|
31
|
+
(0, 4, 7, 3), # -x
|
|
32
|
+
(1, 2, 6, 5), # +x
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@jax.tree_util.register_pytree_node_class
|
|
37
|
+
@dataclass(eq=False)
|
|
38
|
+
class HexMeshPytree(BaseMeshPytree):
|
|
39
|
+
def face_node_patterns(self):
|
|
40
|
+
return [
|
|
41
|
+
(0, 1, 2, 3), # -z
|
|
42
|
+
(4, 5, 6, 7), # +z
|
|
43
|
+
(0, 1, 5, 4), # -y
|
|
44
|
+
(3, 2, 6, 7), # +y
|
|
45
|
+
(0, 4, 7, 3), # -x
|
|
46
|
+
(1, 2, 6, 5), # +x
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def tag_axis_minmax_facets(
|
|
51
|
+
mesh: HexMesh,
|
|
52
|
+
axis: int = 0,
|
|
53
|
+
dirichlet_tag: int = 1,
|
|
54
|
+
neumann_tag: int = 2,
|
|
55
|
+
tol: float = 1e-8,
|
|
56
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
57
|
+
"""
|
|
58
|
+
Tag boundary facets on min/max of the given axis.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
facets: (n_facets, 4) int32, quad node ids
|
|
62
|
+
facet_tags: (n_facets,) int32, dirichlet_tag on min side, neumann_tag on max side
|
|
63
|
+
"""
|
|
64
|
+
coords = np.asarray(mesh.coords)
|
|
65
|
+
conn = np.asarray(mesh.conn)
|
|
66
|
+
|
|
67
|
+
axis_vals = coords[:, axis]
|
|
68
|
+
v_min = float(axis_vals.min())
|
|
69
|
+
v_max = float(axis_vals.max())
|
|
70
|
+
|
|
71
|
+
face_patterns: List[Tuple[int, int, int, int]] = [
|
|
72
|
+
(0, 1, 2, 3), # -z
|
|
73
|
+
(4, 5, 6, 7), # +z
|
|
74
|
+
(0, 1, 5, 4), # -y
|
|
75
|
+
(3, 2, 6, 7), # +y
|
|
76
|
+
(0, 4, 7, 3), # -x
|
|
77
|
+
(1, 2, 6, 5), # +x
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
facet_map: Dict[Tuple[int, ...], Tuple[List[int], int]] = {}
|
|
81
|
+
|
|
82
|
+
for elem_conn in conn:
|
|
83
|
+
elem_nodes = coords[elem_conn]
|
|
84
|
+
for pattern in face_patterns:
|
|
85
|
+
nodes = [int(elem_conn[i]) for i in pattern]
|
|
86
|
+
vals = elem_nodes[list(pattern), axis]
|
|
87
|
+
tag = None
|
|
88
|
+
if np.allclose(vals, v_min, atol=tol):
|
|
89
|
+
tag = dirichlet_tag
|
|
90
|
+
elif np.allclose(vals, v_max, atol=tol):
|
|
91
|
+
tag = neumann_tag
|
|
92
|
+
|
|
93
|
+
if tag is None:
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
key = tuple(sorted(nodes))
|
|
97
|
+
if key not in facet_map:
|
|
98
|
+
facet_map[key] = (nodes, tag)
|
|
99
|
+
|
|
100
|
+
if not facet_map:
|
|
101
|
+
return jnp.empty((0, 4), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
|
|
102
|
+
|
|
103
|
+
facets = []
|
|
104
|
+
tags = []
|
|
105
|
+
for nodes, tag in facet_map.values():
|
|
106
|
+
facets.append(nodes)
|
|
107
|
+
tags.append(tag)
|
|
108
|
+
|
|
109
|
+
return jnp.array(facets, dtype=jnp.int32), jnp.array(tags, dtype=jnp.int32)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class StructuredHexBox:
|
|
114
|
+
"""
|
|
115
|
+
Uniform hex mesh generator on a box, returned as an unstructured HexMesh.
|
|
116
|
+
"""
|
|
117
|
+
nx: int
|
|
118
|
+
ny: int
|
|
119
|
+
nz: int
|
|
120
|
+
lx: float = 1.0
|
|
121
|
+
ly: float = 1.0
|
|
122
|
+
lz: float = 1.0
|
|
123
|
+
origin: tuple[float, float, float] = (0.0, 0.0, 0.0)
|
|
124
|
+
order: int = 1 # 1: 8-node Hex, 2: 20-node serendipity Hex, 3: 27-node triquadratic Hex
|
|
125
|
+
|
|
126
|
+
def build(self) -> HexMesh:
|
|
127
|
+
"""
|
|
128
|
+
Build a regular grid of nx×ny×nz elements over [origin, origin + (lx, ly, lz)].
|
|
129
|
+
order=1 → 8-node Hex, order=2 → 20-node serendipity Hex。
|
|
130
|
+
"""
|
|
131
|
+
if self.nx <= 0 or self.ny <= 0 or self.nz <= 0:
|
|
132
|
+
raise ValueError("nx, ny, nz must be positive")
|
|
133
|
+
if self.order not in (1, 2, 3):
|
|
134
|
+
raise ValueError("order must be 1, 2, or 3")
|
|
135
|
+
|
|
136
|
+
ox, oy, oz = self.origin
|
|
137
|
+
xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=DEFAULT_DTYPE)
|
|
138
|
+
ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=DEFAULT_DTYPE)
|
|
139
|
+
zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=DEFAULT_DTYPE)
|
|
140
|
+
|
|
141
|
+
if self.order == 1:
|
|
142
|
+
return self._build_hex8(xs, ys, zs)
|
|
143
|
+
if self.order == 2:
|
|
144
|
+
return self._build_hex20(xs, ys, zs)
|
|
145
|
+
return self._build_hex27(xs, ys, zs)
|
|
146
|
+
|
|
147
|
+
def _build_hex8(self, xs, ys, zs) -> HexMesh:
|
|
148
|
+
coords_list = []
|
|
149
|
+
for k in range(self.nz + 1):
|
|
150
|
+
for j in range(self.ny + 1):
|
|
151
|
+
for i in range(self.nx + 1):
|
|
152
|
+
coords_list.append([xs[i], ys[j], zs[k]])
|
|
153
|
+
coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
|
|
154
|
+
|
|
155
|
+
def node_id(i: int, j: int, k: int) -> int:
|
|
156
|
+
return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
|
|
157
|
+
|
|
158
|
+
conn_list = []
|
|
159
|
+
for k in range(self.nz):
|
|
160
|
+
for j in range(self.ny):
|
|
161
|
+
for i in range(self.nx):
|
|
162
|
+
n000 = node_id(i, j, k)
|
|
163
|
+
n100 = node_id(i + 1, j, k)
|
|
164
|
+
n110 = node_id(i + 1, j + 1, k)
|
|
165
|
+
n010 = node_id(i, j + 1, k)
|
|
166
|
+
n001 = node_id(i, j, k + 1)
|
|
167
|
+
n101 = node_id(i + 1, j, k + 1)
|
|
168
|
+
n111 = node_id(i + 1, j + 1, k + 1)
|
|
169
|
+
n011 = node_id(i, j + 1, k + 1)
|
|
170
|
+
conn_list.append([n000, n100, n110, n010, n001, n101, n111, n011])
|
|
171
|
+
|
|
172
|
+
conn = jnp.array(conn_list, dtype=jnp.int32)
|
|
173
|
+
return HexMesh(coords=coords, conn=conn)
|
|
174
|
+
|
|
175
|
+
def _build_hex20(self, xs, ys, zs) -> HexMesh:
|
|
176
|
+
"""
|
|
177
|
+
Build 20-node serendipity Hex mesh (corner + edge midpoints).
|
|
178
|
+
"""
|
|
179
|
+
coords_list: List[List[float]] = []
|
|
180
|
+
node_map: Dict[Tuple[float, float, float], int] = {}
|
|
181
|
+
|
|
182
|
+
def add_node(pt: np.ndarray) -> int:
|
|
183
|
+
key = tuple(np.round(pt.astype(np.float64), 8))
|
|
184
|
+
if key not in node_map:
|
|
185
|
+
node_map[key] = len(coords_list)
|
|
186
|
+
coords_list.append([float(pt[0]), float(pt[1]), float(pt[2])])
|
|
187
|
+
return node_map[key]
|
|
188
|
+
|
|
189
|
+
# pre-create corner nodes
|
|
190
|
+
for k in range(self.nz + 1):
|
|
191
|
+
for j in range(self.ny + 1):
|
|
192
|
+
for i in range(self.nx + 1):
|
|
193
|
+
add_node(np.array([xs[i], ys[j], zs[k]], dtype=np.float64))
|
|
194
|
+
|
|
195
|
+
conn_list = []
|
|
196
|
+
for k in range(self.nz):
|
|
197
|
+
for j in range(self.ny):
|
|
198
|
+
for i in range(self.nx):
|
|
199
|
+
p000 = np.array([xs[i], ys[j], zs[k]], dtype=np.float64)
|
|
200
|
+
p100 = np.array([xs[i + 1], ys[j], zs[k]], dtype=np.float64)
|
|
201
|
+
p110 = np.array([xs[i + 1], ys[j + 1], zs[k]], dtype=np.float64)
|
|
202
|
+
p010 = np.array([xs[i], ys[j + 1], zs[k]], dtype=np.float64)
|
|
203
|
+
p001 = np.array([xs[i], ys[j], zs[k + 1]], dtype=np.float64)
|
|
204
|
+
p101 = np.array([xs[i + 1], ys[j], zs[k + 1]], dtype=np.float64)
|
|
205
|
+
p111 = np.array([xs[i + 1], ys[j + 1], zs[k + 1]], dtype=np.float64)
|
|
206
|
+
p011 = np.array([xs[i], ys[j + 1], zs[k + 1]], dtype=np.float64)
|
|
207
|
+
|
|
208
|
+
c000 = add_node(p000)
|
|
209
|
+
c100 = add_node(p100)
|
|
210
|
+
c110 = add_node(p110)
|
|
211
|
+
c010 = add_node(p010)
|
|
212
|
+
c001 = add_node(p001)
|
|
213
|
+
c101 = add_node(p101)
|
|
214
|
+
c111 = add_node(p111)
|
|
215
|
+
c011 = add_node(p011)
|
|
216
|
+
|
|
217
|
+
# edge midpoints
|
|
218
|
+
e01 = add_node(0.5 * (p000 + p100))
|
|
219
|
+
e12 = add_node(0.5 * (p100 + p110))
|
|
220
|
+
e23 = add_node(0.5 * (p110 + p010))
|
|
221
|
+
e30 = add_node(0.5 * (p010 + p000))
|
|
222
|
+
|
|
223
|
+
e45 = add_node(0.5 * (p001 + p101))
|
|
224
|
+
e56 = add_node(0.5 * (p101 + p111))
|
|
225
|
+
e67 = add_node(0.5 * (p111 + p011))
|
|
226
|
+
e74 = add_node(0.5 * (p011 + p001))
|
|
227
|
+
|
|
228
|
+
e04 = add_node(0.5 * (p000 + p001))
|
|
229
|
+
e15 = add_node(0.5 * (p100 + p101))
|
|
230
|
+
e26 = add_node(0.5 * (p110 + p111))
|
|
231
|
+
e37 = add_node(0.5 * (p010 + p011))
|
|
232
|
+
|
|
233
|
+
conn_list.append(
|
|
234
|
+
[
|
|
235
|
+
c000, c100, c110, c010, # corners
|
|
236
|
+
c001, c101, c111, c011,
|
|
237
|
+
e01, e12, e23, e30, # edge mids
|
|
238
|
+
e45, e56, e67, e74,
|
|
239
|
+
e04, e15, e26, e37,
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
|
|
244
|
+
conn = jnp.array(conn_list, dtype=jnp.int32)
|
|
245
|
+
return HexMesh(coords=coords, conn=conn)
|
|
246
|
+
|
|
247
|
+
def _build_hex27(self, xs, ys, zs) -> HexMesh:
|
|
248
|
+
"""
|
|
249
|
+
Build 27-node triquadratic Hex mesh (tensor-product nodes: corners, edge mids, face centers, body center).
|
|
250
|
+
"""
|
|
251
|
+
coords_list: List[List[float]] = []
|
|
252
|
+
node_map: Dict[Tuple[float, float, float], int] = {}
|
|
253
|
+
|
|
254
|
+
def add_node(pt: np.ndarray) -> int:
|
|
255
|
+
key = tuple(np.round(pt.astype(np.float64), 8))
|
|
256
|
+
if key not in node_map:
|
|
257
|
+
node_map[key] = len(coords_list)
|
|
258
|
+
coords_list.append([float(pt[0]), float(pt[1]), float(pt[2])])
|
|
259
|
+
return node_map[key]
|
|
260
|
+
|
|
261
|
+
def mid(a, b):
|
|
262
|
+
return 0.5 * (a + b)
|
|
263
|
+
|
|
264
|
+
# pre-create grid nodes at original vertices and midpoints on axes
|
|
265
|
+
xs_mid = [mid(xs[i], xs[i + 1]) for i in range(len(xs) - 1)]
|
|
266
|
+
ys_mid = [mid(ys[j], ys[j + 1]) for j in range(len(ys) - 1)]
|
|
267
|
+
zs_mid = [mid(zs[k], zs[k + 1]) for k in range(len(zs) - 1)]
|
|
268
|
+
|
|
269
|
+
# create all possible nodes (vertices + edge mids + face centers + cell centers)
|
|
270
|
+
for k_idx, zk in enumerate(zs):
|
|
271
|
+
for j_idx, yj in enumerate(ys):
|
|
272
|
+
for i_idx, xi in enumerate(xs):
|
|
273
|
+
add_node(np.array([xi, yj, zk], dtype=np.float64))
|
|
274
|
+
for k_idx, zk in enumerate(zs):
|
|
275
|
+
for j_idx, yj in enumerate(ys):
|
|
276
|
+
for i_mid in xs_mid:
|
|
277
|
+
add_node(np.array([i_mid, yj, zk], dtype=np.float64))
|
|
278
|
+
for k_idx, zk in enumerate(zs):
|
|
279
|
+
for i_idx, xi in enumerate(xs):
|
|
280
|
+
for j_mid in ys_mid:
|
|
281
|
+
add_node(np.array([xi, j_mid, zk], dtype=np.float64))
|
|
282
|
+
for j_idx, yj in enumerate(ys):
|
|
283
|
+
for i_idx, xi in enumerate(xs):
|
|
284
|
+
for k_mid in zs_mid:
|
|
285
|
+
add_node(np.array([xi, yj, k_mid], dtype=np.float64))
|
|
286
|
+
# face centers
|
|
287
|
+
for k_idx, zk in enumerate(zs):
|
|
288
|
+
for j_mid in ys_mid:
|
|
289
|
+
for i_mid in xs_mid:
|
|
290
|
+
add_node(np.array([i_mid, j_mid, zk], dtype=np.float64))
|
|
291
|
+
for k_mid in zs_mid:
|
|
292
|
+
for j_idx, yj in enumerate(ys):
|
|
293
|
+
for i_mid in xs_mid:
|
|
294
|
+
add_node(np.array([i_mid, yj, k_mid], dtype=np.float64))
|
|
295
|
+
for k_mid in zs_mid:
|
|
296
|
+
for j_mid in ys_mid:
|
|
297
|
+
for i_idx, xi in enumerate(xs):
|
|
298
|
+
add_node(np.array([xi, j_mid, k_mid], dtype=np.float64))
|
|
299
|
+
# cell centers (unique per cell)
|
|
300
|
+
for k in range(self.nz):
|
|
301
|
+
for j in range(self.ny):
|
|
302
|
+
for i in range(self.nx):
|
|
303
|
+
cx = mid(xs[i], xs[i + 1])
|
|
304
|
+
cy = mid(ys[j], ys[j + 1])
|
|
305
|
+
cz = mid(zs[k], zs[k + 1])
|
|
306
|
+
add_node(np.array([cx, cy, cz], dtype=np.float64))
|
|
307
|
+
|
|
308
|
+
conn_list = []
|
|
309
|
+
for k in range(self.nz):
|
|
310
|
+
for j in range(self.ny):
|
|
311
|
+
for i in range(self.nx):
|
|
312
|
+
x_vals = [xs[i], mid(xs[i], xs[i + 1]), xs[i + 1]]
|
|
313
|
+
y_vals = [ys[j], mid(ys[j], ys[j + 1]), ys[j + 1]]
|
|
314
|
+
z_vals = [zs[k], mid(zs[k], zs[k + 1]), zs[k + 1]]
|
|
315
|
+
|
|
316
|
+
nodes = []
|
|
317
|
+
for kk in range(3):
|
|
318
|
+
for jj in range(3):
|
|
319
|
+
for ii in range(3):
|
|
320
|
+
nodes.append(add_node(np.array([x_vals[ii], y_vals[jj], z_vals[kk]], dtype=np.float64)))
|
|
321
|
+
|
|
322
|
+
# order in lexicographic k,j,i -> length 27
|
|
323
|
+
conn_list.append(nodes)
|
|
324
|
+
|
|
325
|
+
coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
|
|
326
|
+
conn = jnp.array(conn_list, dtype=jnp.int32)
|
|
327
|
+
return HexMesh(coords=coords, conn=conn)
|