fluxfem 0.1.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fluxfem might be problematic. Click here for more details.

Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +318 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +828 -0
  12. fluxfem/helpers_ts.py +11 -0
  13. fluxfem/helpers_wf.py +44 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.3a0.dist-info/LICENSE +201 -0
  45. fluxfem-0.1.3a0.dist-info/METADATA +125 -0
  46. fluxfem-0.1.3a0.dist-info/RECORD +47 -0
  47. fluxfem-0.1.3a0.dist-info/WHEEL +4 -0
fluxfem/mesh/io.py ADDED
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
8
+
9
+ try:
10
+ import meshio
11
+ except Exception as e: # pragma: no cover
12
+ meshio = None
13
+ meshio_import_error = e
14
+ else:
15
+ meshio_import_error = None
16
+
17
+ from .hex import HexMesh
18
+ from .tet import TetMesh
19
+ from .surface import SurfaceMesh
20
+
21
+
22
+ def load_gmsh_mesh(path: str):
23
+ """
24
+ Load a Gmsh .msh (v2/v4) file containing hex or tet elements (and optional boundary facets).
25
+
26
+ Returns:
27
+ mesh: HexMesh or TetMesh
28
+ facets: (n_facets, 3 or 4) or None
29
+ facet_tags: (n_facets,) or None (gmsh physical tags if present)
30
+ """
31
+ if meshio is None:
32
+ raise ImportError(f"meshio is required to load gmsh meshes: {meshio_import_error}")
33
+
34
+ msh = meshio.read(path)
35
+ coords = np.asarray(msh.points[:, :3], dtype=DTYPE)
36
+
37
+ mesh = None
38
+ if "hexahedron" in msh.cells_dict:
39
+ conn = np.asarray(msh.cells_dict["hexahedron"], dtype=np.int32)
40
+ mesh = HexMesh(jnp.asarray(coords), jnp.asarray(conn))
41
+ elif "tetra" in msh.cells_dict:
42
+ conn = np.asarray(msh.cells_dict["tetra"], dtype=np.int32)
43
+ mesh = TetMesh(jnp.asarray(coords), jnp.asarray(conn))
44
+ else:
45
+ raise ValueError("gmsh mesh does not contain hexahedron or tetra cells")
46
+
47
+ facets = None
48
+ facet_tags = None
49
+ # surface facets (quad/tri)
50
+ cell_dict = msh.cells_dict
51
+ for fkey in ("quad", "triangle", "tri"):
52
+ if fkey in cell_dict:
53
+ facets = np.asarray(cell_dict[fkey], dtype=np.int32)
54
+ break
55
+ tags_raw = None
56
+ if facets is not None:
57
+ if "gmsh:physical" in msh.cell_data_dict:
58
+ tags_raw = msh.cell_data_dict["gmsh:physical"].get(fkey, None)
59
+ else:
60
+ if "gmsh:physical" in msh.cell_data:
61
+ for block, data in zip(msh.cells, msh.cell_data["gmsh:physical"]):
62
+ if block.type == fkey:
63
+ tags_raw = data
64
+ break
65
+ if tags_raw is not None:
66
+ facet_tags = np.asarray(tags_raw, dtype=np.int32)
67
+
68
+ return mesh, facets, facet_tags
69
+
70
+
71
+ def load_gmsh_hex_mesh(path: str):
72
+ mesh, facets, tags = load_gmsh_mesh(path)
73
+ if not isinstance(mesh, HexMesh):
74
+ raise ValueError("gmsh mesh is not hexahedral")
75
+ return mesh, facets, tags
76
+
77
+
78
+ def load_gmsh_tet_mesh(path: str):
79
+ mesh, facets, tags = load_gmsh_mesh(path)
80
+ if not isinstance(mesh, TetMesh):
81
+ raise ValueError("gmsh mesh is not tetrahedral")
82
+ return mesh, facets, tags
83
+
84
+
85
+ def make_surface_from_facets(coords: np.ndarray, facets: np.ndarray, tags=None) -> SurfaceMesh:
86
+ """Helper to build SurfaceMesh from raw coords/facets (optional tags)."""
87
+ return SurfaceMesh.from_facets(coords, facets, facet_tags=tags)
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ def plane_predicate(axis: int, value: float, tol: float = 1e-8):
7
+ """Return predicate True when all nodes lie on plane x[axis]=value (within tol)."""
8
+ def pred(pts: np.ndarray) -> bool:
9
+ return bool(np.allclose(pts[:, axis], value, atol=tol))
10
+ return pred
11
+
12
+
13
+ def axis_plane_predicate(axis: int, value: float, tol: float = 1e-8):
14
+ """Alias of plane_predicate for readability."""
15
+ return plane_predicate(axis, value, tol=tol)
16
+
17
+
18
+ def slab_predicate(axis: int, min_val: float, max_val: float, tol: float = 1e-8):
19
+ """Return predicate True when nodes lie in a slab min<=x[axis]<=max (within tol)."""
20
+ def pred(pts: np.ndarray) -> np.ndarray:
21
+ pts_np = np.asarray(pts, dtype=float)
22
+ return (pts_np[:, axis] >= min_val - tol) & (pts_np[:, axis] <= max_val + tol)
23
+ return pred
24
+
25
+
26
+ def bbox_predicate(mins: np.ndarray, maxs: np.ndarray, tol: float = 1e-8):
27
+ """Return predicate True for nodes on the axis-aligned bounding box."""
28
+ mins_np = np.asarray(mins, dtype=float)
29
+ maxs_np = np.asarray(maxs, dtype=float)
30
+
31
+ def pred(pts: np.ndarray) -> np.ndarray:
32
+ pts_np = np.asarray(pts, dtype=float)
33
+ on_min = np.isclose(pts_np, mins_np[None, :], atol=tol)
34
+ on_max = np.isclose(pts_np, maxs_np[None, :], atol=tol)
35
+ return np.any(on_min | on_max, axis=1)
36
+
37
+ return pred
38
+
39
+
40
+ __all__ = [
41
+ "plane_predicate",
42
+ "axis_plane_predicate",
43
+ "slab_predicate",
44
+ "bbox_predicate",
45
+ ]
@@ -0,0 +1,257 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
10
+
11
+ from .base import BaseMesh, BaseMeshPytree
12
+ from .hex import HexMesh, HexMeshPytree
13
+
14
+
15
+ def _polygon_area(pts: np.ndarray) -> float:
16
+ """
17
+ Polygon area in 3D by fan triangulation (works for tri/quad faces).
18
+ Assumes points are planar and ordered.
19
+ """
20
+ if pts.shape[0] < 3:
21
+ return 0.0
22
+ area = 0.0
23
+ p0 = pts[0]
24
+ for i in range(1, pts.shape[0] - 1):
25
+ v1 = pts[i] - p0
26
+ v2 = pts[i + 1] - p0
27
+ area += 0.5 * np.linalg.norm(np.cross(v1, v2))
28
+ return float(area)
29
+
30
+
31
+ @dataclass(eq=False)
32
+ class SurfaceMesh(BaseMesh):
33
+ """
34
+ Simple boundary mesh made of facets (tri/quad) that live in the volume mesh nodes.
35
+ Uses BaseMesh.conn to store facets.
36
+ """
37
+
38
+ facet_tags: Optional[jnp.ndarray] = None
39
+
40
+ def __post_init__(self):
41
+ # Keep facet_tags mirrored in cell_tags for BaseMesh compat.
42
+ if self.cell_tags is None and self.facet_tags is not None:
43
+ self.cell_tags = self.facet_tags
44
+ if self.facet_tags is None and self.cell_tags is not None:
45
+ self.facet_tags = self.cell_tags
46
+
47
+ @classmethod
48
+ def from_facets(
49
+ cls,
50
+ coords: jnp.ndarray,
51
+ facets: jnp.ndarray,
52
+ facet_tags: Optional[jnp.ndarray] = None,
53
+ node_tags: Optional[jnp.ndarray] = None,
54
+ ) -> "SurfaceMesh":
55
+ coords_j = jnp.asarray(coords, dtype=DTYPE)
56
+ facets_j = jnp.asarray(facets, dtype=jnp.int32)
57
+ tags_j = None if facet_tags is None else jnp.asarray(facet_tags, dtype=jnp.int32)
58
+ node_tags_j = None if node_tags is None else jnp.asarray(node_tags)
59
+ return cls(coords=coords_j, conn=facets_j, cell_tags=tags_j, node_tags=node_tags_j, facet_tags=tags_j)
60
+
61
+ @classmethod
62
+ def from_hex_mesh(
63
+ cls,
64
+ mesh: HexMesh,
65
+ facets: jnp.ndarray,
66
+ facet_tags: Optional[jnp.ndarray] = None,
67
+ ) -> "SurfaceMesh":
68
+ """
69
+ Build a surface mesh that reuses the volume mesh coordinates.
70
+ Facets must reference the volume mesh node numbering.
71
+ """
72
+ return cls.from_facets(mesh.coords, facets, facet_tags=facet_tags, node_tags=mesh.node_tags)
73
+
74
+ @property
75
+ def n_facets(self) -> int:
76
+ return self.n_elems
77
+
78
+ def facet_areas(self) -> np.ndarray:
79
+ """Return per-facet area (uses NumPy for simplicity)."""
80
+ coords = np.asarray(self.coords)
81
+ facets = np.asarray(self.conn, dtype=int)
82
+ areas = np.zeros(facets.shape[0], dtype=float)
83
+ for i, nodes in enumerate(facets):
84
+ pts = coords[nodes]
85
+ areas[i] = _polygon_area(pts)
86
+ return areas
87
+
88
+ def select_by_tag(self, tag: int) -> "SurfaceMesh":
89
+ """Return a new SurfaceMesh containing only facets with given tag."""
90
+ if self.facet_tags is None:
91
+ raise ValueError("facet_tags not set on this SurfaceMesh")
92
+ mask = np.asarray(self.facet_tags) == tag
93
+ return SurfaceMesh.from_facets(
94
+ self.coords,
95
+ self.conn[mask],
96
+ facet_tags=self.facet_tags[mask],
97
+ node_tags=self.node_tags,
98
+ )
99
+
100
+ def facet_normals(self, *, outward_from=None, normalize: bool = True) -> np.ndarray:
101
+ from ..solver.bc import facet_normals
102
+ return facet_normals(self, outward_from=outward_from, normalize=normalize)
103
+
104
+ def assemble_load(self, load, *, dim: int, n_total_nodes: int | None = None, F0=None):
105
+ from ..solver.bc import assemble_surface_load
106
+ return assemble_surface_load(self, load, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
107
+
108
+ def assemble_linear_form(self, form, params, *, dim: int, n_total_nodes: int | None = None, F0=None):
109
+ from ..solver.bc import assemble_surface_linear_form
110
+ return assemble_surface_linear_form(self, form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
111
+
112
+ def assemble_linear_form_on_space(self, space, form, params, *, F0=None):
113
+ """
114
+ Assemble surface linear form using global size inferred from a volume space.
115
+ """
116
+ dim = int(getattr(space, "value_dim", 1))
117
+ n_total_nodes = int(getattr(space, "mesh", self).n_nodes)
118
+ return self.assemble_linear_form(form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
119
+
120
+ def assemble_traction(
121
+ self,
122
+ traction,
123
+ *,
124
+ dim: int = 3,
125
+ n_total_nodes: int | None = None,
126
+ F0=None,
127
+ outward_from=None,
128
+ ):
129
+ from ..solver.bc import assemble_surface_traction
130
+ return assemble_surface_traction(
131
+ self,
132
+ traction,
133
+ dim=dim,
134
+ n_total_nodes=n_total_nodes,
135
+ F0=F0,
136
+ outward_from=outward_from,
137
+ )
138
+
139
+
140
+ @jax.tree_util.register_pytree_node_class
141
+ @dataclass(eq=False)
142
+ class SurfaceMeshPytree(BaseMeshPytree):
143
+ """
144
+ Simple boundary mesh made of facets (tri/quad) that live in the volume mesh nodes.
145
+ Uses BaseMesh.conn to store facets.
146
+ """
147
+
148
+ facet_tags: Optional[jnp.ndarray] = None
149
+
150
+ def __post_init__(self):
151
+ if self.cell_tags is None and self.facet_tags is not None:
152
+ self.cell_tags = self.facet_tags
153
+ if self.facet_tags is None and self.cell_tags is not None:
154
+ self.facet_tags = self.cell_tags
155
+
156
+ @classmethod
157
+ def from_facets(
158
+ cls,
159
+ coords: jnp.ndarray,
160
+ facets: jnp.ndarray,
161
+ facet_tags: Optional[jnp.ndarray] = None,
162
+ node_tags: Optional[jnp.ndarray] = None,
163
+ ) -> "SurfaceMeshPytree":
164
+ coords_j = jnp.asarray(coords, dtype=DTYPE)
165
+ facets_j = jnp.asarray(facets, dtype=jnp.int32)
166
+ tags_j = None if facet_tags is None else jnp.asarray(facet_tags, dtype=jnp.int32)
167
+ node_tags_j = None if node_tags is None else jnp.asarray(node_tags)
168
+ return cls(coords=coords_j, conn=facets_j, cell_tags=tags_j, node_tags=node_tags_j, facet_tags=tags_j)
169
+
170
+ @classmethod
171
+ def from_hex_mesh(
172
+ cls,
173
+ mesh: HexMesh | HexMeshPytree,
174
+ facets: jnp.ndarray,
175
+ facet_tags: Optional[jnp.ndarray] = None,
176
+ ) -> "SurfaceMeshPytree":
177
+ return cls.from_facets(mesh.coords, facets, facet_tags=facet_tags, node_tags=mesh.node_tags)
178
+
179
+ @property
180
+ def n_facets(self) -> int:
181
+ return self.n_elems
182
+
183
+ def facet_areas(self) -> np.ndarray:
184
+ coords = np.asarray(self.coords)
185
+ facets = np.asarray(self.conn, dtype=int)
186
+ areas = np.zeros(facets.shape[0], dtype=float)
187
+ for i, nodes in enumerate(facets):
188
+ pts = coords[nodes]
189
+ areas[i] = _polygon_area(pts)
190
+ return areas
191
+
192
+ def select_by_tag(self, tag: int) -> "SurfaceMeshPytree":
193
+ if self.facet_tags is None:
194
+ raise ValueError("facet_tags not set on this SurfaceMesh")
195
+ mask = np.asarray(self.facet_tags) == tag
196
+ return SurfaceMeshPytree.from_facets(
197
+ self.coords,
198
+ self.conn[mask],
199
+ facet_tags=self.facet_tags[mask],
200
+ node_tags=self.node_tags,
201
+ )
202
+
203
+ def facet_normals(self, *, outward_from=None, normalize: bool = True) -> np.ndarray:
204
+ from ..solver.bc import facet_normals
205
+ return facet_normals(self, outward_from=outward_from, normalize=normalize)
206
+
207
+ def assemble_load(self, load, *, dim: int, n_total_nodes: int | None = None, F0=None):
208
+ from ..solver.bc import assemble_surface_load
209
+ return assemble_surface_load(self, load, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
210
+
211
+ def assemble_linear_form(self, form, params, *, dim: int, n_total_nodes: int | None = None, F0=None):
212
+ from ..solver.bc import assemble_surface_linear_form
213
+ return assemble_surface_linear_form(self, form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
214
+
215
+ def assemble_linear_form_on_space(self, space, form, params, *, F0=None):
216
+ """
217
+ Assemble surface linear form using global size inferred from a volume space.
218
+ """
219
+ dim = int(getattr(space, "value_dim", 1))
220
+ n_total_nodes = int(getattr(space, "mesh", self).n_nodes)
221
+ return self.assemble_linear_form(form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
222
+
223
+ def assemble_traction(
224
+ self,
225
+ traction,
226
+ *,
227
+ dim: int = 3,
228
+ n_total_nodes: int | None = None,
229
+ F0=None,
230
+ outward_from=None,
231
+ ):
232
+ from ..solver.bc import assemble_surface_traction
233
+ return assemble_surface_traction(
234
+ self,
235
+ traction,
236
+ dim=dim,
237
+ n_total_nodes=n_total_nodes,
238
+ F0=F0,
239
+ outward_from=outward_from,
240
+ )
241
+
242
+ def assemble_flux(
243
+ self,
244
+ flux,
245
+ *,
246
+ n_total_nodes: int | None = None,
247
+ F0=None,
248
+ outward_from=None,
249
+ ):
250
+ from ..solver.bc import assemble_surface_flux
251
+ return assemble_surface_flux(
252
+ self,
253
+ flux,
254
+ n_total_nodes=n_total_nodes,
255
+ F0=F0,
256
+ outward_from=outward_from,
257
+ )
fluxfem/mesh/tet.py ADDED
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+
8
+ DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
9
+
10
+
11
+ from .base import BaseMesh, BaseMeshPytree
12
+
13
+
14
+ @dataclass
15
+ class TetMesh(BaseMesh):
16
+ """Unstructured tetra mesh."""
17
+ def face_node_patterns(self):
18
+ # 4-node faces of a tet
19
+ return [
20
+ (0, 1, 2),
21
+ (0, 1, 3),
22
+ (0, 2, 3),
23
+ (1, 2, 3),
24
+ ]
25
+
26
+
27
+ @jax.tree_util.register_pytree_node_class
28
+ @dataclass(eq=False)
29
+ class TetMeshPytree(BaseMeshPytree):
30
+ """Unstructured tetra mesh (pytree)."""
31
+ def face_node_patterns(self):
32
+ return [
33
+ (0, 1, 2),
34
+ (0, 1, 3),
35
+ (0, 2, 3),
36
+ (1, 2, 3),
37
+ ]
38
+
39
+
40
+ @dataclass
41
+ class StructuredTetBox:
42
+ """Regular grid subdivided into 5 tets per cube (simplex) in a structured layout."""
43
+
44
+ nx: int
45
+ ny: int
46
+ nz: int
47
+ lx: float = 1.0
48
+ ly: float = 1.0
49
+ lz: float = 1.0
50
+ origin: tuple[float, float, float] = (0.0, 0.0, 0.0)
51
+ order: int = 1 # 1: 4-node, 2: 10-node (edge mids)
52
+
53
+ def _fix_orientation(self, coords: np.ndarray, conn: np.ndarray) -> np.ndarray:
54
+ conn = conn.copy()
55
+ for idx, tet in enumerate(conn):
56
+ p = coords[tet]
57
+ p0, p1, p2, p3 = p[0], p[1], p[2], p[3] # first 4 are corners
58
+ J = np.stack([p1 - p0, p2 - p0, p3 - p0], axis=1)
59
+ if np.linalg.det(J) < 0:
60
+ tet[[1, 2]] = tet[[2, 1]] # swap corner1/corner2
61
+ conn[idx] = tet
62
+ return conn
63
+
64
+ def _midpoint(self, a, b):
65
+ return 0.5 * (a + b)
66
+
67
+ def _build_tet10(self, xs, ys, zs) -> TetMesh:
68
+ coords_list: list[list[float]] = []
69
+ node_map: dict[tuple[float, float, float], int] = {}
70
+
71
+ def add_node(pt: np.ndarray) -> int:
72
+ key = tuple(np.round(pt.astype(np.float64), 10))
73
+ if key not in node_map:
74
+ node_map[key] = len(coords_list)
75
+ coords_list.append([float(pt[0]), float(pt[1]), float(pt[2])])
76
+ return node_map[key]
77
+
78
+ # corners
79
+ for k in range(self.nz + 1):
80
+ for j in range(self.ny + 1):
81
+ for i in range(self.nx + 1):
82
+ add_node(np.array([xs[i], ys[j], zs[k]], dtype=np.float64))
83
+
84
+ conn_list = []
85
+ mid = self._midpoint
86
+ for k in range(self.nz):
87
+ for j in range(self.ny):
88
+ for i in range(self.nx):
89
+ p000 = np.array([xs[i], ys[j], zs[k]], dtype=np.float64)
90
+ p100 = np.array([xs[i + 1], ys[j], zs[k]], dtype=np.float64)
91
+ p010 = np.array([xs[i], ys[j + 1], zs[k]], dtype=np.float64)
92
+ p110 = np.array([xs[i + 1], ys[j + 1], zs[k]], dtype=np.float64)
93
+ p001 = np.array([xs[i], ys[j], zs[k + 1]], dtype=np.float64)
94
+ p101 = np.array([xs[i + 1], ys[j], zs[k + 1]], dtype=np.float64)
95
+ p011 = np.array([xs[i], ys[j + 1], zs[k + 1]], dtype=np.float64)
96
+ p111 = np.array([xs[i + 1], ys[j + 1], zs[k + 1]], dtype=np.float64)
97
+
98
+ corners = [
99
+ (p000, p100, p010, p001),
100
+ (p100, p110, p010, p111),
101
+ (p100, p010, p001, p111),
102
+ (p100, p001, p101, p111),
103
+ (p010, p001, p011, p111),
104
+ ]
105
+
106
+ for p0, p1, p2, p3 in corners:
107
+ n0 = add_node(p0)
108
+ n1 = add_node(p1)
109
+ n2 = add_node(p2)
110
+ n3 = add_node(p3)
111
+ # edge midpoints
112
+ n01 = add_node(mid(p0, p1))
113
+ n02 = add_node(mid(p0, p2))
114
+ n03 = add_node(mid(p0, p3))
115
+ n12 = add_node(mid(p1, p2))
116
+ n13 = add_node(mid(p1, p3))
117
+ n23 = add_node(mid(p2, p3))
118
+ conn_list.append([n0, n1, n2, n3, n01, n02, n03, n12, n13, n23])
119
+
120
+ coords = np.asarray(coords_list, dtype=DTYPE)
121
+ conn = np.asarray(conn_list, dtype=np.int32)
122
+ conn = self._fix_orientation(coords, conn)
123
+ return TetMesh(coords=jnp.array(coords), conn=jnp.array(conn))
124
+
125
+ def build(self) -> TetMesh:
126
+ if self.nx <= 0 or self.ny <= 0 or self.nz <= 0:
127
+ raise ValueError("nx, ny, nz must be positive")
128
+ if self.order not in (1, 2):
129
+ raise ValueError("order must be 1 or 2")
130
+
131
+ ox, oy, oz = self.origin
132
+ xs = np.linspace(ox, ox + self.lx, self.nx + 1, dtype=np.float64)
133
+ ys = np.linspace(oy, oy + self.ly, self.ny + 1, dtype=np.float64)
134
+ zs = np.linspace(oz, oz + self.lz, self.nz + 1, dtype=np.float64)
135
+
136
+ if self.order == 1:
137
+ return self._build_linear(xs, ys, zs)
138
+ return self._build_tet10(xs, ys, zs)
139
+
140
+ # keep linear build for order=1
141
+ def _build_linear(self, xs, ys, zs) -> TetMesh:
142
+ ox, oy, oz = 0.0, 0.0, 0.0 # unused but keep signature consistent
143
+ coords_list = []
144
+ for k in range(self.nz + 1):
145
+ for j in range(self.ny + 1):
146
+ for i in range(self.nx + 1):
147
+ coords_list.append([xs[i], ys[j], zs[k]])
148
+ coords = np.asarray(coords_list, dtype=DTYPE)
149
+ def node_id(i: int, j: int, k: int) -> int:
150
+ return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
151
+ conn_list = []
152
+ for k in range(self.nz):
153
+ for j in range(self.ny):
154
+ for i in range(self.nx):
155
+ v000 = node_id(i, j, k)
156
+ v100 = node_id(i + 1, j, k)
157
+ v010 = node_id(i, j + 1, k)
158
+ v110 = node_id(i + 1, j + 1, k)
159
+ v001 = node_id(i, j, k + 1)
160
+ v101 = node_id(i + 1, j, k + 1)
161
+ v011 = node_id(i, j + 1, k + 1)
162
+ v111 = node_id(i + 1, j + 1, k + 1)
163
+ conn_list.extend(
164
+ [
165
+ [v000, v100, v010, v001],
166
+ [v100, v110, v010, v111],
167
+ [v100, v010, v001, v111],
168
+ [v100, v001, v101, v111],
169
+ [v010, v001, v011, v111],
170
+ ]
171
+ )
172
+ conn = np.asarray(conn_list, dtype=np.int32)
173
+ conn = self._fix_orientation(coords, conn)
174
+ return TetMesh(coords=jnp.array(coords), conn=jnp.array(conn))
175
+
176
+
177
+ @dataclass
178
+ class StructuredTetTensorBox:
179
+ """
180
+ Regular grid subdivided into 6 tets per cube (matches skfem MeshTet.init_tensor).
181
+ """
182
+
183
+ nx: int
184
+ ny: int
185
+ nz: int
186
+ lx: float = 1.0
187
+ ly: float = 1.0
188
+ lz: float = 1.0
189
+ origin: tuple[float, float, float] = (0.0, 0.0, 0.0)
190
+ order: int = 1 # only linear supported
191
+
192
+ def _fix_orientation(self, coords: np.ndarray, conn: np.ndarray) -> np.ndarray:
193
+ conn = conn.copy()
194
+ for idx, tet in enumerate(conn):
195
+ p = coords[tet]
196
+ p0, p1, p2, p3 = p[0], p[1], p[2], p[3]
197
+ J = np.stack([p1 - p0, p2 - p0, p3 - p0], axis=1)
198
+ if np.linalg.det(J) < 0:
199
+ tet[[1, 2]] = tet[[2, 1]]
200
+ conn[idx] = tet
201
+ return conn
202
+
203
+ def build(self) -> TetMesh:
204
+ if self.nx <= 0 or self.ny <= 0 or self.nz <= 0:
205
+ raise ValueError("nx, ny, nz must be positive")
206
+ if self.order != 1:
207
+ raise ValueError("StructuredTetTensorBox only supports order=1")
208
+
209
+ ox, oy, oz = self.origin
210
+ xs = np.linspace(ox, ox + self.lx, self.nx + 1, dtype=np.float64)
211
+ ys = np.linspace(oy, oy + self.ly, self.ny + 1, dtype=np.float64)
212
+ zs = np.linspace(oz, oz + self.lz, self.nz + 1, dtype=np.float64)
213
+ return self._build_linear(xs, ys, zs)
214
+
215
+ def _build_linear(self, xs, ys, zs) -> TetMesh:
216
+ # Mirror scikit-fem MeshTet.init_tensor.
217
+ npx = len(xs)
218
+ npy = len(ys)
219
+ npz = len(zs)
220
+ X, Y, Z = np.meshgrid(np.sort(xs), np.sort(ys), np.sort(zs))
221
+ p = np.vstack((X.flatten("F"), Y.flatten("F"), Z.flatten("F")))
222
+ ix = np.arange(npx * npy * npz)
223
+ ne = (npx - 1) * (npy - 1) * (npz - 1)
224
+ t = np.zeros((8, ne), dtype=np.int64)
225
+ ix = ix.reshape(npy, npx, npz, order="F").copy()
226
+ t[0] = ix[0:(npy - 1), 0:(npx - 1), 0:(npz - 1)].reshape(ne, 1, order="F").copy().flatten()
227
+ t[1] = ix[1:npy, 0:(npx - 1), 0:(npz - 1)].reshape(ne, 1, order="F").copy().flatten()
228
+ t[2] = ix[0:(npy - 1), 1:npx, 0:(npz - 1)].reshape(ne, 1, order="F").copy().flatten()
229
+ t[3] = ix[0:(npy - 1), 0:(npx - 1), 1:npz].reshape(ne, 1, order="F").copy().flatten()
230
+ t[4] = ix[1:npy, 1:npx, 0:(npz - 1)].reshape(ne, 1, order="F").copy().flatten()
231
+ t[5] = ix[1:npy, 0:(npx - 1), 1:npz].reshape(ne, 1, order="F").copy().flatten()
232
+ t[6] = ix[0:(npy - 1), 1:npx, 1:npz].reshape(ne, 1, order="F").copy().flatten()
233
+ t[7] = ix[1:npy, 1:npx, 1:npz].reshape(ne, 1, order="F").copy().flatten()
234
+
235
+ T = np.zeros((4, 6 * ne), dtype=np.int64)
236
+ T[:, :ne] = t[[0, 1, 5, 7]]
237
+ T[:, ne:(2 * ne)] = t[[0, 1, 4, 7]]
238
+ T[:, (2 * ne):(3 * ne)] = t[[0, 2, 4, 7]]
239
+ T[:, (3 * ne):(4 * ne)] = t[[0, 3, 5, 7]]
240
+ T[:, (4 * ne):(5 * ne)] = t[[0, 2, 6, 7]]
241
+ T[:, (5 * ne):] = t[[0, 3, 6, 7]]
242
+
243
+ coords = p.T.astype(DTYPE, copy=False)
244
+ conn = T.T.astype(np.int32, copy=False)
245
+ conn = self._fix_orientation(coords, conn)
246
+ return TetMesh(coords=jnp.array(coords), conn=jnp.array(conn))
@@ -0,0 +1,53 @@
1
+ """Physics-level helpers (constitutive models, material laws, etc.)."""
2
+
3
+ from .elasticity import (
4
+ lame_parameters,
5
+ isotropic_3d_D,
6
+ linear_elasticity_form,
7
+ vector_body_force_form,
8
+ constant_body_force_vector_form,
9
+ assemble_constant_body_force,
10
+ right_cauchy_green,
11
+ green_lagrange_strain,
12
+ deformation_gradient,
13
+ pk2_neo_hookean,
14
+ neo_hookean_residual_form,
15
+ make_elastic_point_data,
16
+ write_elastic_vtu,
17
+ principal_stresses,
18
+ principal_sum,
19
+ max_shear_stress,
20
+ von_mises_stress,
21
+ )
22
+ from .diffusion import diffusion_form
23
+ from .operators import dot, ddot, transpose_last2, sym_grad, sym_grad_u
24
+ from .postprocess import make_point_data_displacement, write_point_data_vtu, interpolate_at_points
25
+
26
+ __all__ = [
27
+ "lame_parameters",
28
+ "isotropic_3d_D",
29
+ "linear_elasticity_form",
30
+ "vector_body_force_form",
31
+ "constant_body_force_vector_form",
32
+ "assemble_constant_body_force",
33
+ "diffusion_form",
34
+ "dot",
35
+ "ddot",
36
+ "transpose_last2",
37
+ "sym_grad",
38
+ "sym_grad_u",
39
+ "right_cauchy_green",
40
+ "green_lagrange_strain",
41
+ "deformation_gradient",
42
+ "pk2_neo_hookean",
43
+ "neo_hookean_residual_form",
44
+ "make_elastic_point_data",
45
+ "write_elastic_vtu",
46
+ "make_point_data_displacement",
47
+ "write_point_data_vtu",
48
+ "interpolate_at_points",
49
+ "principal_stresses",
50
+ "principal_sum",
51
+ "max_shear_stress",
52
+ "von_mises_stress",
53
+ ]