fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. fluxfem/__init__.py +136 -161
  2. fluxfem/core/__init__.py +172 -41
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/context_types.py +36 -0
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +15 -1
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +348 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +262 -17
  13. fluxfem/core/weakform.py +1503 -312
  14. fluxfem/helpers_wf.py +53 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +322 -8
  17. fluxfem/mesh/contact.py +825 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +18 -16
  20. fluxfem/mesh/io.py +8 -4
  21. fluxfem/mesh/mortar.py +3907 -0
  22. fluxfem/mesh/supermesh.py +316 -0
  23. fluxfem/mesh/surface.py +22 -4
  24. fluxfem/mesh/tet.py +10 -4
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  27. fluxfem/physics/elasticity/linear.py +9 -2
  28. fluxfem/solver/__init__.py +42 -2
  29. fluxfem/solver/bc.py +38 -2
  30. fluxfem/solver/block_matrix.py +132 -0
  31. fluxfem/solver/block_system.py +454 -0
  32. fluxfem/solver/cg.py +115 -33
  33. fluxfem/solver/dirichlet.py +334 -4
  34. fluxfem/solver/newton.py +237 -60
  35. fluxfem/solver/petsc.py +439 -0
  36. fluxfem/solver/preconditioner.py +106 -0
  37. fluxfem/solver/result.py +18 -0
  38. fluxfem/solver/solve_runner.py +168 -1
  39. fluxfem/solver/solver.py +12 -1
  40. fluxfem/solver/sparse.py +124 -9
  41. fluxfem-0.2.0.dist-info/METADATA +303 -0
  42. fluxfem-0.2.0.dist-info/RECORD +59 -0
  43. fluxfem-0.1.3.dist-info/METADATA +0 -125
  44. fluxfem-0.1.3.dist-info/RECORD +0 -47
  45. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  46. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/helpers_wf.py CHANGED
@@ -10,18 +10,66 @@ from .core.weakform import (
10
10
  inner,
11
11
  action,
12
12
  gaction,
13
+ outer,
13
14
  I,
14
15
  det,
15
16
  inv,
16
17
  transpose,
17
18
  transpose_last2,
18
19
  matmul,
20
+ matmul_std,
19
21
  log,
22
+ einsum,
20
23
  normal,
21
24
  ds,
22
25
  dOmega,
26
+ ParamRef,
23
27
  )
24
28
 
29
+
30
+ def _voigt_A() -> tuple[tuple[tuple[float, ...], ...], ...]:
31
+ return (
32
+ ((1.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.5, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.5)),
33
+ ((0.0, 0.0, 0.0, 0.5, 0.0, 0.0), (0.0, 1.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.5, 0.0)),
34
+ ((0.0, 0.0, 0.0, 0.0, 0.0, 0.5), (0.0, 0.0, 0.0, 0.0, 0.5, 0.0), (0.0, 0.0, 1.0, 0.0, 0.0, 0.0)),
35
+ )
36
+
37
+
38
+ def _identity3() -> tuple[tuple[float, float, float], ...]:
39
+ return (
40
+ (1.0, 0.0, 0.0),
41
+ (0.0, 1.0, 0.0),
42
+ (0.0, 0.0, 1.0),
43
+ )
44
+
45
+
46
+ def voigt_to_tensor(sym_grad_u, p=None):
47
+ """
48
+ Convert Voigt-form symmetric gradient to a 3x3 tensor.
49
+
50
+ Uses the standard Voigt mapping with 1/2 on shear terms.
51
+ """
52
+ A = _voigt_A()
53
+ if p is not None and not isinstance(p, ParamRef):
54
+ A = getattr(p, "A", A)
55
+ return einsum("ijk,qk...->qij...", A, sym_grad_u)
56
+
57
+
58
+ def linear_stress(sym_grad_u, p):
59
+ """Linear elastic stress from symmetric gradient in Voigt notation."""
60
+ I = _identity3()
61
+ if not isinstance(p, ParamRef):
62
+ I = getattr(p, "I", I)
63
+ eps = voigt_to_tensor(sym_grad_u, p)
64
+ tr = einsum("ij,qij...->q...", I, eps)
65
+ return p.lam * einsum("q...,ij->qij...", tr, I) + 2.0 * p.mu * eps
66
+
67
+
68
+ def traction(field, n, p):
69
+ """Traction vector for a field using linear elastic stress."""
70
+ stress = linear_stress(sym_grad(field), p)
71
+ return einsum("qij...,qj->qi...", stress, n)
72
+
25
73
  __all__ = [
26
74
  "grad",
27
75
  "sym_grad",
@@ -31,13 +79,18 @@ __all__ = [
31
79
  "inner",
32
80
  "action",
33
81
  "gaction",
82
+ "outer",
34
83
  "I",
35
84
  "det",
36
85
  "inv",
37
86
  "transpose",
38
87
  "transpose_last2",
39
88
  "matmul",
89
+ "matmul_std",
40
90
  "log",
91
+ "voigt_to_tensor",
92
+ "linear_stress",
93
+ "traction",
41
94
  "normal",
42
95
  "ds",
43
96
  "dOmega",
fluxfem/mesh/__init__.py CHANGED
@@ -1,13 +1,40 @@
1
+ from .base import BaseMesh, BaseMeshPytree, SurfaceWithFacetMap
1
2
  from .hex import HexMesh, HexMeshPytree, StructuredHexBox, tag_axis_minmax_facets
2
3
  from .tet import TetMesh, TetMeshPytree, StructuredTetBox, StructuredTetTensorBox
3
- from .base import BaseMesh, BaseMeshPytree
4
4
  from .predicate import bbox_predicate, plane_predicate, axis_plane_predicate, slab_predicate
5
- from .surface import SurfaceMesh, SurfaceMeshPytree
5
+ from .surface import SurfaceMesh, SurfaceMeshPytree, SurfaceWithElemConn, surface_with_elem_conn
6
+ from .supermesh import SurfaceSupermesh, build_surface_supermesh
7
+ from .mortar import (
8
+ MortarMatrix,
9
+ assemble_mortar_matrices,
10
+ assemble_contact_onesided_floor,
11
+ map_surface_facets_to_tet_elements,
12
+ map_surface_facets_to_hex_elements,
13
+ assemble_mixed_surface_jacobian,
14
+ assemble_mixed_surface_residual,
15
+ tri_area,
16
+ tri_quadrature,
17
+ facet_triangles,
18
+ facet_shape_values,
19
+ volume_shape_values_at_points,
20
+ quad_shape_and_local,
21
+ quad9_shape_values,
22
+ hex27_gradN,
23
+ )
24
+ from .contact import (
25
+ ContactSurfaceSpace,
26
+ ContactSide,
27
+ OneSidedContact,
28
+ OneSidedContactSurfaceSpace,
29
+ facet_gap_values,
30
+ active_contact_facets,
31
+ )
6
32
  from .io import load_gmsh_mesh, load_gmsh_hex_mesh, load_gmsh_tet_mesh, make_surface_from_facets
7
33
 
8
34
  __all__ = [
9
35
  "BaseMesh",
10
36
  "BaseMeshPytree",
37
+ "SurfaceWithFacetMap",
11
38
  "bbox_predicate",
12
39
  "plane_predicate",
13
40
  "axis_plane_predicate",
@@ -22,6 +49,31 @@ __all__ = [
22
49
  "StructuredTetTensorBox",
23
50
  "SurfaceMesh",
24
51
  "SurfaceMeshPytree",
52
+ "SurfaceWithElemConn",
53
+ "surface_with_elem_conn",
54
+ "SurfaceSupermesh",
55
+ "build_surface_supermesh",
56
+ "MortarMatrix",
57
+ "assemble_mortar_matrices",
58
+ "assemble_contact_onesided_floor",
59
+ "assemble_mixed_surface_residual",
60
+ "assemble_mixed_surface_jacobian",
61
+ "map_surface_facets_to_tet_elements",
62
+ "map_surface_facets_to_hex_elements",
63
+ "tri_area",
64
+ "tri_quadrature",
65
+ "facet_triangles",
66
+ "facet_shape_values",
67
+ "volume_shape_values_at_points",
68
+ "quad_shape_and_local",
69
+ "quad9_shape_values",
70
+ "hex27_gradN",
71
+ "ContactSurfaceSpace",
72
+ "ContactSide",
73
+ "OneSidedContact",
74
+ "OneSidedContactSurfaceSpace",
75
+ "facet_gap_values",
76
+ "active_contact_facets",
25
77
  "load_gmsh_mesh",
26
78
  "load_gmsh_hex_mesh",
27
79
  "load_gmsh_tet_mesh",
fluxfem/mesh/base.py CHANGED
@@ -5,9 +5,16 @@ import numpy as np
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
 
8
+ from .dtypes import INDEX_DTYPE, NP_INDEX_DTYPE
9
+
8
10
 
9
11
  @dataclass
10
12
  class BaseMeshClosure:
13
+ """
14
+ Base mesh container with coordinates, connectivity, and optional tags.
15
+
16
+ Concrete mesh types should implement face_node_patterns() for boundary queries.
17
+ """
11
18
  coords: jnp.ndarray
12
19
  conn: jnp.ndarray
13
20
  cell_tags: Optional[jnp.ndarray] = None
@@ -173,12 +180,12 @@ class BaseMeshClosure:
173
180
  def make_node_tags(self, predicate: Callable[[np.ndarray], np.ndarray], tag: int, base: Optional[np.ndarray] = None) -> jnp.ndarray:
174
181
  """
175
182
  Build a node_tags array by applying predicate to coords and setting tag where True.
176
- Returns a jnp.ndarray (int32). Does not mutate the mesh.
183
+ Returns a jnp.ndarray (int64). Does not mutate the mesh.
177
184
  """
178
- base_tags = np.zeros(self.n_nodes, dtype=np.int32) if base is None else np.asarray(base, dtype=np.int32).copy()
185
+ base_tags = np.zeros(self.n_nodes, dtype=NP_INDEX_DTYPE) if base is None else np.asarray(base, dtype=NP_INDEX_DTYPE).copy()
179
186
  mask = predicate(np.asarray(self.coords))
180
187
  base_tags[mask] = int(tag)
181
- return jnp.asarray(base_tags, dtype=jnp.int32)
188
+ return jnp.asarray(base_tags, dtype=INDEX_DTYPE)
182
189
 
183
190
  def with_node_tags(self, node_tags: np.ndarray | jnp.ndarray):
184
191
  """
@@ -212,8 +219,8 @@ class BaseMeshClosure:
212
219
 
213
220
  if not facet_map:
214
221
  if tag is None:
215
- return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32)
216
- return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
222
+ return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=INDEX_DTYPE)
223
+ return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=INDEX_DTYPE), jnp.empty((0,), dtype=INDEX_DTYPE)
217
224
 
218
225
  facets = []
219
226
  tags = []
@@ -222,14 +229,318 @@ class BaseMeshClosure:
222
229
  if tag is not None:
223
230
  tags.append(t if t is not None else 0)
224
231
 
225
- facets_arr = jnp.array(facets, dtype=jnp.int32)
232
+ facets_arr = jnp.array(facets, dtype=INDEX_DTYPE)
226
233
  if tag is None:
227
234
  return facets_arr
228
- return facets_arr, jnp.array(tags, dtype=jnp.int32)
235
+ return facets_arr, jnp.array(tags, dtype=INDEX_DTYPE)
236
+
237
+ def boundary_facets_plane(
238
+ self,
239
+ axis: int = 2,
240
+ value: float = 0.0,
241
+ *,
242
+ tol: float = 1e-8,
243
+ tag: int | None = None,
244
+ ):
245
+ """
246
+ Boundary facets on the plane x[axis] == value (within tol).
247
+ """
248
+ def pred(face: np.ndarray) -> bool:
249
+ return bool(np.allclose(face[:, axis], value, atol=tol))
250
+ return self.boundary_facets_where(pred, tag=tag)
251
+
252
+ def facets_on_plane(
253
+ self,
254
+ axis: int = 2,
255
+ value: float = 0.0,
256
+ *,
257
+ tol: float = 1e-8,
258
+ tag: int | None = None,
259
+ ):
260
+ """Alias for boundary_facets_plane (skfem-like naming)."""
261
+ cache = getattr(self, "_boundary_facets_cache", None)
262
+ key = ("plane", int(axis), float(value), float(tol), int(tag) if tag is not None else None)
263
+ if cache is not None and key in cache:
264
+ return cache[key]
265
+ coords = np.asarray(self.coords)
266
+ conn = np.asarray(self.conn)
267
+ patterns = self.face_node_patterns()
268
+ if patterns:
269
+ facets_list = []
270
+ for pattern in patterns:
271
+ face_nodes = conn[:, pattern]
272
+ face_coords = coords[face_nodes]
273
+ mask = np.all(np.isclose(face_coords[..., axis], value, atol=tol), axis=1)
274
+ if np.any(mask):
275
+ facets_list.append(face_nodes[mask])
276
+ if facets_list:
277
+ facets = np.concatenate(facets_list, axis=0)
278
+ keys = np.sort(facets, axis=1)
279
+ _, idx = np.unique(keys, axis=0, return_index=True)
280
+ facets = facets[np.sort(idx)]
281
+ else:
282
+ facets = np.empty((0, len(patterns[0])), dtype=int)
283
+ facets = jnp.asarray(facets, dtype=INDEX_DTYPE)
284
+ if tag is not None:
285
+ tags = jnp.full((facets.shape[0],), int(tag), dtype=INDEX_DTYPE)
286
+ facets = (facets, tags)
287
+ else:
288
+ facets = self.boundary_facets_plane(axis=axis, value=value, tol=tol, tag=tag)
289
+ if cache is None:
290
+ cache = {}
291
+ setattr(self, "_boundary_facets_cache", cache)
292
+ cache[key] = facets
293
+ return facets
294
+
295
+ def boundary_facets_plane_box(
296
+ self,
297
+ axis: int,
298
+ value: float,
299
+ *,
300
+ ranges: Sequence[tuple[float, float] | None] | None = None,
301
+ mode: str = "centroid",
302
+ tol: float = 1e-8,
303
+ tag: int | None = None,
304
+ ):
305
+ """
306
+ Boundary facets on a plane with additional box constraints.
307
+
308
+ ranges: sequence of (min, max) or None per axis. The plane axis can be None.
309
+ mode: "centroid" checks the face centroid, "all" requires all vertices inside.
310
+ """
311
+ dim = int(np.asarray(self.coords).shape[1])
312
+ if ranges is None:
313
+ ranges = [None] * dim
314
+ if len(ranges) != dim:
315
+ raise ValueError("ranges must have length equal to mesh dimension.")
316
+ if mode not in ("centroid", "all"):
317
+ raise ValueError("mode must be 'centroid' or 'all'.")
318
+
319
+ def pred(face: np.ndarray) -> bool:
320
+ if not np.allclose(face[:, axis], value, atol=tol):
321
+ return False
322
+ pts = face.mean(axis=0)[None, :] if mode == "centroid" else face
323
+ for ax, bounds in enumerate(ranges):
324
+ if bounds is None or ax == axis:
325
+ continue
326
+ lo, hi = bounds
327
+ if np.any(pts[:, ax] < lo - tol) or np.any(pts[:, ax] > hi + tol):
328
+ return False
329
+ return True
330
+
331
+ return self.boundary_facets_where(pred, tag=tag)
332
+
333
+ def facets_on_plane_box(
334
+ self,
335
+ axis: int,
336
+ value: float,
337
+ *,
338
+ x: tuple[float, float] | None = None,
339
+ y: tuple[float, float] | None = None,
340
+ z: tuple[float, float] | None = None,
341
+ ranges: Sequence[tuple[float, float] | None] | None = None,
342
+ mode: str = "centroid",
343
+ tol: float = 1e-8,
344
+ tag: int | None = None,
345
+ ):
346
+ """
347
+ Alias for boundary_facets_plane_box with axis-aligned range helpers.
348
+ Provide x/y/z or a full ranges sequence.
349
+ """
350
+ dim = int(np.asarray(self.coords).shape[1])
351
+ if ranges is None:
352
+ ranges = [None] * dim
353
+ if dim > 0:
354
+ ranges[0] = x
355
+ if dim > 1:
356
+ ranges[1] = y
357
+ if dim > 2:
358
+ ranges[2] = z
359
+ cache = getattr(self, "_boundary_facets_cache", None)
360
+ ranges_key = tuple(ranges)
361
+ key = (
362
+ "plane_box",
363
+ int(axis),
364
+ float(value),
365
+ ranges_key,
366
+ str(mode),
367
+ float(tol),
368
+ int(tag) if tag is not None else None,
369
+ )
370
+ if cache is not None and key in cache:
371
+ return cache[key]
372
+ coords = np.asarray(self.coords)
373
+ conn = np.asarray(self.conn)
374
+ patterns = self.face_node_patterns()
375
+ if patterns:
376
+ facets_list = []
377
+ for pattern in patterns:
378
+ face_nodes = conn[:, pattern]
379
+ face_coords = coords[face_nodes]
380
+ mask = np.all(np.isclose(face_coords[..., axis], value, atol=tol), axis=1)
381
+ if np.any(mask):
382
+ if mode == "centroid":
383
+ pts = face_coords[mask].mean(axis=1)
384
+ mask_local = np.ones(pts.shape[0], dtype=bool)
385
+ for ax, bounds in enumerate(ranges):
386
+ if bounds is None or ax == axis:
387
+ continue
388
+ lo, hi = bounds
389
+ mask_local &= (pts[:, ax] >= lo - tol) & (pts[:, ax] <= hi + tol)
390
+ face_nodes = face_nodes[mask][mask_local]
391
+ else:
392
+ face_coords = face_coords[mask]
393
+ mask_local = np.ones(face_coords.shape[0], dtype=bool)
394
+ for ax, bounds in enumerate(ranges):
395
+ if bounds is None or ax == axis:
396
+ continue
397
+ lo, hi = bounds
398
+ in_range = (face_coords[..., ax] >= lo - tol) & (face_coords[..., ax] <= hi + tol)
399
+ mask_local &= np.all(in_range, axis=1)
400
+ face_nodes = face_nodes[mask][mask_local]
401
+ if face_nodes.size:
402
+ facets_list.append(face_nodes)
403
+ if facets_list:
404
+ facets = np.concatenate(facets_list, axis=0)
405
+ keys = np.sort(facets, axis=1)
406
+ _, idx = np.unique(keys, axis=0, return_index=True)
407
+ facets = facets[np.sort(idx)]
408
+ else:
409
+ facets = np.empty((0, len(patterns[0])), dtype=int)
410
+ facets = jnp.asarray(facets, dtype=INDEX_DTYPE)
411
+ if tag is not None:
412
+ tags = jnp.full((facets.shape[0],), int(tag), dtype=INDEX_DTYPE)
413
+ facets = (facets, tags)
414
+ else:
415
+ facets = self.boundary_facets_plane_box(
416
+ axis=axis,
417
+ value=value,
418
+ ranges=ranges,
419
+ mode=mode,
420
+ tol=tol,
421
+ tag=tag,
422
+ )
423
+ if cache is None:
424
+ cache = {}
425
+ setattr(self, "_boundary_facets_cache", cache)
426
+ cache[key] = facets
427
+ return facets
428
+
429
+ def boundary_dofs_plane(
430
+ self,
431
+ axis: int = 2,
432
+ value: float = 0.0,
433
+ *,
434
+ components: Sequence[int] | str = "xyz",
435
+ dof_per_node: Optional[int] = None,
436
+ tol: float = 1e-8,
437
+ ) -> np.ndarray:
438
+ """
439
+ DOF indices for boundary nodes on the plane x[axis] == value (within tol).
440
+ """
441
+ def pred(coords: np.ndarray) -> np.ndarray:
442
+ return np.isclose(coords[:, axis], value, atol=tol)
443
+ return self.boundary_dofs_where(pred, components=components, dof_per_node=dof_per_node)
444
+
445
+ def elements_touching_nodes(self, nodes: Iterable[int]) -> np.ndarray:
446
+ """
447
+ Return element indices that touch any node in the provided set.
448
+ """
449
+ nodes_arr = np.asarray(list(nodes), dtype=int)
450
+ if nodes_arr.size == 0:
451
+ return np.asarray([], dtype=int)
452
+ mark = np.zeros(self.n_nodes, dtype=bool)
453
+ mark[nodes_arr] = True
454
+ conn = np.asarray(self.conn)
455
+ return np.nonzero(np.any(mark[conn], axis=1))[0]
456
+
457
+ def nodes_from_facets(self, facets: np.ndarray | jnp.ndarray) -> np.ndarray:
458
+ """
459
+ Return unique node indices contained in the given facets array.
460
+ """
461
+ facets_arr = np.asarray(facets, dtype=int)
462
+ if facets_arr.size == 0:
463
+ return np.asarray([], dtype=int)
464
+ return np.unique(facets_arr.reshape(-1))
465
+
466
+ def elements_from_nodes(self, nodes: Iterable[int]) -> np.ndarray:
467
+ """
468
+ Alias for elements_touching_nodes (skfem-like naming).
469
+ """
470
+ return self.elements_touching_nodes(nodes)
471
+
472
+ def elements_from_facets(self, facets: np.ndarray | jnp.ndarray, *, mode: str = "touching") -> np.ndarray:
473
+ """
474
+ Return element indices associated with facets.
475
+
476
+ mode:
477
+ - "touching": any element sharing at least one facet node.
478
+ - "adjacent": elements that own the facet (exact face match).
479
+ """
480
+ if mode not in ("touching", "adjacent"):
481
+ raise ValueError("mode must be 'touching' or 'adjacent'.")
482
+ facets_arr = np.asarray(facets, dtype=int)
483
+ if facets_arr.size == 0:
484
+ return np.asarray([], dtype=int)
485
+ if mode == "touching":
486
+ nodes = self.nodes_from_facets(facets_arr)
487
+ return self.elements_touching_nodes(nodes)
488
+
489
+ patterns = self.face_node_patterns()
490
+ facet_keys = {tuple(sorted(face)) for face in facets_arr}
491
+ conn = np.asarray(self.conn)
492
+ elems = []
493
+ for e_idx, elem_conn in enumerate(conn):
494
+ for pattern in patterns:
495
+ face_nodes = tuple(sorted(int(elem_conn[i]) for i in pattern))
496
+ if face_nodes in facet_keys:
497
+ elems.append(e_idx)
498
+ break
499
+ return np.asarray(elems, dtype=int)
500
+
501
+ def elements_touching_facets(self, facets: np.ndarray | jnp.ndarray) -> np.ndarray:
502
+ """
503
+ Return element indices that touch any node in the provided facets.
504
+ """
505
+ facets_arr = np.asarray(facets, dtype=int)
506
+ if facets_arr.size == 0:
507
+ return np.asarray([], dtype=int)
508
+ nodes = np.unique(facets_arr.reshape(-1))
509
+ return self.elements_touching_nodes(nodes)
510
+
511
+ def surface_from_facets(self, facets, *, facet_tags=None):
512
+ """
513
+ Build a SurfaceMesh from facet connectivity.
514
+ """
515
+ from .surface import SurfaceMesh
516
+ return SurfaceMesh.from_facets(self.coords, facets, facet_tags=facet_tags, node_tags=self.node_tags)
517
+
518
+ def surface_with_elem_conn_from_facets(self, facets, *, mode: str = "touching"):
519
+ """
520
+ Build a SurfaceMesh and matching elem_conn for the given facets.
521
+ """
522
+ from .surface import surface_with_elem_conn
523
+ return surface_with_elem_conn(self, facets, mode=mode)
524
+
525
+ def surface_with_facet_map_from_facets(self, facets):
526
+ """
527
+ Build a SurfaceMesh and facet-to-element map for the given facets.
528
+ """
529
+ surface = self.surface_from_facets(facets)
530
+ conn = np.asarray(self.conn, dtype=int)
531
+ from .mortar import map_surface_facets_to_tet_elements, map_surface_facets_to_hex_elements
532
+ if conn.shape[1] in {4, 10}:
533
+ facet_map = map_surface_facets_to_tet_elements(surface, conn)
534
+ elif conn.shape[1] in {8, 20, 27}:
535
+ facet_map = map_surface_facets_to_hex_elements(surface, conn)
536
+ else:
537
+ raise NotImplementedError("elem_conn must be tet/hex (4/10/8/20/27)")
538
+ return SurfaceWithFacetMap(surface=surface, facet_map=facet_map)
229
539
 
230
540
 
231
541
  @jax.tree_util.register_pytree_node_class
232
542
  class BaseMeshPytree(BaseMeshClosure):
543
+ """BaseMesh variant that registers as a JAX pytree."""
233
544
  def tree_flatten(self):
234
545
  children = (self.coords, self.conn, self.cell_tags, self.node_tags)
235
546
  return children, {}
@@ -241,4 +552,7 @@ class BaseMeshPytree(BaseMeshClosure):
241
552
 
242
553
 
243
554
  BaseMesh = BaseMeshClosure
244
-
555
+ @dataclass(frozen=True)
556
+ class SurfaceWithFacetMap:
557
+ surface: object
558
+ facet_map: np.ndarray