fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/helpers_wf.py CHANGED
@@ -19,11 +19,57 @@ from .core.weakform import (
19
19
  matmul,
20
20
  matmul_std,
21
21
  log,
22
+ einsum,
22
23
  normal,
23
24
  ds,
24
25
  dOmega,
26
+ ParamRef,
25
27
  )
26
28
 
29
+
30
+ def _voigt_A() -> tuple[tuple[tuple[float, ...], ...], ...]:
31
+ return (
32
+ ((1.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.5, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.5)),
33
+ ((0.0, 0.0, 0.0, 0.5, 0.0, 0.0), (0.0, 1.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.5, 0.0)),
34
+ ((0.0, 0.0, 0.0, 0.0, 0.0, 0.5), (0.0, 0.0, 0.0, 0.0, 0.5, 0.0), (0.0, 0.0, 1.0, 0.0, 0.0, 0.0)),
35
+ )
36
+
37
+
38
+ def _identity3() -> tuple[tuple[float, float, float], ...]:
39
+ return (
40
+ (1.0, 0.0, 0.0),
41
+ (0.0, 1.0, 0.0),
42
+ (0.0, 0.0, 1.0),
43
+ )
44
+
45
+
46
+ def voigt_to_tensor(sym_grad_u, p=None):
47
+ """
48
+ Convert Voigt-form symmetric gradient to a 3x3 tensor.
49
+
50
+ Uses the standard Voigt mapping with 1/2 on shear terms.
51
+ """
52
+ A = _voigt_A()
53
+ if p is not None and not isinstance(p, ParamRef):
54
+ A = getattr(p, "A", A)
55
+ return einsum("ijk,qk...->qij...", A, sym_grad_u)
56
+
57
+
58
+ def linear_stress(sym_grad_u, p):
59
+ """Linear elastic stress from symmetric gradient in Voigt notation."""
60
+ I = _identity3()
61
+ if not isinstance(p, ParamRef):
62
+ I = getattr(p, "I", I)
63
+ eps = voigt_to_tensor(sym_grad_u, p)
64
+ tr = einsum("ij,qij...->q...", I, eps)
65
+ return p.lam * einsum("q...,ij->qij...", tr, I) + 2.0 * p.mu * eps
66
+
67
+
68
+ def traction(field, n, p):
69
+ """Traction vector for a field using linear elastic stress."""
70
+ stress = linear_stress(sym_grad(field), p)
71
+ return einsum("qij...,qj->qi...", stress, n)
72
+
27
73
  __all__ = [
28
74
  "grad",
29
75
  "sym_grad",
@@ -42,6 +88,9 @@ __all__ = [
42
88
  "matmul",
43
89
  "matmul_std",
44
90
  "log",
91
+ "voigt_to_tensor",
92
+ "linear_stress",
93
+ "traction",
45
94
  "normal",
46
95
  "ds",
47
96
  "dOmega",
fluxfem/mesh/__init__.py CHANGED
@@ -1,13 +1,40 @@
1
+ from .base import BaseMesh, BaseMeshPytree, SurfaceWithFacetMap
1
2
  from .hex import HexMesh, HexMeshPytree, StructuredHexBox, tag_axis_minmax_facets
2
3
  from .tet import TetMesh, TetMeshPytree, StructuredTetBox, StructuredTetTensorBox
3
- from .base import BaseMesh, BaseMeshPytree
4
4
  from .predicate import bbox_predicate, plane_predicate, axis_plane_predicate, slab_predicate
5
- from .surface import SurfaceMesh, SurfaceMeshPytree
5
+ from .surface import SurfaceMesh, SurfaceMeshPytree, SurfaceWithElemConn, surface_with_elem_conn
6
+ from .supermesh import SurfaceSupermesh, build_surface_supermesh
7
+ from .mortar import (
8
+ MortarMatrix,
9
+ assemble_mortar_matrices,
10
+ assemble_contact_onesided_floor,
11
+ map_surface_facets_to_tet_elements,
12
+ map_surface_facets_to_hex_elements,
13
+ assemble_mixed_surface_jacobian,
14
+ assemble_mixed_surface_residual,
15
+ tri_area,
16
+ tri_quadrature,
17
+ facet_triangles,
18
+ facet_shape_values,
19
+ volume_shape_values_at_points,
20
+ quad_shape_and_local,
21
+ quad9_shape_values,
22
+ hex27_gradN,
23
+ )
24
+ from .contact import (
25
+ ContactSurfaceSpace,
26
+ ContactSide,
27
+ OneSidedContact,
28
+ OneSidedContactSurfaceSpace,
29
+ facet_gap_values,
30
+ active_contact_facets,
31
+ )
6
32
  from .io import load_gmsh_mesh, load_gmsh_hex_mesh, load_gmsh_tet_mesh, make_surface_from_facets
7
33
 
8
34
  __all__ = [
9
35
  "BaseMesh",
10
36
  "BaseMeshPytree",
37
+ "SurfaceWithFacetMap",
11
38
  "bbox_predicate",
12
39
  "plane_predicate",
13
40
  "axis_plane_predicate",
@@ -22,6 +49,31 @@ __all__ = [
22
49
  "StructuredTetTensorBox",
23
50
  "SurfaceMesh",
24
51
  "SurfaceMeshPytree",
52
+ "SurfaceWithElemConn",
53
+ "surface_with_elem_conn",
54
+ "SurfaceSupermesh",
55
+ "build_surface_supermesh",
56
+ "MortarMatrix",
57
+ "assemble_mortar_matrices",
58
+ "assemble_contact_onesided_floor",
59
+ "assemble_mixed_surface_residual",
60
+ "assemble_mixed_surface_jacobian",
61
+ "map_surface_facets_to_tet_elements",
62
+ "map_surface_facets_to_hex_elements",
63
+ "tri_area",
64
+ "tri_quadrature",
65
+ "facet_triangles",
66
+ "facet_shape_values",
67
+ "volume_shape_values_at_points",
68
+ "quad_shape_and_local",
69
+ "quad9_shape_values",
70
+ "hex27_gradN",
71
+ "ContactSurfaceSpace",
72
+ "ContactSide",
73
+ "OneSidedContact",
74
+ "OneSidedContactSurfaceSpace",
75
+ "facet_gap_values",
76
+ "active_contact_facets",
25
77
  "load_gmsh_mesh",
26
78
  "load_gmsh_hex_mesh",
27
79
  "load_gmsh_tet_mesh",
fluxfem/mesh/base.py CHANGED
@@ -5,6 +5,8 @@ import numpy as np
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
 
8
+ from .dtypes import INDEX_DTYPE, NP_INDEX_DTYPE
9
+
8
10
 
9
11
  @dataclass
10
12
  class BaseMeshClosure:
@@ -158,7 +160,7 @@ class BaseMeshClosure:
158
160
  for pattern in patterns:
159
161
  nodes = tuple(sorted(int(elem_conn[i]) for i in pattern))
160
162
  face_counts[nodes] = face_counts.get(nodes, 0) + 1
161
- bnodes = set()
163
+ bnodes: set[int] = set()
162
164
  for nodes, count in face_counts.items():
163
165
  if count == 1:
164
166
  bnodes.update(nodes)
@@ -170,7 +172,7 @@ class BaseMeshClosure:
170
172
  """
171
173
  Return boolean mask for boundary nodes (shape: n_nodes).
172
174
  """
173
- mask = np.zeros(self.n_nodes, dtype=bool)
175
+ mask: np.ndarray = np.zeros(self.n_nodes, dtype=bool)
174
176
  nodes = self.boundary_node_indices()
175
177
  mask[nodes] = True
176
178
  return mask
@@ -178,12 +180,12 @@ class BaseMeshClosure:
178
180
  def make_node_tags(self, predicate: Callable[[np.ndarray], np.ndarray], tag: int, base: Optional[np.ndarray] = None) -> jnp.ndarray:
179
181
  """
180
182
  Build a node_tags array by applying predicate to coords and setting tag where True.
181
- Returns a jnp.ndarray (int32). Does not mutate the mesh.
183
+ Returns a jnp.ndarray (int64). Does not mutate the mesh.
182
184
  """
183
- base_tags = np.zeros(self.n_nodes, dtype=np.int32) if base is None else np.asarray(base, dtype=np.int32).copy()
185
+ base_tags = np.zeros(self.n_nodes, dtype=NP_INDEX_DTYPE) if base is None else np.asarray(base, dtype=NP_INDEX_DTYPE).copy()
184
186
  mask = predicate(np.asarray(self.coords))
185
187
  base_tags[mask] = int(tag)
186
- return jnp.asarray(base_tags, dtype=jnp.int32)
188
+ return jnp.asarray(base_tags, dtype=INDEX_DTYPE)
187
189
 
188
190
  def with_node_tags(self, node_tags: np.ndarray | jnp.ndarray):
189
191
  """
@@ -217,8 +219,8 @@ class BaseMeshClosure:
217
219
 
218
220
  if not facet_map:
219
221
  if tag is None:
220
- return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32)
221
- return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
222
+ return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=INDEX_DTYPE)
223
+ return jnp.empty((0, len(patterns[0]) if patterns else 0), dtype=INDEX_DTYPE), jnp.empty((0,), dtype=INDEX_DTYPE)
222
224
 
223
225
  facets = []
224
226
  tags = []
@@ -227,10 +229,313 @@ class BaseMeshClosure:
227
229
  if tag is not None:
228
230
  tags.append(t if t is not None else 0)
229
231
 
230
- facets_arr = jnp.array(facets, dtype=jnp.int32)
232
+ facets_arr = jnp.array(facets, dtype=INDEX_DTYPE)
231
233
  if tag is None:
232
234
  return facets_arr
233
- return facets_arr, jnp.array(tags, dtype=jnp.int32)
235
+ return facets_arr, jnp.array(tags, dtype=INDEX_DTYPE)
236
+
237
+ def boundary_facets_plane(
238
+ self,
239
+ axis: int = 2,
240
+ value: float = 0.0,
241
+ *,
242
+ tol: float = 1e-8,
243
+ tag: int | None = None,
244
+ ):
245
+ """
246
+ Boundary facets on the plane x[axis] == value (within tol).
247
+ """
248
+ def pred(face: np.ndarray) -> bool:
249
+ return bool(np.allclose(face[:, axis], value, atol=tol))
250
+ return self.boundary_facets_where(pred, tag=tag)
251
+
252
+ def facets_on_plane(
253
+ self,
254
+ axis: int = 2,
255
+ value: float = 0.0,
256
+ *,
257
+ tol: float = 1e-8,
258
+ tag: int | None = None,
259
+ ):
260
+ """Alias for boundary_facets_plane (skfem-like naming)."""
261
+ cache = getattr(self, "_boundary_facets_cache", None)
262
+ key = ("plane", int(axis), float(value), float(tol), int(tag) if tag is not None else None)
263
+ if cache is not None and key in cache:
264
+ return cache[key]
265
+ coords = np.asarray(self.coords)
266
+ conn = np.asarray(self.conn)
267
+ patterns = self.face_node_patterns()
268
+ if patterns:
269
+ facets_list = []
270
+ for pattern in patterns:
271
+ face_nodes = conn[:, pattern]
272
+ face_coords = coords[face_nodes]
273
+ mask = np.all(np.isclose(face_coords[..., axis], value, atol=tol), axis=1)
274
+ if np.any(mask):
275
+ facets_list.append(face_nodes[mask])
276
+ if facets_list:
277
+ facets = np.concatenate(facets_list, axis=0)
278
+ keys = np.sort(facets, axis=1)
279
+ _, idx = np.unique(keys, axis=0, return_index=True)
280
+ facets = facets[np.sort(idx)]
281
+ else:
282
+ facets = np.empty((0, len(patterns[0])), dtype=int)
283
+ facets = jnp.asarray(facets, dtype=INDEX_DTYPE)
284
+ if tag is not None:
285
+ tags = jnp.full((facets.shape[0],), int(tag), dtype=INDEX_DTYPE)
286
+ facets = (facets, tags)
287
+ else:
288
+ facets = self.boundary_facets_plane(axis=axis, value=value, tol=tol, tag=tag)
289
+ if cache is None:
290
+ cache = {}
291
+ setattr(self, "_boundary_facets_cache", cache)
292
+ cache[key] = facets
293
+ return facets
294
+
295
+ def boundary_facets_plane_box(
296
+ self,
297
+ axis: int,
298
+ value: float,
299
+ *,
300
+ ranges: Sequence[tuple[float, float] | None] | None = None,
301
+ mode: str = "centroid",
302
+ tol: float = 1e-8,
303
+ tag: int | None = None,
304
+ ):
305
+ """
306
+ Boundary facets on a plane with additional box constraints.
307
+
308
+ ranges: sequence of (min, max) or None per axis. The plane axis can be None.
309
+ mode: "centroid" checks the face centroid, "all" requires all vertices inside.
310
+ """
311
+ dim = int(np.asarray(self.coords).shape[1])
312
+ if ranges is None:
313
+ ranges = [None] * dim
314
+ if len(ranges) != dim:
315
+ raise ValueError("ranges must have length equal to mesh dimension.")
316
+ if mode not in ("centroid", "all"):
317
+ raise ValueError("mode must be 'centroid' or 'all'.")
318
+
319
+ def pred(face: np.ndarray) -> bool:
320
+ if not np.allclose(face[:, axis], value, atol=tol):
321
+ return False
322
+ pts = face.mean(axis=0)[None, :] if mode == "centroid" else face
323
+ for ax, bounds in enumerate(ranges):
324
+ if bounds is None or ax == axis:
325
+ continue
326
+ lo, hi = bounds
327
+ if np.any(pts[:, ax] < lo - tol) or np.any(pts[:, ax] > hi + tol):
328
+ return False
329
+ return True
330
+
331
+ return self.boundary_facets_where(pred, tag=tag)
332
+
333
+ def facets_on_plane_box(
334
+ self,
335
+ axis: int,
336
+ value: float,
337
+ *,
338
+ x: tuple[float, float] | None = None,
339
+ y: tuple[float, float] | None = None,
340
+ z: tuple[float, float] | None = None,
341
+ ranges: Sequence[tuple[float, float] | None] | None = None,
342
+ mode: str = "centroid",
343
+ tol: float = 1e-8,
344
+ tag: int | None = None,
345
+ ):
346
+ """
347
+ Alias for boundary_facets_plane_box with axis-aligned range helpers.
348
+ Provide x/y/z or a full ranges sequence.
349
+ """
350
+ dim = int(np.asarray(self.coords).shape[1])
351
+ if ranges is None:
352
+ ranges = [None] * dim
353
+ if dim > 0:
354
+ ranges[0] = x
355
+ if dim > 1:
356
+ ranges[1] = y
357
+ if dim > 2:
358
+ ranges[2] = z
359
+ cache = getattr(self, "_boundary_facets_cache", None)
360
+ ranges_key = tuple(ranges)
361
+ key = (
362
+ "plane_box",
363
+ int(axis),
364
+ float(value),
365
+ ranges_key,
366
+ str(mode),
367
+ float(tol),
368
+ int(tag) if tag is not None else None,
369
+ )
370
+ if cache is not None and key in cache:
371
+ return cache[key]
372
+ coords = np.asarray(self.coords)
373
+ conn = np.asarray(self.conn)
374
+ patterns = self.face_node_patterns()
375
+ if patterns:
376
+ facets_list = []
377
+ for pattern in patterns:
378
+ face_nodes = conn[:, pattern]
379
+ face_coords = coords[face_nodes]
380
+ mask = np.all(np.isclose(face_coords[..., axis], value, atol=tol), axis=1)
381
+ if np.any(mask):
382
+ if mode == "centroid":
383
+ pts = face_coords[mask].mean(axis=1)
384
+ mask_local = np.ones(pts.shape[0], dtype=bool)
385
+ for ax, bounds in enumerate(ranges):
386
+ if bounds is None or ax == axis:
387
+ continue
388
+ lo, hi = bounds
389
+ mask_local &= (pts[:, ax] >= lo - tol) & (pts[:, ax] <= hi + tol)
390
+ face_nodes = face_nodes[mask][mask_local]
391
+ else:
392
+ face_coords = face_coords[mask]
393
+ mask_local = np.ones(face_coords.shape[0], dtype=bool)
394
+ for ax, bounds in enumerate(ranges):
395
+ if bounds is None or ax == axis:
396
+ continue
397
+ lo, hi = bounds
398
+ in_range = (face_coords[..., ax] >= lo - tol) & (face_coords[..., ax] <= hi + tol)
399
+ mask_local &= np.all(in_range, axis=1)
400
+ face_nodes = face_nodes[mask][mask_local]
401
+ if face_nodes.size:
402
+ facets_list.append(face_nodes)
403
+ if facets_list:
404
+ facets = np.concatenate(facets_list, axis=0)
405
+ keys = np.sort(facets, axis=1)
406
+ _, idx = np.unique(keys, axis=0, return_index=True)
407
+ facets = facets[np.sort(idx)]
408
+ else:
409
+ facets = np.empty((0, len(patterns[0])), dtype=int)
410
+ facets = jnp.asarray(facets, dtype=INDEX_DTYPE)
411
+ if tag is not None:
412
+ tags = jnp.full((facets.shape[0],), int(tag), dtype=INDEX_DTYPE)
413
+ facets = (facets, tags)
414
+ else:
415
+ facets = self.boundary_facets_plane_box(
416
+ axis=axis,
417
+ value=value,
418
+ ranges=ranges,
419
+ mode=mode,
420
+ tol=tol,
421
+ tag=tag,
422
+ )
423
+ if cache is None:
424
+ cache = {}
425
+ setattr(self, "_boundary_facets_cache", cache)
426
+ cache[key] = facets
427
+ return facets
428
+
429
+ def boundary_dofs_plane(
430
+ self,
431
+ axis: int = 2,
432
+ value: float = 0.0,
433
+ *,
434
+ components: Sequence[int] | str = "xyz",
435
+ dof_per_node: Optional[int] = None,
436
+ tol: float = 1e-8,
437
+ ) -> np.ndarray:
438
+ """
439
+ DOF indices for boundary nodes on the plane x[axis] == value (within tol).
440
+ """
441
+ def pred(coords: np.ndarray) -> np.ndarray:
442
+ return np.isclose(coords[:, axis], value, atol=tol)
443
+ return self.boundary_dofs_where(pred, components=components, dof_per_node=dof_per_node)
444
+
445
+ def elements_touching_nodes(self, nodes: Iterable[int]) -> np.ndarray:
446
+ """
447
+ Return element indices that touch any node in the provided set.
448
+ """
449
+ nodes_arr = np.asarray(list(nodes), dtype=int)
450
+ if nodes_arr.size == 0:
451
+ return np.asarray([], dtype=int)
452
+ mark: np.ndarray = np.zeros(self.n_nodes, dtype=bool)
453
+ mark[nodes_arr] = True
454
+ conn = np.asarray(self.conn)
455
+ return np.nonzero(np.any(mark[conn], axis=1))[0]
456
+
457
+ def nodes_from_facets(self, facets: np.ndarray | jnp.ndarray) -> np.ndarray:
458
+ """
459
+ Return unique node indices contained in the given facets array.
460
+ """
461
+ facets_arr = np.asarray(facets, dtype=int)
462
+ if facets_arr.size == 0:
463
+ return np.asarray([], dtype=int)
464
+ return np.unique(facets_arr.reshape(-1))
465
+
466
+ def elements_from_nodes(self, nodes: Iterable[int]) -> np.ndarray:
467
+ """
468
+ Alias for elements_touching_nodes (skfem-like naming).
469
+ """
470
+ return self.elements_touching_nodes(nodes)
471
+
472
+ def elements_from_facets(self, facets: np.ndarray | jnp.ndarray, *, mode: str = "touching") -> np.ndarray:
473
+ """
474
+ Return element indices associated with facets.
475
+
476
+ mode:
477
+ - "touching": any element sharing at least one facet node.
478
+ - "adjacent": elements that own the facet (exact face match).
479
+ """
480
+ if mode not in ("touching", "adjacent"):
481
+ raise ValueError("mode must be 'touching' or 'adjacent'.")
482
+ facets_arr = np.asarray(facets, dtype=int)
483
+ if facets_arr.size == 0:
484
+ return np.asarray([], dtype=int)
485
+ if mode == "touching":
486
+ nodes = self.nodes_from_facets(facets_arr)
487
+ return self.elements_touching_nodes(nodes)
488
+
489
+ patterns = self.face_node_patterns()
490
+ facet_keys = {tuple(sorted(face)) for face in facets_arr}
491
+ conn = np.asarray(self.conn)
492
+ elems = []
493
+ for e_idx, elem_conn in enumerate(conn):
494
+ for pattern in patterns:
495
+ face_nodes = tuple(sorted(int(elem_conn[i]) for i in pattern))
496
+ if face_nodes in facet_keys:
497
+ elems.append(e_idx)
498
+ break
499
+ return np.asarray(elems, dtype=int)
500
+
501
+ def elements_touching_facets(self, facets: np.ndarray | jnp.ndarray) -> np.ndarray:
502
+ """
503
+ Return element indices that touch any node in the provided facets.
504
+ """
505
+ facets_arr = np.asarray(facets, dtype=int)
506
+ if facets_arr.size == 0:
507
+ return np.asarray([], dtype=int)
508
+ nodes = np.unique(facets_arr.reshape(-1))
509
+ return self.elements_touching_nodes(nodes)
510
+
511
+ def surface_from_facets(self, facets, *, facet_tags=None):
512
+ """
513
+ Build a SurfaceMesh from facet connectivity.
514
+ """
515
+ from .surface import SurfaceMesh
516
+ return SurfaceMesh.from_facets(self.coords, facets, facet_tags=facet_tags, node_tags=self.node_tags)
517
+
518
+ def surface_with_elem_conn_from_facets(self, facets, *, mode: str = "touching"):
519
+ """
520
+ Build a SurfaceMesh and matching elem_conn for the given facets.
521
+ """
522
+ from .surface import surface_with_elem_conn
523
+ return surface_with_elem_conn(self, facets, mode=mode)
524
+
525
+ def surface_with_facet_map_from_facets(self, facets):
526
+ """
527
+ Build a SurfaceMesh and facet-to-element map for the given facets.
528
+ """
529
+ surface = self.surface_from_facets(facets)
530
+ conn = np.asarray(self.conn, dtype=int)
531
+ from .mortar import map_surface_facets_to_tet_elements, map_surface_facets_to_hex_elements
532
+ if conn.shape[1] in {4, 10}:
533
+ facet_map = map_surface_facets_to_tet_elements(surface, conn)
534
+ elif conn.shape[1] in {8, 20, 27}:
535
+ facet_map = map_surface_facets_to_hex_elements(surface, conn)
536
+ else:
537
+ raise NotImplementedError("elem_conn must be tet/hex (4/10/8/20/27)")
538
+ return SurfaceWithFacetMap(surface=surface, facet_map=facet_map)
234
539
 
235
540
 
236
541
  @jax.tree_util.register_pytree_node_class
@@ -247,3 +552,7 @@ class BaseMeshPytree(BaseMeshClosure):
247
552
 
248
553
 
249
554
  BaseMesh = BaseMeshClosure
555
+ @dataclass(frozen=True)
556
+ class SurfaceWithFacetMap:
557
+ surface: object
558
+ facet_map: np.ndarray