fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,318 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import os
5
+ from typing import Iterable
6
+
7
+ import numpy as np
8
+
9
+ from .surface import SurfaceMesh
10
+
11
+ _SUPERMESH_CACHE: dict[tuple, "SurfaceSupermesh"] = {}
12
+
13
+
14
+ @dataclass(eq=False)
15
+ class SurfaceSupermesh:
16
+ """Intersection supermesh for two surface meshes."""
17
+ coords: np.ndarray
18
+ conn: np.ndarray
19
+ source_facets_a: np.ndarray
20
+ source_facets_b: np.ndarray
21
+
22
+
23
+ def _polygon_area_2d(pts: np.ndarray) -> float:
24
+ x = pts[:, 0]
25
+ y = pts[:, 1]
26
+ return 0.5 * float(np.sum(x * np.roll(y, -1) - y * np.roll(x, -1)))
27
+
28
+
29
+ def _cross2(a: np.ndarray, b: np.ndarray) -> float:
30
+ return float(a[0] * b[1] - a[1] * b[0])
31
+
32
+
33
+ def _line_intersection(p1, p2, p3, p4, *, tol: float):
34
+ d1 = p2 - p1
35
+ d2 = p4 - p3
36
+ denom = _cross2(d1, d2)
37
+ if abs(denom) < tol:
38
+ return p2
39
+ t = _cross2(p3 - p1, d2) / denom
40
+ return p1 + t * d1
41
+
42
+
43
+ def _sutherland_hodgman(subject: list[np.ndarray], clip: list[np.ndarray], *, tol: float):
44
+ if not subject:
45
+ return []
46
+ orient = np.sign(_polygon_area_2d(np.array(clip)))
47
+ if orient == 0:
48
+ return []
49
+
50
+ def inside(pt, a, b):
51
+ return orient * _cross2(b - a, pt - a) >= -tol
52
+
53
+ output = subject
54
+ for i in range(len(clip)):
55
+ input_list = output
56
+ if not input_list:
57
+ break
58
+ output = []
59
+ cp1 = clip[i]
60
+ cp2 = clip[(i + 1) % len(clip)]
61
+ s = input_list[-1]
62
+ for e in input_list:
63
+ if inside(e, cp1, cp2):
64
+ if not inside(s, cp1, cp2):
65
+ output.append(_line_intersection(s, e, cp1, cp2, tol=tol))
66
+ output.append(e)
67
+ elif inside(s, cp1, cp2):
68
+ output.append(_line_intersection(s, e, cp1, cp2, tol=tol))
69
+ s = e
70
+ return output
71
+
72
+
73
+ def _plane_basis(normal: np.ndarray):
74
+ n = normal / np.linalg.norm(normal)
75
+ ref = np.array([1.0, 0.0, 0.0], dtype=float)
76
+ if abs(np.dot(n, ref)) > 0.9:
77
+ ref = np.array([0.0, 1.0, 0.0], dtype=float)
78
+ t1 = np.cross(n, ref)
79
+ t1 = t1 / np.linalg.norm(t1)
80
+ t2 = np.cross(n, t1)
81
+ return t1, t2, n
82
+
83
+
84
+ def _facet_plane(pts: np.ndarray, *, tol: float):
85
+ n = None
86
+ for i in range(len(pts) - 2):
87
+ v1 = pts[i + 1] - pts[i]
88
+ v2 = pts[i + 2] - pts[i]
89
+ n_candidate = np.cross(v1, v2)
90
+ n_norm = np.linalg.norm(n_candidate)
91
+ if n_norm > tol:
92
+ n = n_candidate / n_norm
93
+ d = -float(np.dot(n, pts[i]))
94
+ return n, d
95
+ return None, None
96
+
97
+
98
+ def _coplanar(pts_a: np.ndarray, pts_b: np.ndarray, *, tol: float) -> bool:
99
+ n, d = _facet_plane(pts_a, tol=tol)
100
+ if n is None:
101
+ return False
102
+ n2, d2 = _facet_plane(pts_b, tol=tol)
103
+ if n2 is None:
104
+ return False
105
+ if abs(abs(np.dot(n, n2)) - 1.0) > 1e-4:
106
+ return False
107
+ dist_a = np.abs(pts_a @ n + d)
108
+ dist_b = np.abs(pts_b @ n + d)
109
+ return np.max(dist_a) <= tol and np.max(dist_b) <= tol
110
+
111
+
112
+ def _project(points: np.ndarray, origin: np.ndarray, t1: np.ndarray, t2: np.ndarray):
113
+ rel = points - origin[None, :]
114
+ x = rel @ t1
115
+ y = rel @ t2
116
+ return np.stack([x, y], axis=1)
117
+
118
+
119
+ def _unique_points(points: Iterable[np.ndarray], *, tol: float):
120
+ scale = 1.0 / tol
121
+ mapping: dict[tuple[int, int, int], int] = {}
122
+ coords: list[np.ndarray] = []
123
+ indices: list[int] = []
124
+ for p in points:
125
+ key = tuple(np.round(p * scale).astype(int))
126
+ idx = mapping.get(key)
127
+ if idx is None:
128
+ idx = len(coords)
129
+ mapping[key] = idx
130
+ coords.append(p)
131
+ indices.append(idx)
132
+ return np.asarray(coords, dtype=float), indices
133
+
134
+
135
+ def _facet_polygon_coords(coords: np.ndarray, facet: np.ndarray) -> np.ndarray:
136
+ n = int(len(facet))
137
+ if n == 9:
138
+ corner = [0, 2, 8, 6]
139
+ return coords[facet][corner]
140
+ return coords[facet]
141
+
142
+
143
+ def _triangle_min_angle(p0: np.ndarray, p1: np.ndarray, p2: np.ndarray) -> float:
144
+ def angle(a, b, c):
145
+ v1 = a - b
146
+ v2 = c - b
147
+ n1 = np.linalg.norm(v1)
148
+ n2 = np.linalg.norm(v2)
149
+ if n1 == 0.0 or n2 == 0.0:
150
+ return 0.0
151
+ cosang = np.clip(np.dot(v1, v2) / (n1 * n2), -1.0, 1.0)
152
+ return float(np.arccos(cosang))
153
+
154
+ return min(angle(p1, p0, p2), angle(p0, p1, p2), angle(p0, p2, p1))
155
+
156
+
157
+ def _triangulate_polygon(indices: list[int], poly2d: np.ndarray) -> list[tuple[int, int, int]]:
158
+ n = len(indices)
159
+ if n < 3:
160
+ return []
161
+ if n == 3:
162
+ return [(indices[0], indices[1], indices[2])]
163
+ if n == 4:
164
+ p = poly2d
165
+ diag_pref = os.getenv("FLUXFEM_SUPERMESH_QUAD_DIAG", "alt").lower()
166
+ if diag_pref == "alt":
167
+ return [(indices[0], indices[1], indices[3]), (indices[1], indices[2], indices[3])]
168
+ if diag_pref == "fan":
169
+ return [(indices[0], indices[1], indices[2]), (indices[0], indices[2], indices[3])]
170
+ min_a = min(
171
+ _triangle_min_angle(p[0], p[1], p[2]),
172
+ _triangle_min_angle(p[0], p[2], p[3]),
173
+ )
174
+ min_b = min(
175
+ _triangle_min_angle(p[0], p[1], p[3]),
176
+ _triangle_min_angle(p[1], p[2], p[3]),
177
+ )
178
+ if min_b > min_a:
179
+ return [(indices[0], indices[1], indices[3]), (indices[1], indices[2], indices[3])]
180
+ return [(indices[0], indices[1], indices[2]), (indices[0], indices[2], indices[3])]
181
+ tris = []
182
+ for i in range(1, n - 1):
183
+ tris.append((indices[0], indices[i], indices[i + 1]))
184
+ return tris
185
+
186
+
187
+ def build_surface_supermesh(
188
+ surface_a: SurfaceMesh,
189
+ surface_b: SurfaceMesh,
190
+ *,
191
+ tol: float = 1e-8,
192
+ cache_enabled: bool | None = None,
193
+ cache_trace: bool | None = None,
194
+ ) -> SurfaceSupermesh:
195
+ from ..solver.bc import facet_normals
196
+ import hashlib
197
+
198
+ if cache_enabled is None:
199
+ cache_enabled = os.getenv("FLUXFEM_SUPERMESH_CACHE", "0") not in ("0", "", "false", "False")
200
+ if cache_trace is None:
201
+ cache_trace = os.getenv("FLUXFEM_SUPERMESH_CACHE_TRACE", "0") not in ("0", "", "false", "False")
202
+
203
+ def _array_sig(arr: np.ndarray) -> tuple:
204
+ arr_c = np.ascontiguousarray(arr)
205
+ h = hashlib.blake2b(arr_c.view(np.uint8), digest_size=8).hexdigest()
206
+ return (arr_c.shape, str(arr_c.dtype), h)
207
+ if cache_enabled:
208
+ global _SUPERMESH_CACHE
209
+ try:
210
+ _SUPERMESH_CACHE
211
+ except NameError:
212
+ _SUPERMESH_CACHE = {}
213
+ key = (
214
+ _array_sig(np.asarray(surface_a.coords)),
215
+ _array_sig(np.asarray(surface_a.conn)),
216
+ _array_sig(np.asarray(surface_b.coords)),
217
+ _array_sig(np.asarray(surface_b.conn)),
218
+ float(tol),
219
+ )
220
+ cached = _SUPERMESH_CACHE.get(key)
221
+ if cached is not None:
222
+ if cache_trace:
223
+ print(f"[supermesh] cache hit n_tris={int(cached.conn.shape[0])}", flush=True)
224
+ return cached
225
+
226
+ coords_a = np.asarray(surface_a.coords, dtype=float)
227
+ coords_b = np.asarray(surface_b.coords, dtype=float)
228
+ facets_a = np.asarray(surface_a.conn, dtype=int)
229
+ facets_b = np.asarray(surface_b.conn, dtype=int)
230
+ normals_a = facet_normals(surface_a, outward_from=np.mean(coords_a, axis=0), normalize=True)
231
+
232
+ all_coords: list[np.ndarray] = []
233
+ all_conn: list[tuple[int, int, int]] = []
234
+ src_a: list[int] = []
235
+ src_b: list[int] = []
236
+
237
+ for ia, fa in enumerate(facets_a):
238
+ pts_a = _facet_polygon_coords(coords_a, fa)
239
+ min_a = pts_a.min(axis=0)
240
+ max_a = pts_a.max(axis=0)
241
+ for ib, fb in enumerate(facets_b):
242
+ pts_b = _facet_polygon_coords(coords_b, fb)
243
+ if np.any(pts_b.max(axis=0) < min_a - tol) or np.any(pts_b.min(axis=0) > max_a + tol):
244
+ continue
245
+ if not _coplanar(pts_a, pts_b, tol=tol):
246
+ continue
247
+
248
+ n, _d = _facet_plane(pts_a, tol=tol)
249
+ if n is not None:
250
+ n_ref = normals_a[int(ia)]
251
+ if np.dot(n, n_ref) < 0.0:
252
+ n = -n
253
+ t1, t2, _ = _plane_basis(n)
254
+ origin = pts_a[0]
255
+
256
+ poly_a = _project(pts_a, origin, t1, t2)
257
+ poly_b = _project(pts_b, origin, t1, t2)
258
+
259
+ inter = _sutherland_hodgman(
260
+ [p.copy() for p in poly_a],
261
+ [p.copy() for p in poly_b],
262
+ tol=tol,
263
+ )
264
+ if len(inter) < 3:
265
+ continue
266
+ inter_np = np.asarray(inter)
267
+ if abs(_polygon_area_2d(inter_np)) <= tol:
268
+ continue
269
+ center = np.mean(inter_np, axis=0)
270
+ angles = np.arctan2(inter_np[:, 1] - center[1], inter_np[:, 0] - center[0])
271
+ order = np.argsort(angles)
272
+ inter_np = inter_np[order]
273
+
274
+ inter_3d = origin[None, :] + inter_np[:, 0:1] * t1 + inter_np[:, 1:2] * t2
275
+ coords_local, idx = _unique_points(inter_3d, tol=tol)
276
+ base = len(all_coords)
277
+ for p in coords_local:
278
+ all_coords.append(p)
279
+ tris = _triangulate_polygon(idx, inter_np)
280
+ for a_idx, b_idx, c_idx in tris:
281
+ a_id = base + a_idx
282
+ b_id = base + b_idx
283
+ c_id = base + c_idx
284
+ if n is not None:
285
+ pa = all_coords[a_id]
286
+ pb = all_coords[b_id]
287
+ pc = all_coords[c_id]
288
+ n_tri = np.cross(pb - pa, pc - pa)
289
+ if np.dot(n_tri, n) < 0.0:
290
+ b_id, c_id = c_id, b_id
291
+ all_conn.append((a_id, b_id, c_id))
292
+ src_a.append(ia)
293
+ src_b.append(ib)
294
+
295
+ if not all_conn:
296
+ sm = SurfaceSupermesh(
297
+ coords=np.zeros((0, 3), dtype=float),
298
+ conn=np.zeros((0, 3), dtype=int),
299
+ source_facets_a=np.zeros((0,), dtype=int),
300
+ source_facets_b=np.zeros((0,), dtype=int),
301
+ )
302
+ if cache_enabled:
303
+ _SUPERMESH_CACHE[key] = sm
304
+ return sm
305
+
306
+ coords = np.asarray(all_coords, dtype=float)
307
+ conn = np.asarray(all_conn, dtype=int)
308
+ sm = SurfaceSupermesh(
309
+ coords=coords,
310
+ conn=conn,
311
+ source_facets_a=np.asarray(src_a, dtype=int),
312
+ source_facets_b=np.asarray(src_b, dtype=int),
313
+ )
314
+ if cache_enabled:
315
+ _SUPERMESH_CACHE[key] = sm
316
+ if cache_trace:
317
+ print(f"[supermesh] cache store n_tris={int(sm.conn.shape[0])}", flush=True)
318
+ return sm
fluxfem/mesh/surface.py CHANGED
@@ -1,16 +1,32 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Optional
4
+ from typing import Callable, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, cast
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+
8
+ from .dtypes import INDEX_DTYPE
7
9
  import numpy as np
10
+ import numpy.typing as npt
8
11
 
9
12
  DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
10
13
 
11
14
  from .base import BaseMesh, BaseMeshPytree
12
15
  from .hex import HexMesh, HexMeshPytree
13
16
 
17
+ P = TypeVar("P")
18
+
19
+ if TYPE_CHECKING:
20
+ from ..solver.bc import SurfaceFormContext
21
+
22
+
23
+ class SurfaceSpaceLike(Protocol):
24
+ value_dim: int
25
+ mesh: BaseMesh
26
+
27
+
28
+ SurfaceLinearForm = Callable[["SurfaceFormContext", P], npt.ArrayLike]
29
+
14
30
 
15
31
  def _polygon_area(pts: np.ndarray) -> float:
16
32
  """
@@ -24,7 +40,7 @@ def _polygon_area(pts: np.ndarray) -> float:
24
40
  for i in range(1, pts.shape[0] - 1):
25
41
  v1 = pts[i] - p0
26
42
  v2 = pts[i + 1] - p0
27
- area += 0.5 * np.linalg.norm(np.cross(v1, v2))
43
+ area += float(0.5 * np.linalg.norm(np.cross(v1, v2)))
28
44
  return float(area)
29
45
 
30
46
 
@@ -53,8 +69,8 @@ class SurfaceMesh(BaseMesh):
53
69
  node_tags: Optional[jnp.ndarray] = None,
54
70
  ) -> "SurfaceMesh":
55
71
  coords_j = jnp.asarray(coords, dtype=DTYPE)
56
- facets_j = jnp.asarray(facets, dtype=jnp.int32)
57
- tags_j = None if facet_tags is None else jnp.asarray(facet_tags, dtype=jnp.int32)
72
+ facets_j = jnp.asarray(facets, dtype=INDEX_DTYPE)
73
+ tags_j = None if facet_tags is None else jnp.asarray(facet_tags, dtype=INDEX_DTYPE)
58
74
  node_tags_j = None if node_tags is None else jnp.asarray(node_tags)
59
75
  return cls(coords=coords_j, conn=facets_j, cell_tags=tags_j, node_tags=node_tags_j, facet_tags=tags_j)
60
76
 
@@ -101,31 +117,70 @@ class SurfaceMesh(BaseMesh):
101
117
  from ..solver.bc import facet_normals
102
118
  return facet_normals(self, outward_from=outward_from, normalize=normalize)
103
119
 
104
- def assemble_load(self, load, *, dim: int, n_total_nodes: int | None = None, F0=None):
120
+ def assemble_load(
121
+ self,
122
+ load: npt.ArrayLike,
123
+ *,
124
+ dim: int,
125
+ n_total_nodes: int | None = None,
126
+ F0: npt.ArrayLike | None = None,
127
+ ) -> np.ndarray:
105
128
  from ..solver.bc import assemble_surface_load
106
129
  return assemble_surface_load(self, load, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
107
130
 
108
- def assemble_linear_form(self, form, params, *, dim: int, n_total_nodes: int | None = None, F0=None):
131
+ def assemble_linear_form(
132
+ self,
133
+ form: SurfaceLinearForm[P],
134
+ params: P,
135
+ *,
136
+ dim: int,
137
+ n_total_nodes: int | None = None,
138
+ F0: npt.ArrayLike | None = None,
139
+ ) -> np.ndarray:
109
140
  from ..solver.bc import assemble_surface_linear_form
110
141
  return assemble_surface_linear_form(self, form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
111
142
 
112
- def assemble_linear_form_on_space(self, space, form, params, *, F0=None):
143
+ def assemble_linear_form_on_space(
144
+ self,
145
+ space: SurfaceSpaceLike,
146
+ form: SurfaceLinearForm[P],
147
+ params: P,
148
+ *,
149
+ F0: npt.ArrayLike | None = None,
150
+ ) -> np.ndarray:
113
151
  """
114
152
  Assemble surface linear form using global size inferred from a volume space.
115
153
  """
116
154
  dim = int(getattr(space, "value_dim", 1))
117
- n_total_nodes = int(getattr(space, "mesh", self).n_nodes)
155
+ mesh = cast(BaseMesh, getattr(space, "mesh", self))
156
+ n_total_nodes = int(mesh.n_nodes)
118
157
  return self.assemble_linear_form(form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
119
158
 
159
+
160
+ @dataclass(frozen=True)
161
+ class SurfaceWithElemConn:
162
+ surface: SurfaceMesh
163
+ elem_conn: np.ndarray
164
+
165
+
166
+ def surface_with_elem_conn(mesh: BaseMesh, facets, *, mode: str = "touching") -> SurfaceWithElemConn:
167
+ """
168
+ Build a SurfaceMesh from facets and return it with a matching elem_conn array.
169
+ """
170
+ surface = SurfaceMesh.from_facets(mesh.coords, facets, node_tags=mesh.node_tags)
171
+ elems = mesh.elements_from_facets(facets, mode=mode)
172
+ elem_conn = np.asarray(mesh.conn, dtype=int)[elems]
173
+ return SurfaceWithElemConn(surface=surface, elem_conn=elem_conn)
174
+
120
175
  def assemble_traction(
121
176
  self,
122
- traction,
177
+ traction: float | Sequence[float],
123
178
  *,
124
179
  dim: int = 3,
125
180
  n_total_nodes: int | None = None,
126
- F0=None,
127
- outward_from=None,
128
- ):
181
+ F0: npt.ArrayLike | None = None,
182
+ outward_from: npt.ArrayLike | None = None,
183
+ ) -> np.ndarray:
129
184
  from ..solver.bc import assemble_surface_traction
130
185
  return assemble_surface_traction(
131
186
  self,
@@ -162,8 +217,8 @@ class SurfaceMeshPytree(BaseMeshPytree):
162
217
  node_tags: Optional[jnp.ndarray] = None,
163
218
  ) -> "SurfaceMeshPytree":
164
219
  coords_j = jnp.asarray(coords, dtype=DTYPE)
165
- facets_j = jnp.asarray(facets, dtype=jnp.int32)
166
- tags_j = None if facet_tags is None else jnp.asarray(facet_tags, dtype=jnp.int32)
220
+ facets_j = jnp.asarray(facets, dtype=INDEX_DTYPE)
221
+ tags_j = None if facet_tags is None else jnp.asarray(facet_tags, dtype=INDEX_DTYPE)
167
222
  node_tags_j = None if node_tags is None else jnp.asarray(node_tags)
168
223
  return cls(coords=coords_j, conn=facets_j, cell_tags=tags_j, node_tags=node_tags_j, facet_tags=tags_j)
169
224
 
@@ -204,31 +259,54 @@ class SurfaceMeshPytree(BaseMeshPytree):
204
259
  from ..solver.bc import facet_normals
205
260
  return facet_normals(self, outward_from=outward_from, normalize=normalize)
206
261
 
207
- def assemble_load(self, load, *, dim: int, n_total_nodes: int | None = None, F0=None):
262
+ def assemble_load(
263
+ self,
264
+ load: npt.ArrayLike,
265
+ *,
266
+ dim: int,
267
+ n_total_nodes: int | None = None,
268
+ F0: npt.ArrayLike | None = None,
269
+ ) -> np.ndarray:
208
270
  from ..solver.bc import assemble_surface_load
209
271
  return assemble_surface_load(self, load, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
210
272
 
211
- def assemble_linear_form(self, form, params, *, dim: int, n_total_nodes: int | None = None, F0=None):
273
+ def assemble_linear_form(
274
+ self,
275
+ form: SurfaceLinearForm[P],
276
+ params: P,
277
+ *,
278
+ dim: int,
279
+ n_total_nodes: int | None = None,
280
+ F0: npt.ArrayLike | None = None,
281
+ ) -> np.ndarray:
212
282
  from ..solver.bc import assemble_surface_linear_form
213
283
  return assemble_surface_linear_form(self, form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
214
284
 
215
- def assemble_linear_form_on_space(self, space, form, params, *, F0=None):
285
+ def assemble_linear_form_on_space(
286
+ self,
287
+ space: SurfaceSpaceLike,
288
+ form: SurfaceLinearForm[P],
289
+ params: P,
290
+ *,
291
+ F0: npt.ArrayLike | None = None,
292
+ ) -> np.ndarray:
216
293
  """
217
294
  Assemble surface linear form using global size inferred from a volume space.
218
295
  """
219
296
  dim = int(getattr(space, "value_dim", 1))
220
- n_total_nodes = int(getattr(space, "mesh", self).n_nodes)
297
+ mesh = cast(BaseMesh, getattr(space, "mesh", self))
298
+ n_total_nodes = int(mesh.n_nodes)
221
299
  return self.assemble_linear_form(form, params, dim=dim, n_total_nodes=n_total_nodes, F0=F0)
222
300
 
223
301
  def assemble_traction(
224
302
  self,
225
- traction,
303
+ traction: float | Sequence[float],
226
304
  *,
227
305
  dim: int = 3,
228
306
  n_total_nodes: int | None = None,
229
- F0=None,
230
- outward_from=None,
231
- ):
307
+ F0: npt.ArrayLike | None = None,
308
+ outward_from: npt.ArrayLike | None = None,
309
+ ) -> np.ndarray:
232
310
  from ..solver.bc import assemble_surface_traction
233
311
  return assemble_surface_traction(
234
312
  self,
@@ -241,12 +319,12 @@ class SurfaceMeshPytree(BaseMeshPytree):
241
319
 
242
320
  def assemble_flux(
243
321
  self,
244
- flux,
322
+ flux: npt.ArrayLike,
245
323
  *,
246
324
  n_total_nodes: int | None = None,
247
- F0=None,
248
- outward_from=None,
249
- ):
325
+ F0: npt.ArrayLike | None = None,
326
+ outward_from: npt.ArrayLike | None = None,
327
+ ) -> np.ndarray:
250
328
  from ..solver.bc import assemble_surface_flux
251
329
  return assemble_surface_flux(
252
330
  self,
fluxfem/mesh/tet.py CHANGED
@@ -5,6 +5,8 @@ import jax
5
5
  import jax.numpy as jnp
6
6
  import numpy as np
7
7
 
8
+ from .dtypes import NP_INDEX_DTYPE
9
+
8
10
  DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
9
11
 
10
12
 
@@ -58,6 +60,10 @@ class StructuredTetBox:
58
60
  J = np.stack([p1 - p0, p2 - p0, p3 - p0], axis=1)
59
61
  if np.linalg.det(J) < 0:
60
62
  tet[[1, 2]] = tet[[2, 1]] # swap corner1/corner2
63
+ if tet.shape[0] == 10:
64
+ # keep edge-node ordering consistent with corner swap
65
+ tet[[4, 6]] = tet[[6, 4]] # edges (0-1) <-> (0-2)
66
+ tet[[8, 9]] = tet[[9, 8]] # edges (1-3) <-> (2-3)
61
67
  conn[idx] = tet
62
68
  return conn
63
69
 
@@ -115,10 +121,10 @@ class StructuredTetBox:
115
121
  n12 = add_node(mid(p1, p2))
116
122
  n13 = add_node(mid(p1, p3))
117
123
  n23 = add_node(mid(p2, p3))
118
- conn_list.append([n0, n1, n2, n3, n01, n02, n03, n12, n13, n23])
124
+ conn_list.append([n0, n1, n2, n3, n01, n12, n02, n03, n13, n23])
119
125
 
120
126
  coords = np.asarray(coords_list, dtype=DTYPE)
121
- conn = np.asarray(conn_list, dtype=np.int32)
127
+ conn = np.asarray(conn_list, dtype=NP_INDEX_DTYPE)
122
128
  conn = self._fix_orientation(coords, conn)
123
129
  return TetMesh(coords=jnp.array(coords), conn=jnp.array(conn))
124
130
 
@@ -169,7 +175,7 @@ class StructuredTetBox:
169
175
  [v010, v001, v011, v111],
170
176
  ]
171
177
  )
172
- conn = np.asarray(conn_list, dtype=np.int32)
178
+ conn = np.asarray(conn_list, dtype=NP_INDEX_DTYPE)
173
179
  conn = self._fix_orientation(coords, conn)
174
180
  return TetMesh(coords=jnp.array(coords), conn=jnp.array(conn))
175
181
 
@@ -217,11 +223,14 @@ class StructuredTetTensorBox:
217
223
  npx = len(xs)
218
224
  npy = len(ys)
219
225
  npz = len(zs)
226
+ X: np.ndarray
227
+ Y: np.ndarray
228
+ Z: np.ndarray
220
229
  X, Y, Z = np.meshgrid(np.sort(xs), np.sort(ys), np.sort(zs))
221
230
  p = np.vstack((X.flatten("F"), Y.flatten("F"), Z.flatten("F")))
222
- ix = np.arange(npx * npy * npz)
231
+ ix: np.ndarray = np.arange(npx * npy * npz)
223
232
  ne = (npx - 1) * (npy - 1) * (npz - 1)
224
- t = np.zeros((8, ne), dtype=np.int64)
233
+ t: np.ndarray = np.zeros((8, ne), dtype=np.int64)
225
234
  ix = ix.reshape(npy, npx, npz, order="F").copy()
226
235
  t[0] = ix[0:(npy - 1), 0:(npx - 1), 0:(npz - 1)].reshape(ne, 1, order="F").copy().flatten()
227
236
  t[1] = ix[1:npy, 0:(npx - 1), 0:(npz - 1)].reshape(ne, 1, order="F").copy().flatten()
@@ -232,7 +241,7 @@ class StructuredTetTensorBox:
232
241
  t[6] = ix[0:(npy - 1), 1:npx, 1:npz].reshape(ne, 1, order="F").copy().flatten()
233
242
  t[7] = ix[1:npy, 1:npx, 1:npz].reshape(ne, 1, order="F").copy().flatten()
234
243
 
235
- T = np.zeros((4, 6 * ne), dtype=np.int64)
244
+ T: np.ndarray = np.zeros((4, 6 * ne), dtype=np.int64)
236
245
  T[:, :ne] = t[[0, 1, 5, 7]]
237
246
  T[:, ne:(2 * ne)] = t[[0, 1, 4, 7]]
238
247
  T[:, (2 * ne):(3 * ne)] = t[[0, 2, 4, 7]]
@@ -241,6 +250,6 @@ class StructuredTetTensorBox:
241
250
  T[:, (5 * ne):] = t[[0, 3, 6, 7]]
242
251
 
243
252
  coords = p.T.astype(DTYPE, copy=False)
244
- conn = T.T.astype(np.int32, copy=False)
253
+ conn: np.ndarray = T.T.astype(NP_INDEX_DTYPE, copy=False)
245
254
  conn = self._fix_orientation(coords, conn)
246
255
  return TetMesh(coords=jnp.array(coords), conn=jnp.array(conn))
@@ -15,4 +15,7 @@ def diffusion_form(ctx: FormContext, kappa: float) -> jnp.ndarray:
15
15
  return kappa * G
16
16
 
17
17
 
18
+ diffusion_form._ff_kind = "bilinear"
19
+ diffusion_form._ff_domain = "volume"
20
+
18
21
  __all__ = ["diffusion_form"]
@@ -1,11 +1,23 @@
1
+ from typing import Mapping, TYPE_CHECKING, TypeAlias
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import numpy as np
4
6
 
5
7
  from ...core.forms import FormContext
8
+ from ...core.space import FESpace
9
+ from ...mesh import BaseMesh
6
10
  from ...core.basis import build_B_matrices_finite
7
11
  from ..postprocess import make_point_data_displacement, write_point_data_vtu
8
12
 
13
+ if TYPE_CHECKING:
14
+ from jax import Array as JaxArray
15
+
16
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
17
+ else:
18
+ ArrayLike: TypeAlias = np.ndarray
19
+ ParamsLike: TypeAlias = Mapping[str, float] | tuple[float, float]
20
+
9
21
 
10
22
  def right_cauchy_green(F: jnp.ndarray) -> jnp.ndarray:
11
23
  """C = F^T F (right Cauchy-Green)."""
@@ -46,7 +58,9 @@ def pk2_neo_hookean(F: jnp.ndarray, mu: float, lam: float) -> jnp.ndarray:
46
58
  return mu * (I - C_inv) + lam * jnp.log(J)[..., None, None] * C_inv
47
59
 
48
60
 
49
- def neo_hookean_residual_form(ctx: FormContext, u_elem: jnp.ndarray, params) -> jnp.ndarray:
61
+ def neo_hookean_residual_form(
62
+ ctx: FormContext, u_elem: jnp.ndarray, params: ParamsLike
63
+ ) -> jnp.ndarray:
50
64
  """
51
65
  Compressible Neo-Hookean residual (Total Lagrangian, PK2).
52
66
  params: dict-like with keys \"mu\", \"lam\" or tuple (mu, lam)
@@ -78,6 +92,9 @@ def neo_hookean_residual_form(ctx: FormContext, u_elem: jnp.ndarray, params) ->
78
92
  return jnp.einsum("qik,qk->qi", BT, S_voigt) # (n_q, n_ldofs)
79
93
 
80
94
 
95
+ neo_hookean_residual_form._ff_kind = "residual"
96
+ neo_hookean_residual_form._ff_domain = "volume"
97
+
81
98
  __all__ = [
82
99
  "right_cauchy_green",
83
100
  "green_lagrange_strain",
@@ -89,11 +106,26 @@ __all__ = [
89
106
  ]
90
107
 
91
108
 
92
- def make_elastic_point_data(mesh, space, u, *, compute_j: bool = True, deformed_scale: float = 1.0):
109
+ def make_elastic_point_data(
110
+ mesh: BaseMesh,
111
+ space: FESpace,
112
+ u: ArrayLike,
113
+ *,
114
+ compute_j: bool = True,
115
+ deformed_scale: float = 1.0,
116
+ ) -> dict[str, np.ndarray]:
93
117
  """Alias to postprocess.make_point_data_displacement for backward compatibility."""
94
118
  return make_point_data_displacement(mesh, space, u, compute_j=compute_j, deformed_scale=deformed_scale)
95
119
 
96
120
 
97
- def write_elastic_vtu(mesh, space, u, filepath: str, *, compute_j: bool = True, deformed_scale: float = 1.0):
121
+ def write_elastic_vtu(
122
+ mesh: BaseMesh,
123
+ space: FESpace,
124
+ u: ArrayLike,
125
+ filepath: str,
126
+ *,
127
+ compute_j: bool = True,
128
+ deformed_scale: float = 1.0,
129
+ ) -> None:
98
130
  """Alias to postprocess.write_point_data_vtu for backward compatibility."""
99
131
  return write_point_data_vtu(mesh, space, u, filepath, compute_j=compute_j, deformed_scale=deformed_scale)