fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/mesh/mortar.py ADDED
@@ -0,0 +1,3970 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import os
5
+ import time
6
+ from typing import Any, Callable, Iterable, Sequence, TYPE_CHECKING, cast
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+
12
+ from .surface import SurfaceMesh
13
+ from ..core.forms import FormFieldLike
14
+ if TYPE_CHECKING:
15
+ from ..core.forms import FieldPair
16
+ from ..core.weakform import Params as WeakParams
17
+
18
+
19
+ @dataclass(eq=False)
20
+ class _SurfaceBasis:
21
+ dofs_per_node: int
22
+
23
+
24
+ @dataclass(eq=False)
25
+ class SurfaceMixedFormField:
26
+ """Surface form field for mixed weak-form evaluation."""
27
+ N: np.ndarray
28
+ gradN: np.ndarray | None
29
+ value_dim: int
30
+ basis: _SurfaceBasis
31
+
32
+
33
+ @dataclass(eq=False)
34
+ class SurfaceMixedFormContext:
35
+ """Surface mixed context for weak-form evaluation on supermesh."""
36
+ fields: dict[str, "FieldPair"]
37
+ x_q: np.ndarray
38
+ w: np.ndarray
39
+ detJ: np.ndarray
40
+ normal: np.ndarray | None = None
41
+ trial_fields: dict[str, SurfaceMixedFormField] | None = None
42
+ test_fields: dict[str, SurfaceMixedFormField] | None = None
43
+ unknown_fields: dict[str, SurfaceMixedFormField] | None = None
44
+
45
+
46
+ _DEBUG_SURFACE_GRADN = os.getenv("FLUXFEM_DEBUG_SURFACE_GRADN")
47
+ _DEBUG_SURFACE_GRADN_MAX = int(os.getenv("FLUXFEM_DEBUG_SURFACE_GRADN_MAX", "8")) if _DEBUG_SURFACE_GRADN else 0
48
+ _DEBUG_SURFACE_GRADN_COUNT = 0
49
+ _DEBUG_SURFACE_SOURCE_ONCE = False
50
+ _DEBUG_CONTACT_MAP_ONCE = False
51
+ _DEBUG_CONTACT_N_ONCE = False
52
+ _DEBUG_PROJECTION_DIAG = os.getenv("FLUXFEM_PROJ_DIAG")
53
+ _DEBUG_PROJECTION_DIAG_MAX = int(os.getenv("FLUXFEM_PROJ_DIAG_MAX", "20")) if _DEBUG_PROJECTION_DIAG else 0
54
+ _DEBUG_CONTACT_PROJ_ONCE = False
55
+ _DEBUG_PROJ_QP_CACHE = None
56
+ _DEBUG_PROJ_QP_SOURCE = None
57
+ _DEBUG_PROJ_QP_DUMPED = False
58
+ _PROJ_DIAG_STATS: dict[str, Any] | None = None
59
+ _PROJ_DIAG_COUNT = 0
60
+ _PROJ_DIAG_CONTEXT: dict[str, int | str] = {}
61
+
62
+
63
+ def _mortar_dbg_enabled() -> bool:
64
+ return os.getenv("FLUXFEM_MORTAR_DEBUG", "0") not in ("0", "", "false", "False")
65
+
66
+
67
+ def _mortar_dbg(msg: str) -> None:
68
+ if _mortar_dbg_enabled():
69
+ print(msg, flush=True)
70
+
71
+
72
+ def _env_flag(name: str, default: bool) -> bool:
73
+ raw = os.getenv(name)
74
+ if raw is None:
75
+ return default
76
+ return raw not in ("0", "", "false", "False")
77
+
78
+
79
+ @dataclass(eq=False)
80
+ class MortarMatrix:
81
+ """COO storage for mortar coupling matrices (can be rectangular)."""
82
+ rows: np.ndarray
83
+ cols: np.ndarray
84
+ data: np.ndarray
85
+ shape: tuple[int, int]
86
+
87
+
88
+ def _tri_area(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> float:
89
+ return 0.5 * float(np.linalg.norm(np.cross(b - a, c - a)))
90
+
91
+
92
+ def tri_area(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> float:
93
+ """Public wrapper for triangle area (used in contact diagnostics)."""
94
+ return _tri_area(a, b, c)
95
+
96
+
97
+ def tri_quadrature(order: int) -> tuple[np.ndarray, np.ndarray]:
98
+ """Public wrapper for triangle quadrature."""
99
+ return _tri_quadrature(order)
100
+
101
+
102
+ def facet_triangles(coords: np.ndarray, facet_nodes: np.ndarray) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
103
+ """Public wrapper for facet triangulation."""
104
+ return _facet_triangles(coords, facet_nodes)
105
+
106
+
107
+ def facet_shape_values(point: np.ndarray, facet_nodes: np.ndarray, coords: np.ndarray, *, tol: float) -> np.ndarray:
108
+ """Public wrapper for facet shape values at a point."""
109
+ return _facet_shape_values(point, facet_nodes, coords, tol=tol)
110
+
111
+
112
+ def volume_shape_values_at_points(x_q: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
113
+ """Public wrapper for volume shape values at quadrature points."""
114
+ return _volume_shape_values_at_points(x_q, elem_coords, tol=tol)
115
+
116
+
117
+ def quad_shape_and_local(
118
+ point: np.ndarray,
119
+ quad_nodes: np.ndarray,
120
+ corner_coords: np.ndarray,
121
+ *,
122
+ tol: float,
123
+ ) -> tuple[np.ndarray, float, float]:
124
+ """Public wrapper for quad shape values and local coordinates."""
125
+ return _quad_shape_and_local(point, quad_nodes, corner_coords, tol=tol)
126
+
127
+
128
+ def quad9_shape_values(xi: float, eta: float) -> np.ndarray:
129
+ """Public wrapper for quad9 shape values."""
130
+ return _quad9_shape_values(xi, eta)
131
+
132
+
133
+ def hex27_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
134
+ """Public wrapper for hex27 gradN (diagnostics)."""
135
+ return _hex27_gradN(point, elem_coords, tol=tol)
136
+
137
+
138
+ def _quad_quadrature(order: int) -> tuple[np.ndarray, np.ndarray]:
139
+ if order <= 1:
140
+ order = 2
141
+ n = int(np.ceil((order + 1.0) / 2.0))
142
+ x1d, w1d = np.polynomial.legendre.leggauss(n)
143
+ X: np.ndarray
144
+ Y: np.ndarray
145
+ X, Y = np.meshgrid(x1d, x1d, indexing="xy")
146
+ W = np.outer(w1d, w1d)
147
+ pts = np.stack([X.ravel(), Y.ravel()], axis=1)
148
+ w = W.ravel()
149
+ return pts, w
150
+
151
+
152
+ def _facet_area_estimate(facet_nodes: np.ndarray, coords: np.ndarray) -> float:
153
+ n = int(len(facet_nodes))
154
+ if n == 3:
155
+ pts = coords[facet_nodes]
156
+ return _tri_area(pts[0], pts[1], pts[2])
157
+ if n == 4:
158
+ pts = coords[facet_nodes]
159
+ return _tri_area(pts[0], pts[1], pts[2]) + _tri_area(pts[0], pts[2], pts[3])
160
+ if n == 8:
161
+ corner_nodes = facet_nodes[:4]
162
+ pts = coords[corner_nodes]
163
+ return _tri_area(pts[0], pts[1], pts[2]) + _tri_area(pts[0], pts[2], pts[3])
164
+ if n == 9:
165
+ corner_nodes = facet_nodes[[0, 2, 8, 6]]
166
+ pts = coords[corner_nodes]
167
+ return _tri_area(pts[0], pts[1], pts[2]) + _tri_area(pts[0], pts[2], pts[3])
168
+ pts = coords[facet_nodes]
169
+ area = 0.0
170
+ p0 = pts[0]
171
+ for i in range(1, len(pts) - 1):
172
+ area += _tri_area(p0, pts[i], pts[i + 1])
173
+ return float(area)
174
+
175
+
176
+ def _facet_triangles(coords: np.ndarray, facet_nodes: np.ndarray) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
177
+ n = int(len(facet_nodes))
178
+ if n in {3, 6}:
179
+ corner = facet_nodes[:3]
180
+ pts = coords[corner]
181
+ return [(pts[0], pts[1], pts[2])]
182
+ if n == 4:
183
+ corner = facet_nodes
184
+ elif n == 8:
185
+ corner = facet_nodes[:4]
186
+ elif n == 9:
187
+ corner = facet_nodes[[0, 2, 8, 6]]
188
+ else:
189
+ corner = facet_nodes
190
+ pts = coords[corner]
191
+ if len(pts) < 3:
192
+ return []
193
+ if len(pts) == 3:
194
+ return [(pts[0], pts[1], pts[2])]
195
+ tris = [(pts[0], pts[1], pts[2])]
196
+ if len(pts) >= 4:
197
+ tris.append((pts[0], pts[2], pts[3]))
198
+ if len(pts) > 4:
199
+ for i in range(2, len(pts) - 1):
200
+ tris.append((pts[0], pts[i], pts[i + 1]))
201
+ return tris
202
+
203
+
204
+
205
+
206
+ def _tri_centroid(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
207
+ return (a + b + c) / 3.0
208
+
209
+
210
+ def _tri_quadrature(order: int) -> tuple[np.ndarray, np.ndarray]:
211
+ """
212
+ Return reference triangle quadrature points (r, s) and weights.
213
+ Reference triangle is (0,0), (1,0), (0,1); weights integrate over area 1/2.
214
+ """
215
+ if order <= 0:
216
+ return np.array([[1.0 / 3.0, 1.0 / 3.0]]), np.array([0.5])
217
+ if order <= 2:
218
+ pts = np.array(
219
+ [
220
+ [1.0 / 6.0, 1.0 / 6.0],
221
+ [2.0 / 3.0, 1.0 / 6.0],
222
+ [1.0 / 6.0, 2.0 / 3.0],
223
+ ],
224
+ dtype=float,
225
+ )
226
+ weights = np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0], dtype=float)
227
+ return pts, weights
228
+ if order <= 3:
229
+ pts = np.array(
230
+ [
231
+ [1.0 / 3.0, 1.0 / 3.0],
232
+ [0.2, 0.2],
233
+ [0.6, 0.2],
234
+ [0.2, 0.6],
235
+ ],
236
+ dtype=float,
237
+ )
238
+ weights = np.array(
239
+ [-27.0 / 96.0, 25.0 / 96.0, 25.0 / 96.0, 25.0 / 96.0],
240
+ dtype=float,
241
+ )
242
+ return pts, weights
243
+ if order <= 4:
244
+ a = 0.445948490915965
245
+ b = 0.108103018168070
246
+ c = 0.091576213509771
247
+ d = 0.816847572980459
248
+ pts = np.array(
249
+ [
250
+ [a, a],
251
+ [a, b],
252
+ [b, a],
253
+ [c, c],
254
+ [c, d],
255
+ [d, c],
256
+ ],
257
+ dtype=float,
258
+ )
259
+ weights = np.array(
260
+ [
261
+ 0.111690794839005,
262
+ 0.111690794839005,
263
+ 0.111690794839005,
264
+ 0.054975871827661,
265
+ 0.054975871827661,
266
+ 0.054975871827661,
267
+ ],
268
+ dtype=float,
269
+ )
270
+ return pts, weights
271
+ if order <= 5:
272
+ a = 0.470142064105115
273
+ b = 0.059715871789770
274
+ c = 0.101286507323456
275
+ d = 0.797426985353087
276
+ pts = np.array(
277
+ [
278
+ [1.0 / 3.0, 1.0 / 3.0],
279
+ [a, a],
280
+ [a, b],
281
+ [b, a],
282
+ [c, c],
283
+ [c, d],
284
+ [d, c],
285
+ ],
286
+ dtype=float,
287
+ )
288
+ weights = np.array(
289
+ [
290
+ 0.225000000000000,
291
+ 0.132394152788506,
292
+ 0.132394152788506,
293
+ 0.132394152788506,
294
+ 0.125939180544827,
295
+ 0.125939180544827,
296
+ 0.125939180544827,
297
+ ],
298
+ dtype=float,
299
+ )
300
+ weights *= 0.5
301
+ return pts, weights
302
+ raise NotImplementedError("triangle quadrature order > 5 is not implemented")
303
+
304
+
305
+ def _proj_diag_enabled() -> bool:
306
+ return os.getenv("FLUXFEM_PROJ_DIAG", "0") == "1"
307
+
308
+
309
+ def _proj_diag_max() -> int:
310
+ return int(os.getenv("FLUXFEM_PROJ_DIAG_MAX", "20"))
311
+
312
+
313
+ def _proj_diag_reset() -> None:
314
+ global _PROJ_DIAG_STATS, _PROJ_DIAG_COUNT
315
+ _PROJ_DIAG_STATS = {
316
+ "total": 0,
317
+ "fail": 0,
318
+ "by_code": {},
319
+ }
320
+ _PROJ_DIAG_COUNT = 0
321
+
322
+
323
+ def _proj_diag_set_context(
324
+ *,
325
+ fa: int,
326
+ fb: int,
327
+ face_a: str,
328
+ face_b: str,
329
+ elem_a: int,
330
+ elem_b: int,
331
+ ) -> None:
332
+ _PROJ_DIAG_CONTEXT.clear()
333
+ _PROJ_DIAG_CONTEXT.update(
334
+ {
335
+ "fa": int(fa),
336
+ "fb": int(fb),
337
+ "face_a": face_a,
338
+ "face_b": face_b,
339
+ "elem_a": int(elem_a),
340
+ "elem_b": int(elem_b),
341
+ }
342
+ )
343
+
344
+
345
+ def _proj_diag_attempt() -> None:
346
+ if _PROJ_DIAG_STATS is None:
347
+ return
348
+ _PROJ_DIAG_STATS["total"] += 1
349
+
350
+
351
+ def _proj_diag_log(
352
+ code: str,
353
+ *,
354
+ iters: int,
355
+ res_norm: float,
356
+ delta_norm: float | None,
357
+ detJ: float | None,
358
+ point: np.ndarray,
359
+ local: np.ndarray,
360
+ in_ref_domain: bool,
361
+ ) -> None:
362
+ global _PROJ_DIAG_COUNT
363
+ if _PROJ_DIAG_STATS is None:
364
+ return
365
+ _PROJ_DIAG_STATS["fail"] += 1
366
+ by_code = cast(dict[str, int], _PROJ_DIAG_STATS["by_code"])
367
+ by_code[code] = by_code.get(code, 0) + 1
368
+ if _PROJ_DIAG_COUNT >= _proj_diag_max():
369
+ return
370
+ _PROJ_DIAG_COUNT += 1
371
+ ctx = " ".join(f"{k}={v}" for k, v in _PROJ_DIAG_CONTEXT.items()) if _PROJ_DIAG_CONTEXT else "ctx=unknown"
372
+ det_str = "None" if detJ is None else f"{detJ:.6e}"
373
+ delta_str = "None" if delta_norm is None else f"{delta_norm:.6e}"
374
+ print(
375
+ "[fluxfem][proj][fail]",
376
+ f"code={code}",
377
+ ctx,
378
+ f"iters={iters}",
379
+ f"res={res_norm:.6e}",
380
+ f"delta={delta_str}",
381
+ f"detJ={det_str}",
382
+ f"in_ref={bool(in_ref_domain)}",
383
+ f"point={point.tolist()}",
384
+ f"local={local.tolist()}",
385
+ )
386
+
387
+
388
+ def _proj_diag_report() -> None:
389
+ if _PROJ_DIAG_STATS is None:
390
+ return
391
+ total = _PROJ_DIAG_STATS["total"]
392
+ fail = _PROJ_DIAG_STATS["fail"]
393
+ by_code = _PROJ_DIAG_STATS["by_code"]
394
+ print("[fluxfem][proj][diag] total=", total, "fail=", fail, "by_code=", by_code)
395
+
396
+
397
+ def _facet_label(facet: np.ndarray) -> str:
398
+ n = int(len(facet))
399
+ if n == 3:
400
+ return "tri3"
401
+ if n == 4:
402
+ return "quad4"
403
+ if n == 6:
404
+ return "tri6"
405
+ if n == 8:
406
+ return "quad8"
407
+ if n == 9:
408
+ return "quad9"
409
+ return f"n{n}"
410
+
411
+
412
+ def _diag_quad_override(diag_force: bool, mode: str, path: str) -> tuple[np.ndarray, np.ndarray] | None:
413
+ global _DEBUG_PROJ_QP_CACHE, _DEBUG_PROJ_QP_SOURCE
414
+ if not diag_force or mode != "load" or not path:
415
+ return None
416
+ if _DEBUG_PROJ_QP_CACHE is None:
417
+ data = np.load(path)
418
+ _DEBUG_PROJ_QP_CACHE = (np.asarray(data["quad_pts"], dtype=float), np.asarray(data["quad_w"], dtype=float))
419
+ _DEBUG_PROJ_QP_SOURCE = f"file:{path}"
420
+ return _DEBUG_PROJ_QP_CACHE
421
+
422
+
423
+ def _diag_quad_dump(diag_force: bool, mode: str, path: str, quad_pts: np.ndarray, quad_w: np.ndarray) -> None:
424
+ global _DEBUG_PROJ_QP_DUMPED
425
+ if not diag_force or mode != "dump" or not path or _DEBUG_PROJ_QP_DUMPED:
426
+ return
427
+ np.savez(path, quad_pts=np.asarray(quad_pts, dtype=float), quad_w=np.asarray(quad_w, dtype=float))
428
+ _DEBUG_PROJ_QP_DUMPED = True
429
+
430
+
431
+ def _volume_local_coords(point: np.ndarray, elem_coords: np.ndarray, *, tol: float):
432
+ n_nodes = elem_coords.shape[0]
433
+ if n_nodes in {4, 10}:
434
+ corner_coords = elem_coords[:4]
435
+ M = np.stack([corner_coords[:, 0], corner_coords[:, 1], corner_coords[:, 2], np.ones(4)], axis=1)
436
+ rhs = np.array([point[0], point[1], point[2], 1.0], dtype=float)
437
+ try:
438
+ lam = np.linalg.solve(M.T, rhs)
439
+ except np.linalg.LinAlgError:
440
+ return None
441
+ return lam
442
+ if n_nodes == 8:
443
+ _, xi, eta, zeta = _hex8_shape_and_local(point, elem_coords, tol=tol)
444
+ return np.array([xi, eta, zeta], dtype=float)
445
+ if n_nodes == 20:
446
+ _, xi, eta, zeta = _hex20_shape_and_local(point, elem_coords, tol=tol)
447
+ return np.array([xi, eta, zeta], dtype=float)
448
+ if n_nodes == 27:
449
+ _, xi, eta, zeta = _hex27_shape_and_local(point, elem_coords, tol=tol)
450
+ return np.array([xi, eta, zeta], dtype=float)
451
+ return None
452
+
453
+
454
+ def _diag_contact_projection(
455
+ *,
456
+ fa: int,
457
+ fb: int,
458
+ quad_pts: np.ndarray,
459
+ quad_w: np.ndarray,
460
+ x_q: np.ndarray,
461
+ Na: np.ndarray,
462
+ Nb: np.ndarray,
463
+ nodes_a: np.ndarray,
464
+ nodes_b: np.ndarray,
465
+ dofs_a: np.ndarray,
466
+ dofs_b: np.ndarray,
467
+ elem_coords_a: np.ndarray | None,
468
+ elem_coords_b: np.ndarray | None,
469
+ na: np.ndarray | None,
470
+ nb: np.ndarray | None,
471
+ normal: np.ndarray | None,
472
+ normal_source: str,
473
+ normal_sign: float,
474
+ detJ: float,
475
+ diag_facet: int,
476
+ diag_max_q: int,
477
+ quad_source: str,
478
+ tol: float,
479
+ ) -> None:
480
+ global _DEBUG_CONTACT_PROJ_ONCE
481
+ if _DEBUG_CONTACT_PROJ_ONCE:
482
+ return
483
+ if diag_facet >= 0 and fa != diag_facet:
484
+ return
485
+ samples = min(diag_max_q, int(x_q.shape[0]))
486
+ print("[fluxfem][diag][proj] first facet")
487
+ print(f" fa={fa} fb={fb} quad_source={quad_source}")
488
+ print(f" quad_pts={quad_pts.tolist()} quad_w={quad_w.tolist()}")
489
+ print(f" normal_source={normal_source} normal_sign={normal_sign}")
490
+ print(f" n_master={None if na is None else na.tolist()}")
491
+ print(f" n_slave={None if nb is None else nb.tolist()}")
492
+ print(f" n_used={None if normal is None else normal.tolist()}")
493
+ if normal is not None and na is not None:
494
+ print(f" dot(n_used,n_master)={float(np.dot(normal, na)):.6e}")
495
+ if normal is not None and nb is not None:
496
+ print(f" dot(n_used,n_slave)={float(np.dot(normal, nb)):.6e}")
497
+ print(f" detJ={float(detJ):.6e}")
498
+ print(f" nodes_a={nodes_a.tolist()} nodes_b={nodes_b.tolist()}")
499
+ print(f" dofs_a={dofs_a.tolist()} dofs_b={dofs_b.tolist()}")
500
+ for qi in range(samples):
501
+ nsum_a = float(np.sum(Na[qi]))
502
+ nsum_b = float(np.sum(Nb[qi]))
503
+ xq = x_q[qi]
504
+ msg = f" q{qi} x={xq.tolist()} sum(Na)={nsum_a:.6e} sum(Nb)={nsum_b:.6e}"
505
+ if elem_coords_a is not None:
506
+ xa = Na[qi] @ elem_coords_a
507
+ msg += f" x_a={xa.tolist()} |x_a-x_q|={float(np.linalg.norm(xa - xq)):.6e}"
508
+ local_a = _volume_local_coords(xq, elem_coords_a, tol=tol)
509
+ if local_a is not None:
510
+ msg += f" xi_a={local_a.tolist()}"
511
+ if elem_coords_b is not None:
512
+ xb = Nb[qi] @ elem_coords_b
513
+ msg += f" x_b={xb.tolist()} |x_b-x_q|={float(np.linalg.norm(xb - xq)):.6e}"
514
+ local_b = _volume_local_coords(xq, elem_coords_b, tol=tol)
515
+ if local_b is not None:
516
+ msg += f" xi_b={local_b.tolist()}"
517
+ print(msg)
518
+ _DEBUG_CONTACT_PROJ_ONCE = True
519
+
520
+
521
+ def _barycentric(p: np.ndarray, a: np.ndarray, b: np.ndarray, c: np.ndarray):
522
+ v0 = b - a
523
+ v1 = c - a
524
+ v2 = p - a
525
+ d00 = float(np.dot(v0, v0))
526
+ d01 = float(np.dot(v0, v1))
527
+ d11 = float(np.dot(v1, v1))
528
+ d20 = float(np.dot(v2, v0))
529
+ d21 = float(np.dot(v2, v1))
530
+ denom = d00 * d11 - d01 * d01
531
+ if abs(denom) < 1e-14:
532
+ return None
533
+ v = (d11 * d20 - d01 * d21) / denom
534
+ w = (d00 * d21 - d01 * d20) / denom
535
+ u = 1.0 - v - w
536
+ return np.array([u, v, w], dtype=float)
537
+
538
+
539
+ def _point_in_tri(lam: np.ndarray, *, tol: float) -> bool:
540
+ return bool(np.all(lam >= -tol) and np.all(lam <= 1.0 + tol))
541
+
542
+
543
+ def _plane_basis(pts: np.ndarray, *, tol: float):
544
+ v1 = pts[1] - pts[0]
545
+ v2 = pts[3] - pts[0] if pts.shape[0] > 3 else pts[2] - pts[0]
546
+ n = np.cross(v1, v2)
547
+ n_norm = np.linalg.norm(n)
548
+ if n_norm < tol:
549
+ return None, None
550
+ n = n / n_norm
551
+ t1 = v1 / np.linalg.norm(v1)
552
+ v2_proj = v2 - np.dot(v2, t1) * t1
553
+ v2_norm = np.linalg.norm(v2_proj)
554
+ if v2_norm < tol:
555
+ return None, None
556
+ t2 = v2_proj / v2_norm
557
+ return t1, t2
558
+
559
+
560
+ def _quad_shape_and_local(
561
+ point: np.ndarray,
562
+ facet_nodes: np.ndarray,
563
+ coords: np.ndarray,
564
+ *,
565
+ tol: float,
566
+ ) -> tuple[np.ndarray, float, float]:
567
+ if _proj_diag_enabled():
568
+ _proj_diag_attempt()
569
+ pts = coords[facet_nodes]
570
+ basis = _plane_basis(pts, tol=tol)
571
+ if basis[0] is None:
572
+ return np.zeros((4,), dtype=float), 0.0, 0.0
573
+ t1, t2 = basis
574
+ origin = pts[0]
575
+ local = (pts - origin) @ np.stack([t1, t2], axis=1)
576
+ p_local = (point - origin) @ np.stack([t1, t2], axis=1)
577
+ x = local[:, 0]
578
+ y = local[:, 1]
579
+ xp = float(p_local[0])
580
+ yp = float(p_local[1])
581
+
582
+ xi = 0.0
583
+ eta = 0.0
584
+ res_norm = 0.0
585
+ detJ = None
586
+ iters = 0
587
+ for _ in range(12):
588
+ iters += 1
589
+ n1 = 0.25 * (1.0 - xi) * (1.0 - eta)
590
+ n2 = 0.25 * (1.0 + xi) * (1.0 - eta)
591
+ n3 = 0.25 * (1.0 + xi) * (1.0 + eta)
592
+ n4 = 0.25 * (1.0 - xi) * (1.0 + eta)
593
+ x_m = n1 * x[0] + n2 * x[1] + n3 * x[2] + n4 * x[3]
594
+ y_m = n1 * y[0] + n2 * y[1] + n3 * y[2] + n4 * y[3]
595
+ rx = x_m - xp
596
+ ry = y_m - yp
597
+ res_norm = float(np.hypot(rx, ry))
598
+ if abs(rx) + abs(ry) < tol:
599
+ break
600
+ dndxi = np.array(
601
+ [
602
+ -0.25 * (1.0 - eta),
603
+ 0.25 * (1.0 - eta),
604
+ 0.25 * (1.0 + eta),
605
+ -0.25 * (1.0 + eta),
606
+ ],
607
+ dtype=float,
608
+ )
609
+ dndeta = np.array(
610
+ [
611
+ -0.25 * (1.0 - xi),
612
+ -0.25 * (1.0 + xi),
613
+ 0.25 * (1.0 + xi),
614
+ 0.25 * (1.0 - xi),
615
+ ],
616
+ dtype=float,
617
+ )
618
+ j11 = float(np.dot(dndxi, x))
619
+ j12 = float(np.dot(dndeta, x))
620
+ j21 = float(np.dot(dndxi, y))
621
+ j22 = float(np.dot(dndeta, y))
622
+ det = j11 * j22 - j12 * j21
623
+ detJ = float(det)
624
+ if abs(det) < tol:
625
+ if _proj_diag_enabled():
626
+ _proj_diag_log(
627
+ "SINGULAR_H",
628
+ iters=iters,
629
+ res_norm=res_norm,
630
+ delta_norm=None,
631
+ detJ=detJ,
632
+ point=point,
633
+ local=np.array([xi, eta], dtype=float),
634
+ in_ref_domain=False,
635
+ )
636
+ return np.zeros((4,), dtype=float), xi, eta
637
+ dxi = (-j22 * rx + j12 * ry) / det
638
+ deta = (j21 * rx - j11 * ry) / det
639
+ xi += dxi
640
+ eta += deta
641
+ if not np.isfinite(xi) or not np.isfinite(eta):
642
+ if _proj_diag_enabled():
643
+ _proj_diag_log(
644
+ "NAN_INF",
645
+ iters=iters,
646
+ res_norm=res_norm,
647
+ delta_norm=float(np.hypot(dxi, deta)),
648
+ detJ=detJ,
649
+ point=point,
650
+ local=np.array([xi, eta], dtype=float),
651
+ in_ref_domain=False,
652
+ )
653
+ return np.zeros((4,), dtype=float), 0.0, 0.0
654
+
655
+ in_ref = max(abs(xi), abs(eta)) <= 1.0 + tol
656
+ if _proj_diag_enabled() and (not in_ref or res_norm > tol):
657
+ code = "OUTSIDE_DOMAIN" if not in_ref else "NEWTON_NO_CONVERGE"
658
+ _proj_diag_log(
659
+ code,
660
+ iters=iters,
661
+ res_norm=res_norm,
662
+ delta_norm=None,
663
+ detJ=detJ,
664
+ point=point,
665
+ local=np.array([xi, eta], dtype=float),
666
+ in_ref_domain=in_ref,
667
+ )
668
+
669
+ return np.array([n1, n2, n3, n4], dtype=float), xi, eta
670
+
671
+
672
+ def _quad_shape_values(
673
+ point: np.ndarray,
674
+ facet_nodes: np.ndarray,
675
+ coords: np.ndarray,
676
+ *,
677
+ tol: float,
678
+ ) -> np.ndarray:
679
+ values, _xi, _eta = _quad_shape_and_local(point, facet_nodes, coords, tol=tol)
680
+ return values
681
+
682
+
683
+ def _quad8_shape_values(xi: float, eta: float) -> np.ndarray:
684
+ n1 = -0.25 * (1.0 - xi) * (1.0 - eta) * (1.0 + xi + eta)
685
+ n2 = -0.25 * (1.0 + xi) * (1.0 - eta) * (1.0 - xi + eta)
686
+ n3 = -0.25 * (1.0 + xi) * (1.0 + eta) * (1.0 - xi - eta)
687
+ n4 = -0.25 * (1.0 - xi) * (1.0 + eta) * (1.0 + xi - eta)
688
+ n5 = 0.5 * (1.0 - xi * xi) * (1.0 - eta)
689
+ n6 = 0.5 * (1.0 + xi) * (1.0 - eta * eta)
690
+ n7 = 0.5 * (1.0 - xi * xi) * (1.0 + eta)
691
+ n8 = 0.5 * (1.0 - xi) * (1.0 - eta * eta)
692
+ return np.array([n1, n2, n3, n4, n5, n6, n7, n8], dtype=float)
693
+
694
+
695
+ def _quad9_shape_values(xi: float, eta: float) -> np.ndarray:
696
+ def q1(t):
697
+ return 0.5 * t * (t - 1.0)
698
+
699
+ def q2(t):
700
+ return 1.0 - t * t
701
+
702
+ def q3(t):
703
+ return 0.5 * t * (t + 1.0)
704
+
705
+ Nx = [q1(xi), q2(xi), q3(xi)]
706
+ Ny = [q1(eta), q2(eta), q3(eta)]
707
+ out = []
708
+ for j in range(3):
709
+ for i in range(3):
710
+ out.append(Nx[i] * Ny[j])
711
+ return np.array(out, dtype=float)
712
+
713
+
714
+ def _quad9_shape_grad_ref(xi: float, eta: float) -> np.ndarray:
715
+ def q1(t):
716
+ return 0.5 * t * (t - 1.0)
717
+
718
+ def q2(t):
719
+ return 1.0 - t * t
720
+
721
+ def q3(t):
722
+ return 0.5 * t * (t + 1.0)
723
+
724
+ def dq1(t):
725
+ return t - 0.5
726
+
727
+ def dq2(t):
728
+ return -2.0 * t
729
+
730
+ def dq3(t):
731
+ return t + 0.5
732
+
733
+ Nx = [q1(xi), q2(xi), q3(xi)]
734
+ Ny = [q1(eta), q2(eta), q3(eta)]
735
+ dNx = [dq1(xi), dq2(xi), dq3(xi)]
736
+ dNy = [dq1(eta), dq2(eta), dq3(eta)]
737
+ out = []
738
+ for j in range(3):
739
+ for i in range(3):
740
+ out.append([dNx[i] * Ny[j], Nx[i] * dNy[j]])
741
+ return np.array(out, dtype=float)
742
+
743
+
744
+ def _quad9_map_and_jacobian(pts: np.ndarray, xi: float, eta: float) -> tuple[np.ndarray, np.ndarray]:
745
+ N = _quad9_shape_values(xi, eta)
746
+ dN = _quad9_shape_grad_ref(xi, eta)
747
+ x = N @ pts
748
+ J = (dN.T @ pts).T # (3,2)
749
+ return x, J
750
+
751
+
752
+ def _project_point_to_quad9(
753
+ point: np.ndarray,
754
+ pts: np.ndarray,
755
+ *,
756
+ tol: float,
757
+ max_iter: int = 15,
758
+ ) -> tuple[float, float, bool, np.ndarray, np.ndarray, dict]:
759
+ xi0 = 0.0
760
+ eta0 = 0.0
761
+ xi = xi0
762
+ eta = eta0
763
+ last_delta = np.array([np.nan, np.nan], dtype=float)
764
+ last_r = np.array([np.nan, np.nan], dtype=float)
765
+ last_det = np.nan
766
+ status = "OK"
767
+ for _ in range(max_iter):
768
+ x, J = _quad9_map_and_jacobian(pts, xi, eta)
769
+ JTJ = J.T @ J
770
+ det = float(np.linalg.det(JTJ))
771
+ last_det = det
772
+ if abs(det) < tol:
773
+ status = "SINGULAR_H"
774
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
775
+ r = J.T @ (x - point)
776
+ last_r = r
777
+ try:
778
+ delta = -np.linalg.solve(JTJ, r)
779
+ except np.linalg.LinAlgError:
780
+ status = "SINGULAR_H"
781
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
782
+ if not np.all(np.isfinite(delta)):
783
+ status = "NAN_INF"
784
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
785
+ last_delta = delta
786
+ step = float(np.max(np.abs(delta)))
787
+ if step > 1.0:
788
+ delta = delta / step
789
+ xi += float(delta[0])
790
+ eta += float(delta[1])
791
+ if float(np.linalg.norm(delta)) < tol and float(np.linalg.norm(r)) < tol:
792
+ break
793
+ x, J = _quad9_map_and_jacobian(pts, xi, eta)
794
+ ok = abs(xi) <= 1.0 + tol and abs(eta) <= 1.0 + tol
795
+ if not ok:
796
+ status = "OUTSIDE_DOMAIN"
797
+ if status == "OK" and (float(np.linalg.norm(last_delta)) >= tol or float(np.linalg.norm(last_r)) >= tol):
798
+ status = "NEWTON_NO_CONVERGE"
799
+ return xi, eta, ok and status == "OK", x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, last_det, J.T @ J)
800
+
801
+
802
+ def _tri6_shape_values(xi: float, eta: float) -> np.ndarray:
803
+ L1 = 1.0 - xi - eta
804
+ L2 = xi
805
+ L3 = eta
806
+ return np.array(
807
+ [
808
+ L1 * (2.0 * L1 - 1.0),
809
+ L2 * (2.0 * L2 - 1.0),
810
+ L3 * (2.0 * L3 - 1.0),
811
+ 4.0 * L1 * L2,
812
+ 4.0 * L2 * L3,
813
+ 4.0 * L1 * L3,
814
+ ],
815
+ dtype=float,
816
+ )
817
+
818
+
819
+ def _tri6_shape_grad_ref(xi: float, eta: float) -> np.ndarray:
820
+ L1 = 1.0 - xi - eta
821
+ L2 = xi
822
+ L3 = eta
823
+ dN1 = np.array([-(4.0 * L1 - 1.0), -(4.0 * L1 - 1.0)], dtype=float)
824
+ dN2 = np.array([4.0 * L2 - 1.0, 0.0], dtype=float)
825
+ dN3 = np.array([0.0, 4.0 * L3 - 1.0], dtype=float)
826
+ dN4 = np.array([4.0 * (L1 - L2), -4.0 * L2], dtype=float)
827
+ dN5 = np.array([4.0 * L3, 4.0 * L2], dtype=float)
828
+ dN6 = np.array([-4.0 * L3, 4.0 * (L1 - L3)], dtype=float)
829
+ return np.array([dN1, dN2, dN3, dN4, dN5, dN6], dtype=float)
830
+
831
+
832
+ def _tri6_map_and_jacobian(pts: np.ndarray, xi: float, eta: float) -> tuple[np.ndarray, np.ndarray]:
833
+ N = _tri6_shape_values(xi, eta)
834
+ dN = _tri6_shape_grad_ref(xi, eta)
835
+ x = N @ pts
836
+ J = (dN.T @ pts).T # (3,2)
837
+ return x, J
838
+
839
+
840
+ def _projection_info(
841
+ status: str,
842
+ xi0: float,
843
+ eta0: float,
844
+ xi: float,
845
+ eta: float,
846
+ r: np.ndarray,
847
+ delta: np.ndarray,
848
+ det: float,
849
+ JTJ: np.ndarray,
850
+ ) -> dict:
851
+ r_norm = float(np.linalg.norm(r)) if r.size else float("nan")
852
+ d_norm = float(np.linalg.norm(delta)) if delta.size else float("nan")
853
+ cond = float(np.linalg.cond(JTJ)) if JTJ.size and np.isfinite(JTJ).all() else float("nan")
854
+ return {
855
+ "status": status,
856
+ "xi0": float(xi0),
857
+ "eta0": float(eta0),
858
+ "xi": float(xi),
859
+ "eta": float(eta),
860
+ "r_norm": r_norm,
861
+ "d_norm": d_norm,
862
+ "det": float(det),
863
+ "cond": cond,
864
+ }
865
+
866
+
867
+ def _project_point_to_tri6(
868
+ point: np.ndarray,
869
+ pts: np.ndarray,
870
+ *,
871
+ tol: float,
872
+ max_iter: int = 15,
873
+ ) -> tuple[float, float, bool, np.ndarray, np.ndarray, dict]:
874
+ xi0 = 1.0 / 3.0
875
+ eta0 = 1.0 / 3.0
876
+ xi = xi0
877
+ eta = eta0
878
+ last_delta = np.array([np.nan, np.nan], dtype=float)
879
+ last_r = np.array([np.nan, np.nan], dtype=float)
880
+ last_det = np.nan
881
+ status = "OK"
882
+ for _ in range(max_iter):
883
+ x, J = _tri6_map_and_jacobian(pts, xi, eta)
884
+ JTJ = J.T @ J
885
+ det = float(np.linalg.det(JTJ))
886
+ last_det = det
887
+ if abs(det) < tol:
888
+ status = "SINGULAR_H"
889
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
890
+ r = J.T @ (x - point)
891
+ last_r = r
892
+ try:
893
+ delta = -np.linalg.solve(JTJ, r)
894
+ except np.linalg.LinAlgError:
895
+ status = "SINGULAR_H"
896
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
897
+ if not np.all(np.isfinite(delta)):
898
+ status = "NAN_INF"
899
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
900
+ last_delta = delta
901
+ step = float(np.max(np.abs(delta)))
902
+ if step > 1.0:
903
+ delta = delta / step
904
+ xi += float(delta[0])
905
+ eta += float(delta[1])
906
+ if float(np.linalg.norm(delta)) < tol and float(np.linalg.norm(r)) < tol:
907
+ break
908
+ x, J = _tri6_map_and_jacobian(pts, xi, eta)
909
+ ok = xi >= -tol and eta >= -tol and (xi + eta) <= 1.0 + tol
910
+ if not ok:
911
+ status = "OUTSIDE_DOMAIN"
912
+ if status == "OK" and (float(np.linalg.norm(last_delta)) >= tol or float(np.linalg.norm(last_r)) >= tol):
913
+ status = "NEWTON_NO_CONVERGE"
914
+ return xi, eta, ok and status == "OK", x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, last_det, J.T @ J)
915
+
916
+
917
+ def _facet_shape_values(
918
+ point: np.ndarray,
919
+ facet_nodes: np.ndarray,
920
+ coords: np.ndarray,
921
+ *,
922
+ tol: float,
923
+ ) -> np.ndarray:
924
+ """
925
+ Evaluate nodal shape values on a facet at a point.
926
+
927
+ Tri: standard barycentric.
928
+ Quad: split into (0,1,2) and (0,2,3) triangles, piecewise linear.
929
+ """
930
+ pts = coords[facet_nodes]
931
+ n = len(facet_nodes)
932
+ if n == 3:
933
+ lam = _barycentric(point, pts[0], pts[1], pts[2])
934
+ if lam is None:
935
+ return np.zeros((3,), dtype=float)
936
+ return lam
937
+ if n == 6:
938
+ lam = _barycentric(point, pts[0], pts[1], pts[2])
939
+ if lam is None or np.any(lam < -tol):
940
+ return np.zeros((6,), dtype=float)
941
+ L1, L2, L3 = lam
942
+ N1 = L1 * (2.0 * L1 - 1.0)
943
+ N2 = L2 * (2.0 * L2 - 1.0)
944
+ N3 = L3 * (2.0 * L3 - 1.0)
945
+ N4 = 4.0 * L1 * L2
946
+ N5 = 4.0 * L2 * L3
947
+ N6 = 4.0 * L1 * L3
948
+ return np.array([N1, N2, N3, N4, N5, N6], dtype=float)
949
+ if n == 4:
950
+ return _quad_shape_values(point, facet_nodes, coords, tol=tol)
951
+ if n == 8:
952
+ corner_nodes = facet_nodes[:4]
953
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
954
+ if np.allclose(values, 0.0):
955
+ return np.zeros((8,), dtype=float)
956
+ return _quad8_shape_values(xi, eta)
957
+ if n == 9:
958
+ corner_nodes = facet_nodes[[0, 2, 8, 6]]
959
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
960
+ if np.allclose(values, 0.0):
961
+ return np.zeros((9,), dtype=float)
962
+ return _quad9_shape_values(xi, eta)
963
+ raise ValueError("facet must be a triangle or quad")
964
+
965
+
966
+ def _gather_u_local(u_field: np.ndarray, nodes: np.ndarray, value_dim: int) -> np.ndarray:
967
+ if value_dim == 1:
968
+ return u_field[nodes]
969
+ idx = np.repeat(nodes * value_dim, value_dim) + np.tile(np.arange(value_dim), len(nodes))
970
+ return u_field[idx]
971
+
972
+
973
+ def _global_dof_indices(nodes: np.ndarray, value_dim: int, offset: int) -> np.ndarray:
974
+ if value_dim == 1:
975
+ return offset + nodes
976
+ idx = np.repeat(nodes * value_dim, value_dim) + np.tile(np.arange(value_dim), len(nodes))
977
+ return offset + idx
978
+
979
+
980
+ def map_surface_facets_to_tet_elements(surface: SurfaceMesh, tet_conn: np.ndarray) -> np.ndarray:
981
+ """
982
+ Map surface triangle facets to parent tet elements by node matching (tet4/tet10).
983
+ """
984
+ face_patterns_corner: list[tuple[int, ...]] = [
985
+ (0, 1, 2),
986
+ (0, 1, 3),
987
+ (0, 2, 3),
988
+ (1, 2, 3),
989
+ ]
990
+ face_patterns_quad: list[tuple[int, ...]] = [
991
+ (0, 1, 2, 4, 5, 6),
992
+ (0, 1, 3, 4, 8, 7),
993
+ (0, 2, 3, 6, 9, 7),
994
+ (1, 2, 3, 5, 9, 8),
995
+ ]
996
+ tet_conn = np.asarray(tet_conn, dtype=int)
997
+ if tet_conn.shape[1] not in {4, 10}:
998
+ raise NotImplementedError("Only tet4 and tet10 are supported.")
999
+ mapping_corner: dict[tuple[int, ...], int] = {}
1000
+ mapping_quad: dict[tuple[int, ...], int] = {}
1001
+ for e_id, elem in enumerate(tet_conn):
1002
+ for pattern in face_patterns_corner:
1003
+ face_nodes: tuple[int, ...] = tuple(sorted(int(elem[i]) for i in pattern))
1004
+ mapping_corner.setdefault(face_nodes, e_id)
1005
+ if elem.shape[0] == 10:
1006
+ for pattern in face_patterns_quad:
1007
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1008
+ mapping_quad.setdefault(face_nodes, e_id)
1009
+ facet_map = np.full((surface.conn.shape[0],), -1, dtype=int)
1010
+ for f_id, facet in enumerate(np.asarray(surface.conn, dtype=int)):
1011
+ key = tuple(sorted(int(n) for n in facet))
1012
+ if len(facet) == 3 and key in mapping_corner:
1013
+ facet_map[f_id] = mapping_corner[key]
1014
+ elif len(facet) == 6 and key in mapping_quad:
1015
+ facet_map[f_id] = mapping_quad[key]
1016
+ elif key in mapping_corner:
1017
+ facet_map[f_id] = mapping_corner[key]
1018
+ return facet_map
1019
+
1020
+
1021
+ def map_surface_facets_to_hex_elements(surface: SurfaceMesh, hex_conn: np.ndarray) -> np.ndarray:
1022
+ """
1023
+ Map surface quad facets to parent hex elements by node matching (hex8/hex20/hex27).
1024
+ """
1025
+ hex_conn = np.asarray(hex_conn, dtype=int)
1026
+ if hex_conn.shape[1] not in {8, 20, 27}:
1027
+ raise NotImplementedError("Only hex8/hex20/hex27 are supported.")
1028
+ face_patterns_corner: list[tuple[int, ...]] = [
1029
+ (0, 1, 2, 3),
1030
+ (4, 5, 6, 7),
1031
+ (0, 1, 5, 4),
1032
+ (1, 2, 6, 5),
1033
+ (2, 3, 7, 6),
1034
+ (3, 0, 4, 7),
1035
+ ]
1036
+ face_patterns_corner27: list[tuple[int, ...]] = [
1037
+ (0, 2, 8, 6),
1038
+ (18, 20, 26, 24),
1039
+ (0, 2, 20, 18),
1040
+ (6, 8, 26, 24),
1041
+ (0, 6, 24, 18),
1042
+ (2, 8, 26, 20),
1043
+ ]
1044
+ face_patterns_quad: list[tuple[int, ...]] = [
1045
+ (0, 1, 2, 3, 8, 9, 10, 11),
1046
+ (4, 5, 6, 7, 12, 13, 14, 15),
1047
+ (0, 1, 5, 4, 8, 17, 12, 16),
1048
+ (1, 2, 6, 5, 9, 18, 13, 17),
1049
+ (2, 3, 7, 6, 10, 19, 14, 18),
1050
+ (3, 0, 4, 7, 11, 16, 15, 19),
1051
+ ]
1052
+ face_patterns_quad9: list[tuple[int, ...]] = [
1053
+ (0, 1, 2, 3, 4, 5, 6, 7, 8),
1054
+ (18, 19, 20, 21, 22, 23, 24, 25, 26),
1055
+ (0, 1, 2, 9, 10, 11, 18, 19, 20),
1056
+ (6, 7, 8, 15, 16, 17, 24, 25, 26),
1057
+ (0, 3, 6, 9, 12, 15, 18, 21, 24),
1058
+ (2, 5, 8, 11, 14, 17, 20, 23, 26),
1059
+ ]
1060
+ mapping_corner: dict[tuple[int, ...], int] = {}
1061
+ mapping_quad: dict[tuple[int, ...], int] = {}
1062
+ for e_id, elem in enumerate(hex_conn):
1063
+ if elem.shape[0] == 27:
1064
+ corner_patterns = face_patterns_corner27
1065
+ else:
1066
+ corner_patterns = face_patterns_corner
1067
+ for pattern in corner_patterns:
1068
+ face_nodes: tuple[int, ...] = tuple(sorted(int(elem[i]) for i in pattern))
1069
+ mapping_corner.setdefault(face_nodes, e_id)
1070
+ if elem.shape[0] == 20:
1071
+ for pattern in face_patterns_quad:
1072
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1073
+ mapping_quad.setdefault(face_nodes, e_id)
1074
+ if elem.shape[0] == 27:
1075
+ for pattern in face_patterns_quad9:
1076
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1077
+ mapping_quad.setdefault(face_nodes, e_id)
1078
+ facet_map = np.full((surface.conn.shape[0],), -1, dtype=int)
1079
+ for f_id, facet in enumerate(np.asarray(surface.conn, dtype=int)):
1080
+ key = tuple(sorted(int(n) for n in facet))
1081
+ if len(facet) == 4 and key in mapping_corner:
1082
+ facet_map[f_id] = mapping_corner[key]
1083
+ elif len(facet) == 8 and key in mapping_quad:
1084
+ facet_map[f_id] = mapping_quad[key]
1085
+ elif len(facet) == 9 and key in mapping_quad:
1086
+ facet_map[f_id] = mapping_quad[key]
1087
+ elif key in mapping_corner:
1088
+ facet_map[f_id] = mapping_corner[key]
1089
+ return facet_map
1090
+
1091
+
1092
+ def _tet_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1093
+ corner_coords = elem_coords[:4]
1094
+ M = np.stack([corner_coords[:, 0], corner_coords[:, 1], corner_coords[:, 2], np.ones(4)], axis=1)
1095
+ rhs = np.array([point[0], point[1], point[2], 1.0], dtype=float)
1096
+ try:
1097
+ lam = np.linalg.solve(M.T, rhs)
1098
+ except np.linalg.LinAlgError:
1099
+ return np.zeros((elem_coords.shape[0],), dtype=float)
1100
+ if np.any(lam < -tol):
1101
+ return np.zeros((elem_coords.shape[0],), dtype=float)
1102
+ if elem_coords.shape[0] == 4:
1103
+ return lam
1104
+ if elem_coords.shape[0] != 10:
1105
+ raise NotImplementedError("tet shape evaluation supports tet4/tet10 only")
1106
+ L1, L2, L3, L4 = lam
1107
+ N1 = L1 * (2.0 * L1 - 1.0)
1108
+ N2 = L2 * (2.0 * L2 - 1.0)
1109
+ N3 = L3 * (2.0 * L3 - 1.0)
1110
+ N4 = L4 * (2.0 * L4 - 1.0)
1111
+ N5 = 4.0 * L1 * L2
1112
+ N6 = 4.0 * L2 * L3
1113
+ N7 = 4.0 * L1 * L3
1114
+ N8 = 4.0 * L1 * L4
1115
+ N9 = 4.0 * L2 * L4
1116
+ N10 = 4.0 * L3 * L4
1117
+ return np.array([N1, N2, N3, N4, N5, N6, N7, N8, N9, N10], dtype=float)
1118
+
1119
+
1120
+ def _tet_gradN(elem_coords: np.ndarray, *, point: np.ndarray | None = None, tol: float) -> np.ndarray:
1121
+ corner_coords = elem_coords[:4]
1122
+ M = np.stack([corner_coords[:, 0], corner_coords[:, 1], corner_coords[:, 2], np.ones(4)], axis=1)
1123
+ try:
1124
+ invM = np.linalg.inv(M)
1125
+ except np.linalg.LinAlgError:
1126
+ return np.zeros((elem_coords.shape[0], 3), dtype=float)
1127
+ dL = invM[:3, :].T
1128
+ if elem_coords.shape[0] == 4:
1129
+ return dL
1130
+ if elem_coords.shape[0] != 10:
1131
+ raise NotImplementedError("tet grad evaluation supports tet4/tet10 only")
1132
+ if point is None:
1133
+ raise ValueError("tet10 grad evaluation requires point")
1134
+ rhs = np.array([point[0], point[1], point[2], 1.0], dtype=float)
1135
+ try:
1136
+ lam = np.linalg.solve(M.T, rhs)
1137
+ except np.linalg.LinAlgError:
1138
+ return np.zeros((10, 3), dtype=float)
1139
+ if np.any(lam < -tol):
1140
+ return np.zeros((10, 3), dtype=float)
1141
+ L1, L2, L3, L4 = lam
1142
+ dL1, dL2, dL3, dL4 = dL
1143
+ dN1 = (4.0 * L1 - 1.0) * dL1
1144
+ dN2 = (4.0 * L2 - 1.0) * dL2
1145
+ dN3 = (4.0 * L3 - 1.0) * dL3
1146
+ dN4 = (4.0 * L4 - 1.0) * dL4
1147
+ dN5 = 4.0 * (L2 * dL1 + L1 * dL2)
1148
+ dN6 = 4.0 * (L3 * dL2 + L2 * dL3)
1149
+ dN7 = 4.0 * (L3 * dL1 + L1 * dL3)
1150
+ dN8 = 4.0 * (L4 * dL1 + L1 * dL4)
1151
+ dN9 = 4.0 * (L4 * dL2 + L2 * dL4)
1152
+ dN10 = 4.0 * (L4 * dL3 + L3 * dL4)
1153
+ return np.vstack([dN1, dN2, dN3, dN4, dN5, dN6, dN7, dN8, dN9, dN10])
1154
+
1155
+
1156
+ def _tet_gradN_at_points(
1157
+ points: np.ndarray,
1158
+ elem_coords: np.ndarray,
1159
+ *,
1160
+ local: np.ndarray | None = None,
1161
+ tol: float,
1162
+ ) -> np.ndarray:
1163
+ n_nodes = elem_coords.shape[0]
1164
+ if n_nodes == 4:
1165
+ grad = _tet_gradN(elem_coords, tol=tol)
1166
+ grad_q = np.repeat(grad[None, :, :], points.shape[0], axis=0)
1167
+ elif n_nodes == 10:
1168
+ grad_q = np.array([_tet_gradN(elem_coords, point=pt, tol=tol) for pt in points], dtype=float)
1169
+ elif n_nodes == 8:
1170
+ grad_q = np.array([_hex8_gradN(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1171
+ elif n_nodes == 20:
1172
+ grad_q = np.array([_hex20_gradN(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1173
+ elif n_nodes == 27:
1174
+ grad_q = np.array([_hex27_gradN(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1175
+ else:
1176
+ raise NotImplementedError("volume grad evaluation supports tet4/tet10/hex8/hex20/hex27 only")
1177
+ if local is not None:
1178
+ grad_q = grad_q[:, local, :]
1179
+ return grad_q
1180
+
1181
+
1182
+ def _hex8_shape_and_local(
1183
+ point: np.ndarray,
1184
+ elem_coords: np.ndarray,
1185
+ *,
1186
+ tol: float,
1187
+ ) -> tuple[np.ndarray, float, float, float]:
1188
+ if _proj_diag_enabled():
1189
+ _proj_diag_attempt()
1190
+ signs = np.array(
1191
+ [
1192
+ [-1.0, -1.0, -1.0],
1193
+ [1.0, -1.0, -1.0],
1194
+ [1.0, 1.0, -1.0],
1195
+ [-1.0, 1.0, -1.0],
1196
+ [-1.0, -1.0, 1.0],
1197
+ [1.0, -1.0, 1.0],
1198
+ [1.0, 1.0, 1.0],
1199
+ [-1.0, 1.0, 1.0],
1200
+ ],
1201
+ dtype=float,
1202
+ )
1203
+ xi = 0.0
1204
+ eta = 0.0
1205
+ zeta = 0.0
1206
+ res_norm = 0.0
1207
+ detJ = None
1208
+ iters = 0
1209
+ for _ in range(12):
1210
+ iters += 1
1211
+ n = 0.125 * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1212
+ x = n @ elem_coords
1213
+ r = x - point
1214
+ res_norm = float(np.linalg.norm(r))
1215
+ if res_norm < tol:
1216
+ break
1217
+ dN_dxi = 0.125 * signs[:, 0] * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1218
+ dN_deta = 0.125 * signs[:, 1] * (1.0 + xi * signs[:, 0]) * (1.0 + zeta * signs[:, 2])
1219
+ dN_dzeta = 0.125 * signs[:, 2] * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1])
1220
+ J = np.stack(
1221
+ [
1222
+ dN_dxi @ elem_coords,
1223
+ dN_deta @ elem_coords,
1224
+ dN_dzeta @ elem_coords,
1225
+ ],
1226
+ axis=1,
1227
+ )
1228
+ detJ = float(np.linalg.det(J))
1229
+ try:
1230
+ delta = np.linalg.solve(J, r)
1231
+ except np.linalg.LinAlgError:
1232
+ if _proj_diag_enabled():
1233
+ _proj_diag_log(
1234
+ "SINGULAR_H",
1235
+ iters=iters,
1236
+ res_norm=res_norm,
1237
+ delta_norm=None,
1238
+ detJ=detJ,
1239
+ point=point,
1240
+ local=np.array([xi, eta, zeta], dtype=float),
1241
+ in_ref_domain=False,
1242
+ )
1243
+ return np.zeros((8,), dtype=float), 0.0, 0.0, 0.0
1244
+ delta_norm = float(np.linalg.norm(delta))
1245
+ xi -= float(delta[0])
1246
+ eta -= float(delta[1])
1247
+ zeta -= float(delta[2])
1248
+ if not np.isfinite(xi) or not np.isfinite(eta) or not np.isfinite(zeta):
1249
+ if _proj_diag_enabled():
1250
+ _proj_diag_log(
1251
+ "NAN_INF",
1252
+ iters=iters,
1253
+ res_norm=res_norm,
1254
+ delta_norm=delta_norm,
1255
+ detJ=detJ,
1256
+ point=point,
1257
+ local=np.array([xi, eta, zeta], dtype=float),
1258
+ in_ref_domain=False,
1259
+ )
1260
+ return np.zeros((8,), dtype=float), 0.0, 0.0, 0.0
1261
+ if max(abs(xi), abs(eta), abs(zeta)) > 1.0 + tol:
1262
+ if _proj_diag_enabled():
1263
+ _proj_diag_log(
1264
+ "OUTSIDE_DOMAIN",
1265
+ iters=iters,
1266
+ res_norm=res_norm,
1267
+ delta_norm=None,
1268
+ detJ=detJ,
1269
+ point=point,
1270
+ local=np.array([xi, eta, zeta], dtype=float),
1271
+ in_ref_domain=False,
1272
+ )
1273
+ return np.zeros((8,), dtype=float), xi, eta, zeta
1274
+ if _proj_diag_enabled() and res_norm > tol:
1275
+ _proj_diag_log(
1276
+ "NEWTON_NO_CONVERGE",
1277
+ iters=iters,
1278
+ res_norm=res_norm,
1279
+ delta_norm=None,
1280
+ detJ=detJ,
1281
+ point=point,
1282
+ local=np.array([xi, eta, zeta], dtype=float),
1283
+ in_ref_domain=True,
1284
+ )
1285
+ n = 0.125 * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1286
+ return n, xi, eta, zeta
1287
+
1288
+
1289
+ def _hex8_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1290
+ n, _, _, _ = _hex8_shape_and_local(point, elem_coords, tol=tol)
1291
+ return n
1292
+
1293
+
1294
+ def _hex8_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1295
+ n, xi, eta, zeta = _hex8_shape_and_local(point, elem_coords, tol=tol)
1296
+ if np.allclose(n, 0.0):
1297
+ return np.zeros((8, 3), dtype=float)
1298
+ signs = np.array(
1299
+ [
1300
+ [-1.0, -1.0, -1.0],
1301
+ [1.0, -1.0, -1.0],
1302
+ [1.0, 1.0, -1.0],
1303
+ [-1.0, 1.0, -1.0],
1304
+ [-1.0, -1.0, 1.0],
1305
+ [1.0, -1.0, 1.0],
1306
+ [1.0, 1.0, 1.0],
1307
+ [-1.0, 1.0, 1.0],
1308
+ ],
1309
+ dtype=float,
1310
+ )
1311
+ dN_dxi = 0.125 * signs[:, 0] * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1312
+ dN_deta = 0.125 * signs[:, 1] * (1.0 + xi * signs[:, 0]) * (1.0 + zeta * signs[:, 2])
1313
+ dN_dzeta = 0.125 * signs[:, 2] * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1])
1314
+ J = np.stack(
1315
+ [
1316
+ dN_dxi @ elem_coords,
1317
+ dN_deta @ elem_coords,
1318
+ dN_dzeta @ elem_coords,
1319
+ ],
1320
+ axis=1,
1321
+ )
1322
+ try:
1323
+ invJ = np.linalg.inv(J)
1324
+ except np.linalg.LinAlgError:
1325
+ return np.zeros((8, 3), dtype=float)
1326
+ dN_dxi_eta = np.stack([dN_dxi, dN_deta, dN_dzeta], axis=1) # (8,3)
1327
+ return dN_dxi_eta @ invJ
1328
+
1329
+
1330
+ def _hex20_shape_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1331
+ s = np.array(
1332
+ [
1333
+ [-1.0, -1.0, -1.0],
1334
+ [1.0, -1.0, -1.0],
1335
+ [1.0, 1.0, -1.0],
1336
+ [-1.0, 1.0, -1.0],
1337
+ [-1.0, -1.0, 1.0],
1338
+ [1.0, -1.0, 1.0],
1339
+ [1.0, 1.0, 1.0],
1340
+ [-1.0, 1.0, 1.0],
1341
+ ],
1342
+ dtype=float,
1343
+ )
1344
+ sx, sy, sz = s[:, 0], s[:, 1], s[:, 2]
1345
+ term = xi * sx + eta * sy + zeta * sz - 2.0
1346
+ n_corner = 0.125 * (1.0 + sx * xi) * (1.0 + sy * eta) * (1.0 + sz * zeta) * term
1347
+
1348
+ def edge_x(sy, sz):
1349
+ return 0.25 * (1.0 - xi * xi) * (1.0 + sy * eta) * (1.0 + sz * zeta)
1350
+
1351
+ def edge_y(sx, sz):
1352
+ return 0.25 * (1.0 - eta * eta) * (1.0 + sx * xi) * (1.0 + sz * zeta)
1353
+
1354
+ def edge_z(sx, sy):
1355
+ return 0.25 * (1.0 - zeta * zeta) * (1.0 + sx * xi) * (1.0 + sy * eta)
1356
+
1357
+ n_edges = [
1358
+ edge_x(-1, -1),
1359
+ edge_y(1, -1),
1360
+ edge_x(1, -1),
1361
+ edge_y(-1, -1),
1362
+ edge_x(-1, 1),
1363
+ edge_y(1, 1),
1364
+ edge_x(1, 1),
1365
+ edge_y(-1, 1),
1366
+ edge_z(-1, -1),
1367
+ edge_z(1, -1),
1368
+ edge_z(1, 1),
1369
+ edge_z(-1, 1),
1370
+ ]
1371
+
1372
+ return np.concatenate([n_corner, np.array(n_edges, dtype=float)], axis=0)
1373
+
1374
+
1375
+ def _hex20_grad_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1376
+ s = np.array(
1377
+ [
1378
+ [-1.0, -1.0, -1.0],
1379
+ [1.0, -1.0, -1.0],
1380
+ [1.0, 1.0, -1.0],
1381
+ [-1.0, 1.0, -1.0],
1382
+ [-1.0, -1.0, 1.0],
1383
+ [1.0, -1.0, 1.0],
1384
+ [1.0, 1.0, 1.0],
1385
+ [-1.0, 1.0, 1.0],
1386
+ ],
1387
+ dtype=float,
1388
+ )
1389
+ sx, sy, sz = s[:, 0], s[:, 1], s[:, 2]
1390
+ term = xi * sx + eta * sy + zeta * sz - 2.0
1391
+
1392
+ dN_dxi_corner = (sx / 8.0) * (1.0 + sy * eta) * (1.0 + sz * zeta) * (term + (1.0 + sx * xi))
1393
+ dN_deta_corner = (sy / 8.0) * (1.0 + sx * xi) * (1.0 + sz * zeta) * (term + (1.0 + sy * eta))
1394
+ dN_dzeta_corner = (sz / 8.0) * (1.0 + sx * xi) * (1.0 + sy * eta) * (term + (1.0 + sz * zeta))
1395
+ d_corner = np.stack([dN_dxi_corner, dN_deta_corner, dN_dzeta_corner], axis=1)
1396
+
1397
+ def d_edge_x(sy_val, sz_val):
1398
+ dxi = -0.5 * xi * (1.0 + sy_val * eta) * (1.0 + sz_val * zeta)
1399
+ deta = 0.25 * (1.0 - xi * xi) * sy_val * (1.0 + sz_val * zeta)
1400
+ dzeta = 0.25 * (1.0 - xi * xi) * (1.0 + sy_val * eta) * sz_val
1401
+ return np.array([dxi, deta, dzeta], dtype=float)
1402
+
1403
+ def d_edge_y(sx_val, sz_val):
1404
+ dxi = 0.25 * (1.0 - eta * eta) * sx_val * (1.0 + sz_val * zeta)
1405
+ deta = -0.5 * eta * (1.0 + sx_val * xi) * (1.0 + sz_val * zeta)
1406
+ dzeta = 0.25 * (1.0 - eta * eta) * (1.0 + sx_val * xi) * sz_val
1407
+ return np.array([dxi, deta, dzeta], dtype=float)
1408
+
1409
+ def d_edge_z(sx_val, sy_val):
1410
+ dxi = 0.25 * (1.0 - zeta * zeta) * sx_val * (1.0 + sy_val * eta)
1411
+ deta = 0.25 * (1.0 - zeta * zeta) * (1.0 + sx_val * xi) * sy_val
1412
+ dzeta = -0.5 * zeta * (1.0 + sx_val * xi) * (1.0 + sy_val * eta)
1413
+ return np.array([dxi, deta, dzeta], dtype=float)
1414
+
1415
+ d_list = [
1416
+ d_edge_x(-1, -1),
1417
+ d_edge_y(1, -1),
1418
+ d_edge_x(1, -1),
1419
+ d_edge_y(-1, -1),
1420
+ d_edge_x(-1, 1),
1421
+ d_edge_y(1, 1),
1422
+ d_edge_x(1, 1),
1423
+ d_edge_y(-1, 1),
1424
+ d_edge_z(-1, -1),
1425
+ d_edge_z(1, -1),
1426
+ d_edge_z(1, 1),
1427
+ d_edge_z(-1, 1),
1428
+ ]
1429
+
1430
+ d_edges = np.stack(d_list, axis=0)
1431
+ return np.concatenate([d_corner, d_edges], axis=0)
1432
+
1433
+
1434
+ def _hex27_shape_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1435
+ def q1(t):
1436
+ return 0.5 * t * (t - 1.0)
1437
+
1438
+ def q2(t):
1439
+ return 1.0 - t * t
1440
+
1441
+ def q3(t):
1442
+ return 0.5 * t * (t + 1.0)
1443
+
1444
+ Nx = [q1(xi), q2(xi), q3(xi)]
1445
+ Ny = [q1(eta), q2(eta), q3(eta)]
1446
+ Nz = [q1(zeta), q2(zeta), q3(zeta)]
1447
+ out = []
1448
+ for k in range(3):
1449
+ for j in range(3):
1450
+ for i in range(3):
1451
+ out.append(Nx[i] * Ny[j] * Nz[k])
1452
+ return np.array(out, dtype=float)
1453
+
1454
+
1455
+ def _hex27_grad_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1456
+ def q1(t):
1457
+ return 0.5 * t * (t - 1.0)
1458
+
1459
+ def q2(t):
1460
+ return 1.0 - t * t
1461
+
1462
+ def q3(t):
1463
+ return 0.5 * t * (t + 1.0)
1464
+
1465
+ def dq1(t):
1466
+ return t - 0.5
1467
+
1468
+ def dq2(t):
1469
+ return -2.0 * t
1470
+
1471
+ def dq3(t):
1472
+ return t + 0.5
1473
+
1474
+ Nx = [q1(xi), q2(xi), q3(xi)]
1475
+ Ny = [q1(eta), q2(eta), q3(eta)]
1476
+ Nz = [q1(zeta), q2(zeta), q3(zeta)]
1477
+ dNx = [dq1(xi), dq2(xi), dq3(xi)]
1478
+ dNy = [dq1(eta), dq2(eta), dq3(eta)]
1479
+ dNz = [dq1(zeta), dq2(zeta), dq3(zeta)]
1480
+ out = []
1481
+ for k in range(3):
1482
+ for j in range(3):
1483
+ for i in range(3):
1484
+ dxi = dNx[i] * Ny[j] * Nz[k]
1485
+ deta = Nx[i] * dNy[j] * Nz[k]
1486
+ dzeta = Nx[i] * Ny[j] * dNz[k]
1487
+ out.append([dxi, deta, dzeta])
1488
+ return np.array(out, dtype=float)
1489
+
1490
+
1491
+ def _hex20_shape_and_local(
1492
+ point: np.ndarray,
1493
+ elem_coords: np.ndarray,
1494
+ *,
1495
+ tol: float,
1496
+ ) -> tuple[np.ndarray, float, float, float]:
1497
+ n8, xi, eta, zeta = _hex8_shape_and_local(point, elem_coords[:8], tol=tol)
1498
+ if np.allclose(n8, 0.0):
1499
+ return np.zeros((20,), dtype=float), 0.0, 0.0, 0.0
1500
+ if max(abs(xi), abs(eta), abs(zeta)) > 1.0 + tol:
1501
+ return np.zeros((20,), dtype=float), xi, eta, zeta
1502
+ n = _hex20_shape_ref(xi, eta, zeta)
1503
+ return n, xi, eta, zeta
1504
+
1505
+
1506
+ def _hex20_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1507
+ n, _, _, _ = _hex20_shape_and_local(point, elem_coords, tol=tol)
1508
+ return n
1509
+
1510
+
1511
+ def _hex20_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1512
+ n, xi, eta, zeta = _hex20_shape_and_local(point, elem_coords, tol=tol)
1513
+ if np.allclose(n, 0.0):
1514
+ return np.zeros((20, 3), dtype=float)
1515
+ dN = _hex20_grad_ref(xi, eta, zeta)
1516
+ J = dN.T @ elem_coords
1517
+ try:
1518
+ invJ = np.linalg.inv(J)
1519
+ except np.linalg.LinAlgError:
1520
+ return np.zeros((20, 3), dtype=float)
1521
+ return dN @ invJ
1522
+
1523
+
1524
+ def _hex27_shape_and_local(
1525
+ point: np.ndarray,
1526
+ elem_coords: np.ndarray,
1527
+ *,
1528
+ tol: float,
1529
+ ) -> tuple[np.ndarray, float, float, float]:
1530
+ corner_ids = np.array([0, 2, 8, 6, 18, 20, 26, 24], dtype=int)
1531
+ corner_coords = elem_coords[corner_ids]
1532
+ n8, xi, eta, zeta = _hex8_shape_and_local(point, corner_coords, tol=tol)
1533
+ if np.allclose(n8, 0.0):
1534
+ return np.zeros((27,), dtype=float), 0.0, 0.0, 0.0
1535
+ if max(abs(xi), abs(eta), abs(zeta)) > 1.0 + tol:
1536
+ return np.zeros((27,), dtype=float), xi, eta, zeta
1537
+ n = _hex27_shape_ref(xi, eta, zeta)
1538
+ return n, xi, eta, zeta
1539
+
1540
+
1541
+ def _hex27_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1542
+ n, _, _, _ = _hex27_shape_and_local(point, elem_coords, tol=tol)
1543
+ return n
1544
+
1545
+
1546
+ def _hex27_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1547
+ n, xi, eta, zeta = _hex27_shape_and_local(point, elem_coords, tol=tol)
1548
+ if np.allclose(n, 0.0):
1549
+ return np.zeros((27, 3), dtype=float)
1550
+ dN = _hex27_grad_ref(xi, eta, zeta)
1551
+ J = dN.T @ elem_coords
1552
+ try:
1553
+ invJ = np.linalg.inv(J)
1554
+ except np.linalg.LinAlgError:
1555
+ return np.zeros((27, 3), dtype=float)
1556
+ return dN @ invJ
1557
+
1558
+
1559
+ def _volume_shape_values_at_points(
1560
+ points: np.ndarray,
1561
+ elem_coords: np.ndarray,
1562
+ *,
1563
+ tol: float,
1564
+ ) -> np.ndarray:
1565
+ n_nodes = elem_coords.shape[0]
1566
+ if n_nodes in {4, 10}:
1567
+ return np.array([_tet_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1568
+ if n_nodes == 20:
1569
+ return np.array([_hex20_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1570
+ if n_nodes == 8:
1571
+ return np.array([_hex8_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1572
+ if n_nodes == 27:
1573
+ return np.array([_hex27_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1574
+ raise NotImplementedError("volume shape evaluation supports tet4/tet10/hex8/hex20/hex27 only")
1575
+
1576
+
1577
+ def _local_indices(elem_nodes: np.ndarray, facet_nodes: np.ndarray) -> np.ndarray:
1578
+ index = {int(n): i for i, n in enumerate(elem_nodes)}
1579
+ try:
1580
+ return np.array([index[int(n)] for n in facet_nodes], dtype=int)
1581
+ except KeyError as exc:
1582
+ raise ValueError("facet nodes are not part of the element connectivity") from exc
1583
+
1584
+
1585
+ def _surface_gradN(
1586
+ point: np.ndarray,
1587
+ facet_nodes: np.ndarray,
1588
+ coords: np.ndarray,
1589
+ *,
1590
+ tol: float,
1591
+ ) -> np.ndarray:
1592
+ global _DEBUG_SURFACE_GRADN_COUNT
1593
+ pts = coords[facet_nodes]
1594
+ n = len(facet_nodes)
1595
+ debug = bool(_DEBUG_SURFACE_GRADN) and _DEBUG_SURFACE_GRADN_COUNT < _DEBUG_SURFACE_GRADN_MAX
1596
+ if n == 3:
1597
+ dN = np.array(
1598
+ [
1599
+ [-1.0, -1.0],
1600
+ [1.0, 0.0],
1601
+ [0.0, 1.0],
1602
+ ],
1603
+ dtype=float,
1604
+ )
1605
+ dX_dxi = dN[:, 0] @ pts
1606
+ dX_deta = dN[:, 1] @ pts
1607
+ dN_lin = dN
1608
+ elif n == 4:
1609
+ values, xi, eta = _quad_shape_and_local(point, facet_nodes, coords, tol=tol)
1610
+ dN_dxi = np.array(
1611
+ [
1612
+ -0.25 * (1.0 - eta),
1613
+ 0.25 * (1.0 - eta),
1614
+ 0.25 * (1.0 + eta),
1615
+ -0.25 * (1.0 + eta),
1616
+ ],
1617
+ dtype=float,
1618
+ )
1619
+ dN_deta = np.array(
1620
+ [
1621
+ -0.25 * (1.0 - xi),
1622
+ -0.25 * (1.0 + xi),
1623
+ 0.25 * (1.0 + xi),
1624
+ 0.25 * (1.0 - xi),
1625
+ ],
1626
+ dtype=float,
1627
+ )
1628
+ dX_dxi = dN_dxi @ pts
1629
+ dX_deta = dN_deta @ pts
1630
+ dN = np.stack([dN_dxi, dN_deta], axis=1)
1631
+ dN_lin = None
1632
+ if debug:
1633
+ n_sum = float(values.sum())
1634
+ x_phys = values @ pts
1635
+ n_raw = np.cross(dX_dxi, dX_deta)
1636
+ j_surf = float(np.linalg.norm(n_raw))
1637
+ print(
1638
+ "[fluxfem][surface_gradN][quad4]",
1639
+ f"pt={np.array2string(point, precision=6)}",
1640
+ f"xi={xi:.6f}",
1641
+ f"eta={eta:.6f}",
1642
+ f"N_sum={n_sum:.6e}",
1643
+ f"dN_dxi_sum={float(dN_dxi.sum()):.6e}",
1644
+ f"dN_deta_sum={float(dN_deta.sum()):.6e}",
1645
+ f"x_phys={np.array2string(x_phys, precision=6)}",
1646
+ f"t1={np.array2string(dX_dxi, precision=6)}",
1647
+ f"t2={np.array2string(dX_deta, precision=6)}",
1648
+ f"J_surf={j_surf:.6e}",
1649
+ )
1650
+ _DEBUG_SURFACE_GRADN_COUNT += 1
1651
+ elif n == 6:
1652
+ lam = _barycentric(point, pts[0], pts[1], pts[2])
1653
+ if lam is None:
1654
+ return np.zeros((6, 3), dtype=float)
1655
+ dN_lin = np.array(
1656
+ [
1657
+ [-1.0, -1.0],
1658
+ [1.0, 0.0],
1659
+ [0.0, 1.0],
1660
+ ],
1661
+ dtype=float,
1662
+ )
1663
+ dX_dxi = dN_lin[:, 0] @ pts[:3]
1664
+ dX_deta = dN_lin[:, 1] @ pts[:3]
1665
+ dN = dN_lin
1666
+ elif n == 8:
1667
+ corner_nodes = facet_nodes[:4]
1668
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
1669
+ if np.allclose(values, 0.0):
1670
+ return np.zeros((8, 3), dtype=float)
1671
+ dN_dxi_corner = np.array(
1672
+ [
1673
+ -0.25 * (1.0 - eta),
1674
+ 0.25 * (1.0 - eta),
1675
+ 0.25 * (1.0 + eta),
1676
+ -0.25 * (1.0 + eta),
1677
+ ],
1678
+ dtype=float,
1679
+ )
1680
+ dN_deta_corner = np.array(
1681
+ [
1682
+ -0.25 * (1.0 - xi),
1683
+ -0.25 * (1.0 + xi),
1684
+ 0.25 * (1.0 + xi),
1685
+ 0.25 * (1.0 - xi),
1686
+ ],
1687
+ dtype=float,
1688
+ )
1689
+ dX_dxi = dN_dxi_corner @ pts[:4]
1690
+ dX_deta = dN_deta_corner @ pts[:4]
1691
+ dN1_dxi = -0.25 * (1.0 - eta) * ((1.0 - xi) - (1.0 + xi + eta))
1692
+ dN1_deta = -0.25 * (1.0 - xi) * ((1.0 - eta) - (1.0 + xi + eta))
1693
+ dN2_dxi = 0.25 * (1.0 - eta) * ((1.0 + xi) - (1.0 - xi + eta))
1694
+ dN2_deta = -0.25 * (1.0 + xi) * ((1.0 - eta) - (1.0 - xi + eta))
1695
+ dN3_dxi = 0.25 * (1.0 + eta) * ((1.0 + xi) - (1.0 - xi - eta))
1696
+ dN3_deta = 0.25 * (1.0 + xi) * ((1.0 + eta) - (1.0 - xi - eta))
1697
+ dN4_dxi = -0.25 * (1.0 + eta) * ((1.0 - xi) - (1.0 + xi - eta))
1698
+ dN4_deta = 0.25 * (1.0 - xi) * ((1.0 + eta) - (1.0 + xi - eta))
1699
+ dN5_dxi = -xi * (1.0 - eta)
1700
+ dN5_deta = -0.5 * (1.0 - xi * xi)
1701
+ dN6_dxi = 0.5 * (1.0 - eta * eta)
1702
+ dN6_deta = -(1.0 + xi) * eta
1703
+ dN7_dxi = -xi * (1.0 + eta)
1704
+ dN7_deta = 0.5 * (1.0 - xi * xi)
1705
+ dN8_dxi = -0.5 * (1.0 - eta * eta)
1706
+ dN8_deta = -(1.0 - xi) * eta
1707
+ dN = np.array(
1708
+ [
1709
+ [dN1_dxi, dN1_deta],
1710
+ [dN2_dxi, dN2_deta],
1711
+ [dN3_dxi, dN3_deta],
1712
+ [dN4_dxi, dN4_deta],
1713
+ [dN5_dxi, dN5_deta],
1714
+ [dN6_dxi, dN6_deta],
1715
+ [dN7_dxi, dN7_deta],
1716
+ [dN8_dxi, dN8_deta],
1717
+ ],
1718
+ dtype=float,
1719
+ )
1720
+ if debug:
1721
+ values8 = _quad8_shape_values(xi, eta)
1722
+ n_sum = float(values8.sum())
1723
+ x_phys = values8 @ pts
1724
+ n_raw = np.cross(dX_dxi, dX_deta)
1725
+ j_surf = float(np.linalg.norm(n_raw))
1726
+ print(
1727
+ "[fluxfem][surface_gradN][quad8]",
1728
+ f"pt={np.array2string(point, precision=6)}",
1729
+ f"xi={xi:.6f}",
1730
+ f"eta={eta:.6f}",
1731
+ f"N_sum={n_sum:.6e}",
1732
+ f"dN_dxi_sum={float(dN[:, 0].sum()):.6e}",
1733
+ f"dN_deta_sum={float(dN[:, 1].sum()):.6e}",
1734
+ f"x_phys={np.array2string(x_phys, precision=6)}",
1735
+ f"t1={np.array2string(dX_dxi, precision=6)}",
1736
+ f"t2={np.array2string(dX_deta, precision=6)}",
1737
+ f"J_surf={j_surf:.6e}",
1738
+ )
1739
+ _DEBUG_SURFACE_GRADN_COUNT += 1
1740
+ elif n == 9:
1741
+ corner_nodes = facet_nodes[[0, 2, 8, 6]]
1742
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
1743
+ if np.allclose(values, 0.0):
1744
+ return np.zeros((9, 3), dtype=float)
1745
+ dN_dxi_corner = np.array(
1746
+ [
1747
+ -0.25 * (1.0 - eta),
1748
+ 0.25 * (1.0 - eta),
1749
+ 0.25 * (1.0 + eta),
1750
+ -0.25 * (1.0 + eta),
1751
+ ],
1752
+ dtype=float,
1753
+ )
1754
+ dN_deta_corner = np.array(
1755
+ [
1756
+ -0.25 * (1.0 - xi),
1757
+ -0.25 * (1.0 + xi),
1758
+ 0.25 * (1.0 + xi),
1759
+ 0.25 * (1.0 - xi),
1760
+ ],
1761
+ dtype=float,
1762
+ )
1763
+ dX_dxi = dN_dxi_corner @ pts[:4]
1764
+ dX_deta = dN_deta_corner @ pts[:4]
1765
+
1766
+ def q1(t):
1767
+ return 0.5 * t * (t - 1.0)
1768
+
1769
+ def q2(t):
1770
+ return 1.0 - t * t
1771
+
1772
+ def q3(t):
1773
+ return 0.5 * t * (t + 1.0)
1774
+
1775
+ def dq1(t):
1776
+ return t - 0.5
1777
+
1778
+ def dq2(t):
1779
+ return -2.0 * t
1780
+
1781
+ def dq3(t):
1782
+ return t + 0.5
1783
+
1784
+ Nx = [q1(xi), q2(xi), q3(xi)]
1785
+ Ny = [q1(eta), q2(eta), q3(eta)]
1786
+ dNx = [dq1(xi), dq2(xi), dq3(xi)]
1787
+ dNy = [dq1(eta), dq2(eta), dq3(eta)]
1788
+ dN = []
1789
+ for j in range(3):
1790
+ for i in range(3):
1791
+ dN_dxi = dNx[i] * Ny[j]
1792
+ dN_deta = Nx[i] * dNy[j]
1793
+ dN.append([dN_dxi, dN_deta])
1794
+ dN = np.array(dN, dtype=float)
1795
+ if debug:
1796
+ values9 = _quad9_shape_values(xi, eta)
1797
+ n_sum = float(values9.sum())
1798
+ x_phys = values9 @ pts
1799
+ n_raw = np.cross(dX_dxi, dX_deta)
1800
+ j_surf = float(np.linalg.norm(n_raw))
1801
+ print(
1802
+ "[fluxfem][surface_gradN][quad9]",
1803
+ f"pt={np.array2string(point, precision=6)}",
1804
+ f"xi={xi:.6f}",
1805
+ f"eta={eta:.6f}",
1806
+ f"N_sum={n_sum:.6e}",
1807
+ f"dN_dxi_sum={float(dN[:, 0].sum()):.6e}",
1808
+ f"dN_deta_sum={float(dN[:, 1].sum()):.6e}",
1809
+ f"x_phys={np.array2string(x_phys, precision=6)}",
1810
+ f"t1={np.array2string(dX_dxi, precision=6)}",
1811
+ f"t2={np.array2string(dX_deta, precision=6)}",
1812
+ f"J_surf={j_surf:.6e}",
1813
+ )
1814
+ _DEBUG_SURFACE_GRADN_COUNT += 1
1815
+ else:
1816
+ raise ValueError("facet must be a triangle or quad")
1817
+
1818
+ J = np.stack([dX_dxi, dX_deta], axis=1) # (3, 2)
1819
+ JTJ = J.T @ J
1820
+ if abs(np.linalg.det(JTJ)) < tol:
1821
+ return np.zeros((n, 3), dtype=float)
1822
+ M = J @ np.linalg.inv(JTJ) # (3, 2)
1823
+ gradN = (M @ dN.T).T # (n, 3)
1824
+ if n == 6:
1825
+ L1, L2, L3 = lam
1826
+ g1, g2, g3 = gradN[:3]
1827
+ gradN = np.array(
1828
+ [
1829
+ (4.0 * L1 - 1.0) * g1,
1830
+ (4.0 * L2 - 1.0) * g2,
1831
+ (4.0 * L3 - 1.0) * g3,
1832
+ 4.0 * (L1 * g2 + L2 * g1),
1833
+ 4.0 * (L2 * g3 + L3 * g2),
1834
+ 4.0 * (L1 * g3 + L3 * g1),
1835
+ ],
1836
+ dtype=float,
1837
+ )
1838
+ return gradN
1839
+
1840
+
1841
+ def _iter_supermesh_tris(coords: np.ndarray, conn: np.ndarray):
1842
+ for tri in conn:
1843
+ a, b, c = coords[tri]
1844
+ yield tri, a, b, c
1845
+
1846
+
1847
+ def _projection_surface_batches(
1848
+ source_facets_a: Iterable[int],
1849
+ source_facets_b: Iterable[int],
1850
+ surface_a: SurfaceMesh,
1851
+ surface_b: SurfaceMesh,
1852
+ *,
1853
+ elem_conn_a: np.ndarray | None,
1854
+ elem_conn_b: np.ndarray | None,
1855
+ facet_to_elem_a: np.ndarray | None,
1856
+ facet_to_elem_b: np.ndarray | None,
1857
+ quad_order: int,
1858
+ grad_source: str,
1859
+ dof_source: str,
1860
+ normal_source: str,
1861
+ normal_sign: float,
1862
+ tol: float,
1863
+ ):
1864
+ if dof_source != "volume" or grad_source != "volume":
1865
+ return None, False
1866
+
1867
+ facets_a = np.asarray(surface_a.conn, dtype=int)
1868
+ facets_b = np.asarray(surface_b.conn, dtype=int)
1869
+ coords_a = np.asarray(surface_a.coords, dtype=float)
1870
+ coords_b = np.asarray(surface_b.coords, dtype=float)
1871
+
1872
+ if facets_a.shape[1] != facets_b.shape[1] or facets_a.shape[1] not in {6, 9}:
1873
+ return None, False
1874
+ if elem_conn_a is None or elem_conn_b is None or facet_to_elem_a is None or facet_to_elem_b is None:
1875
+ return None, False
1876
+
1877
+ diag = bool(_DEBUG_PROJECTION_DIAG)
1878
+ diag_max = _DEBUG_PROJECTION_DIAG_MAX if diag else 0
1879
+ total_points = 0
1880
+ fail_points = 0
1881
+ fail_by_code: dict[str, int] = {}
1882
+ fail_samples: list[dict] = []
1883
+
1884
+ def _record_failure(code: str, info: dict | None, *, face_type: str, fa: int, fb: int, elem_id_a: int, elem_id_b: int, xm):
1885
+ nonlocal fail_points
1886
+ fail_points += 1
1887
+ fail_by_code[code] = fail_by_code.get(code, 0) + 1
1888
+ if not diag or len(fail_samples) >= diag_max:
1889
+ return
1890
+ sample = {
1891
+ "code": code,
1892
+ "face_type": face_type,
1893
+ "fa": int(fa),
1894
+ "fb": int(fb),
1895
+ "elem_a": int(elem_id_a),
1896
+ "elem_b": int(elem_id_b),
1897
+ "xm": None if xm is None else np.array(xm, dtype=float),
1898
+ }
1899
+ if info:
1900
+ sample.update(info)
1901
+ fail_samples.append(sample)
1902
+
1903
+ pairs = {(int(fa), int(fb)) for fa, fb in zip(source_facets_a, source_facets_b)}
1904
+ if facets_a.shape[1] == 9:
1905
+ quad_pts, quad_w = _quad_quadrature(quad_order if quad_order > 0 else 2)
1906
+ face_type = "quad9"
1907
+ else:
1908
+ quad_pts, quad_w = _tri_quadrature(quad_order if quad_order > 0 else 1)
1909
+ face_type = "tri6"
1910
+ batches = []
1911
+ fallback = False
1912
+
1913
+ for fa, fb in pairs:
1914
+ facet_a = facets_a[fa]
1915
+ facet_b = facets_b[fb]
1916
+ pts_a = coords_a[facet_a]
1917
+ pts_b = coords_b[facet_b]
1918
+
1919
+ elem_id_a = int(facet_to_elem_a[fa])
1920
+ elem_id_b = int(facet_to_elem_b[fb])
1921
+ if elem_id_a < 0 or elem_id_b < 0:
1922
+ return None, True
1923
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
1924
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
1925
+ elem_coords_a = coords_a[elem_nodes_a]
1926
+ elem_coords_b = coords_b[elem_nodes_b]
1927
+
1928
+ x_m_list = []
1929
+ x_s_list = []
1930
+ detJ_list = []
1931
+ normal_list = []
1932
+ for (xi, eta), w in zip(quad_pts, quad_w):
1933
+ if facets_a.shape[1] == 9:
1934
+ x_m, Jm = _quad9_map_and_jacobian(pts_a, xi, eta)
1935
+ xi_s, eta_s, ok, x_s, Js, info = _project_point_to_quad9(x_m, pts_b, tol=tol)
1936
+ else:
1937
+ x_m, Jm = _tri6_map_and_jacobian(pts_a, xi, eta)
1938
+ xi_s, eta_s, ok, x_s, Js, info = _project_point_to_tri6(x_m, pts_b, tol=tol)
1939
+ total_points += 1
1940
+ n_raw = np.cross(Jm[:, 0], Jm[:, 1])
1941
+ j_surf = float(np.linalg.norm(n_raw))
1942
+ if j_surf <= tol:
1943
+ fallback = True
1944
+ _record_failure(
1945
+ "DEGENERATE_MASTER",
1946
+ None,
1947
+ face_type=face_type,
1948
+ fa=fa,
1949
+ fb=fb,
1950
+ elem_id_a=elem_id_a,
1951
+ elem_id_b=elem_id_b,
1952
+ xm=x_m,
1953
+ )
1954
+ continue
1955
+ if not ok:
1956
+ fallback = True
1957
+ _record_failure(
1958
+ info.get("status", "PROJECTION_FAIL"),
1959
+ info,
1960
+ face_type=face_type,
1961
+ fa=fa,
1962
+ fb=fb,
1963
+ elem_id_a=elem_id_a,
1964
+ elem_id_b=elem_id_b,
1965
+ xm=x_m,
1966
+ )
1967
+ continue
1968
+ n_m = n_raw / j_surf
1969
+ n_use = n_m
1970
+ if normal_source in {"b", "slave"}:
1971
+ n_raw_b = np.cross(Js[:, 0], Js[:, 1])
1972
+ n_norm_b = float(np.linalg.norm(n_raw_b))
1973
+ if n_norm_b <= tol:
1974
+ fallback = True
1975
+ _record_failure(
1976
+ "DEGENERATE_SLAVE",
1977
+ None,
1978
+ face_type=face_type,
1979
+ fa=fa,
1980
+ fb=fb,
1981
+ elem_id_a=elem_id_a,
1982
+ elem_id_b=elem_id_b,
1983
+ xm=x_m,
1984
+ )
1985
+ continue
1986
+ n_use = n_raw_b / n_norm_b
1987
+ elif normal_source == "avg":
1988
+ n_raw_b = np.cross(Js[:, 0], Js[:, 1])
1989
+ n_norm_b = float(np.linalg.norm(n_raw_b))
1990
+ if n_norm_b <= tol:
1991
+ fallback = True
1992
+ _record_failure(
1993
+ "DEGENERATE_SLAVE",
1994
+ None,
1995
+ face_type=face_type,
1996
+ fa=fa,
1997
+ fb=fb,
1998
+ elem_id_a=elem_id_a,
1999
+ elem_id_b=elem_id_b,
2000
+ xm=x_m,
2001
+ )
2002
+ continue
2003
+ n_b = n_raw_b / n_norm_b
2004
+ avg = n_m + n_b
2005
+ avg_norm = float(np.linalg.norm(avg))
2006
+ n_use = avg / avg_norm if avg_norm > tol else n_m
2007
+ x_m_list.append(x_m)
2008
+ x_s_list.append(x_s)
2009
+ detJ_list.append(float(w * j_surf))
2010
+ normal_list.append(n_use)
2011
+
2012
+ if not x_m_list:
2013
+ continue
2014
+ x_m = np.array(x_m_list, dtype=float)
2015
+ x_s = np.array(x_s_list, dtype=float)
2016
+ weights = np.array(detJ_list, dtype=float)
2017
+ normals = normal_sign * np.array(normal_list, dtype=float)
2018
+
2019
+ Na = _volume_shape_values_at_points(x_m, elem_coords_a, tol=tol)
2020
+ Nb = _volume_shape_values_at_points(x_s, elem_coords_b, tol=tol)
2021
+ gradNa = _tet_gradN_at_points(x_m, elem_coords_a, tol=tol)
2022
+ gradNb = _tet_gradN_at_points(x_s, elem_coords_b, tol=tol)
2023
+
2024
+ batches.append(
2025
+ dict(
2026
+ x_q=x_m,
2027
+ w=weights,
2028
+ detJ=np.ones_like(weights),
2029
+ Na=Na,
2030
+ Nb=Nb,
2031
+ gradNa=gradNa,
2032
+ gradNb=gradNb,
2033
+ nodes_a=elem_nodes_a,
2034
+ nodes_b=elem_nodes_b,
2035
+ normal=normals,
2036
+ )
2037
+ )
2038
+
2039
+ if diag and fail_points:
2040
+ print(
2041
+ "[fluxfem][proj][diag]",
2042
+ f"total={total_points}",
2043
+ f"fail={fail_points}",
2044
+ f"fallback={fallback}",
2045
+ f"face_type={face_type}",
2046
+ f"fail_by_code={fail_by_code}",
2047
+ )
2048
+ for i, sample in enumerate(fail_samples):
2049
+ xm = sample.get("xm")
2050
+ xm_str = np.array2string(xm, precision=6) if xm is not None else "None"
2051
+ print(
2052
+ "[fluxfem][proj][diag]",
2053
+ f"sample={i}",
2054
+ f"code={sample.get('code')}",
2055
+ f"face={sample.get('face_type')}",
2056
+ f"fa={sample.get('fa')}",
2057
+ f"fb={sample.get('fb')}",
2058
+ f"elem_a={sample.get('elem_a')}",
2059
+ f"elem_b={sample.get('elem_b')}",
2060
+ f"xm={xm_str}",
2061
+ f"xi0={sample.get('xi0', float('nan')):.6f}",
2062
+ f"eta0={sample.get('eta0', float('nan')):.6f}",
2063
+ f"xi={sample.get('xi', float('nan')):.6f}",
2064
+ f"eta={sample.get('eta', float('nan')):.6f}",
2065
+ f"r={sample.get('r_norm', float('nan')):.3e}",
2066
+ f"d={sample.get('d_norm', float('nan')):.3e}",
2067
+ f"det={sample.get('det', float('nan')):.3e}",
2068
+ f"cond={sample.get('cond', float('nan')):.3e}",
2069
+ )
2070
+
2071
+ return batches, fallback
2072
+
2073
+
2074
+ def assemble_mortar_matrices(
2075
+ supermesh_coords: np.ndarray,
2076
+ supermesh_conn: np.ndarray,
2077
+ source_facets_a: Iterable[int],
2078
+ source_facets_b: Iterable[int],
2079
+ surface_a: SurfaceMesh,
2080
+ surface_b: SurfaceMesh,
2081
+ *,
2082
+ tol: float = 1e-8,
2083
+ ) -> tuple[MortarMatrix, MortarMatrix]:
2084
+ """
2085
+ Assemble mortar coupling matrices M_aa and M_ab using centroid quadrature.
2086
+ """
2087
+ coords_a = np.asarray(surface_a.coords, dtype=float)
2088
+ coords_b = np.asarray(surface_b.coords, dtype=float)
2089
+ facets_a = np.asarray(surface_a.conn, dtype=int)
2090
+ facets_b = np.asarray(surface_b.conn, dtype=int)
2091
+
2092
+ rows_aa: list[int] = []
2093
+ cols_aa: list[int] = []
2094
+ data_aa: list[float] = []
2095
+
2096
+ rows_ab: list[int] = []
2097
+ cols_ab: list[int] = []
2098
+ data_ab: list[float] = []
2099
+
2100
+ for (tri, a, b, c), fa, fb in zip(
2101
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
2102
+ source_facets_a,
2103
+ source_facets_b,
2104
+ ):
2105
+ centroid = _tri_centroid(a, b, c)
2106
+ weight = _tri_area(a, b, c)
2107
+ if weight <= tol:
2108
+ continue
2109
+
2110
+ facet_a = facets_a[int(fa)]
2111
+ facet_b = facets_b[int(fb)]
2112
+ Na = _facet_shape_values(centroid, facet_a, coords_a, tol=tol)
2113
+ Nb = _facet_shape_values(centroid, facet_b, coords_b, tol=tol)
2114
+
2115
+ for i, node_i in enumerate(facet_a):
2116
+ for j, node_j in enumerate(facet_a):
2117
+ rows_aa.append(int(node_i))
2118
+ cols_aa.append(int(node_j))
2119
+ data_aa.append(weight * float(Na[i]) * float(Na[j]))
2120
+
2121
+ for i, node_i in enumerate(facet_a):
2122
+ for j, node_j in enumerate(facet_b):
2123
+ rows_ab.append(int(node_i))
2124
+ cols_ab.append(int(node_j))
2125
+ data_ab.append(weight * float(Na[i]) * float(Nb[j]))
2126
+
2127
+ n_a = int(np.asarray(surface_a.coords).shape[0])
2128
+ n_b = int(np.asarray(surface_b.coords).shape[0])
2129
+ M_aa = MortarMatrix(
2130
+ rows=np.asarray(rows_aa, dtype=int),
2131
+ cols=np.asarray(cols_aa, dtype=int),
2132
+ data=np.asarray(data_aa, dtype=float),
2133
+ shape=(n_a, n_a),
2134
+ )
2135
+ M_ab = MortarMatrix(
2136
+ rows=np.asarray(rows_ab, dtype=int),
2137
+ cols=np.asarray(cols_ab, dtype=int),
2138
+ data=np.asarray(data_ab, dtype=float),
2139
+ shape=(n_a, n_b),
2140
+ )
2141
+ return M_aa, M_ab
2142
+
2143
+
2144
+ def assemble_mixed_surface_residual(
2145
+ supermesh_coords: np.ndarray,
2146
+ supermesh_conn: np.ndarray,
2147
+ source_facets_a: Iterable[int],
2148
+ source_facets_b: Iterable[int],
2149
+ surface_a: SurfaceMesh,
2150
+ surface_b: SurfaceMesh,
2151
+ res_form,
2152
+ u_a: np.ndarray,
2153
+ u_b: np.ndarray,
2154
+ params,
2155
+ *,
2156
+ value_dim_a: int = 1,
2157
+ value_dim_b: int = 1,
2158
+ offset_a: int = 0,
2159
+ offset_b: int | None = None,
2160
+ field_a: str = "a",
2161
+ field_b: str = "b",
2162
+ elem_conn_a: np.ndarray | None = None,
2163
+ elem_conn_b: np.ndarray | None = None,
2164
+ facet_to_elem_a: np.ndarray | None = None,
2165
+ facet_to_elem_b: np.ndarray | None = None,
2166
+ normal_source: str = "master",
2167
+ normal_from: str | None = None,
2168
+ master_field: str | None = None,
2169
+ normal_sign: float = 1.0,
2170
+ grad_source: str = "volume",
2171
+ dof_source: str = "surface",
2172
+ quad_order: int = 0,
2173
+ tol: float = 1e-8,
2174
+ ) -> np.ndarray:
2175
+ """
2176
+ Assemble mixed surface residual over a supermesh (centroid quadrature).
2177
+
2178
+ normal_source can be "master", "slave", "a", "b", or "avg"; use master_field
2179
+ to pick which field acts as the master when normal_source is "master"/"slave".
2180
+ dof_source="volume" assembles into element nodes (requires elem_conn_* mappings).
2181
+ """
2182
+ from ..core.forms import FieldPair
2183
+ coords_a = np.asarray(surface_a.coords, dtype=float)
2184
+ coords_b = np.asarray(surface_b.coords, dtype=float)
2185
+ facets_a = np.asarray(surface_a.conn, dtype=int)
2186
+ facets_b = np.asarray(surface_b.conn, dtype=int)
2187
+ n_a = int(coords_a.shape[0] * value_dim_a)
2188
+ n_b = int(coords_b.shape[0] * value_dim_b)
2189
+ if offset_b is None:
2190
+ offset_b = offset_a + n_a
2191
+ n_total = int(offset_b + n_b)
2192
+ R: np.ndarray = np.zeros((n_total,), dtype=float)
2193
+
2194
+ trace = os.getenv("FLUXFEM_MORTAR_TRACE", "0") not in ("0", "", "false", "False")
2195
+
2196
+ def _trace_time(msg: str, t0: float) -> None:
2197
+ if trace:
2198
+ print(f"{msg} dt={time.perf_counter() - t0:.3e}s", flush=True)
2199
+
2200
+ t_norm = time.perf_counter()
2201
+ normals_a = None
2202
+ normals_b = None
2203
+ if hasattr(surface_a, "facet_normals"):
2204
+ normals_a = surface_a.facet_normals()
2205
+ if hasattr(surface_b, "facet_normals"):
2206
+ normals_b = surface_b.facet_normals()
2207
+ if trace:
2208
+ _trace_time("[CONTACT] normals_done", t_norm)
2209
+
2210
+ area_scale = float(os.getenv("FLUXFEM_SMALL_TRI_EPS_SCALE", "0.0"))
2211
+ skip_small_tri = os.getenv("FLUXFEM_SKIP_SMALL_TRI", "0") == "1" and area_scale > 0.0
2212
+ facet_area_a = None
2213
+ facet_area_b = None
2214
+ if area_scale > 0.0:
2215
+ t_area = time.perf_counter()
2216
+ facet_area_a = np.array([_facet_area_estimate(fa, coords_a) for fa in facets_a], dtype=float)
2217
+ facet_area_b = np.array([_facet_area_estimate(fb, coords_b) for fb in facets_b], dtype=float)
2218
+ if trace:
2219
+ _trace_time("[CONTACT] facet_area_done", t_area)
2220
+
2221
+ includes_measure = getattr(res_form, "_includes_measure", {})
2222
+
2223
+ use_elem_a = elem_conn_a is not None and facet_to_elem_a is not None
2224
+ use_elem_b = elem_conn_b is not None and facet_to_elem_b is not None
2225
+ if use_elem_a:
2226
+ assert elem_conn_a is not None
2227
+ assert facet_to_elem_a is not None
2228
+ if use_elem_b:
2229
+ assert elem_conn_b is not None
2230
+ assert facet_to_elem_b is not None
2231
+
2232
+ if grad_source not in {"volume", "surface"}:
2233
+ raise ValueError("grad_source must be 'volume' or 'surface'")
2234
+ if dof_source not in {"surface", "volume"}:
2235
+ raise ValueError("dof_source must be 'surface' or 'volume'")
2236
+ if dof_source == "volume" and grad_source == "surface":
2237
+ raise ValueError("dof_source 'volume' requires grad_source 'volume'")
2238
+ global _DEBUG_SURFACE_SOURCE_ONCE
2239
+ if grad_source == "surface" and not _DEBUG_SURFACE_SOURCE_ONCE:
2240
+ print("[fluxfem] using surface gradN in mortar")
2241
+ _DEBUG_SURFACE_SOURCE_ONCE = True
2242
+ proj_diag = _proj_diag_enabled()
2243
+ if proj_diag:
2244
+ _proj_diag_reset()
2245
+ diag_force = os.getenv("FLUXFEM_PROJ_DIAG_FORCE", "0") == "1"
2246
+ diag_qp_mode = os.getenv("FLUXFEM_PROJ_DIAG_QP_MODE", "").strip().lower()
2247
+ diag_qp_path = os.getenv("FLUXFEM_PROJ_DIAG_QP_PATH", "").strip()
2248
+ diag_normal = os.getenv("FLUXFEM_PROJ_DIAG_NORMAL", "").strip().lower()
2249
+ diag_facet = int(os.getenv("FLUXFEM_PROJ_DIAG_FACET", "-1"))
2250
+ diag_max_q = int(os.getenv("FLUXFEM_PROJ_DIAG_MAX_Q", "3"))
2251
+ diag_abs_detj = os.getenv("FLUXFEM_PROJ_DIAG_ABS_DETJ", "1") == "1"
2252
+
2253
+ if normal_from is not None:
2254
+ if normal_from not in {"master", "slave"}:
2255
+ raise ValueError("normal_from must be 'master' or 'slave'")
2256
+ master_name = field_a if master_field is None else master_field
2257
+ if master_name not in {field_a, field_b}:
2258
+ raise ValueError("master_field must match field_a or field_b")
2259
+ if normal_from == "master":
2260
+ normal_source = "a" if master_name == field_a else "b"
2261
+ else:
2262
+ normal_source = "b" if master_name == field_a else "a"
2263
+ if diag_force and diag_normal:
2264
+ normal_source = diag_normal
2265
+ if normal_source not in {"a", "b", "avg", "master", "slave"}:
2266
+ raise ValueError("normal_source must be 'a', 'b', 'avg', 'master', or 'slave'")
2267
+ if normal_source == "master":
2268
+ normal_source = "a" if (master_field is None or master_field == field_a) else "b"
2269
+ if normal_source == "slave":
2270
+ normal_source = "b" if (master_field is None or master_field == field_a) else "a"
2271
+
2272
+ mortar_mode = os.getenv("FLUXFEM_MORTAR_MODE", "supermesh").lower()
2273
+ if mortar_mode == "projection":
2274
+ batches, fallback = _projection_surface_batches(
2275
+ source_facets_a,
2276
+ source_facets_b,
2277
+ surface_a,
2278
+ surface_b,
2279
+ elem_conn_a=elem_conn_a,
2280
+ elem_conn_b=elem_conn_b,
2281
+ facet_to_elem_a=facet_to_elem_a,
2282
+ facet_to_elem_b=facet_to_elem_b,
2283
+ quad_order=quad_order,
2284
+ grad_source=grad_source,
2285
+ dof_source=dof_source,
2286
+ normal_source=normal_source,
2287
+ normal_sign=normal_sign,
2288
+ tol=tol,
2289
+ )
2290
+ if batches is not None and not fallback:
2291
+ for batch in batches:
2292
+ Na = batch["Na"]
2293
+ Nb = batch["Nb"]
2294
+ gradNa = batch["gradNa"]
2295
+ gradNb = batch["gradNb"]
2296
+ nodes_a = batch["nodes_a"]
2297
+ nodes_b = batch["nodes_b"]
2298
+ normal_q = batch["normal"]
2299
+
2300
+ field_a_obj = SurfaceMixedFormField(
2301
+ N=Na,
2302
+ gradN=gradNa,
2303
+ value_dim=value_dim_a,
2304
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2305
+ )
2306
+ field_b_obj = SurfaceMixedFormField(
2307
+ N=Nb,
2308
+ gradN=gradNb,
2309
+ value_dim=value_dim_b,
2310
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2311
+ )
2312
+ fields = {
2313
+ field_a: FieldPair(test=cast("FormFieldLike", field_a_obj), trial=cast("FormFieldLike", field_a_obj)),
2314
+ field_b: FieldPair(test=cast("FormFieldLike", field_b_obj), trial=cast("FormFieldLike", field_b_obj)),
2315
+ }
2316
+ ctx = SurfaceMixedFormContext(
2317
+ fields=fields,
2318
+ x_q=batch["x_q"],
2319
+ w=batch["w"],
2320
+ detJ=batch["detJ"],
2321
+ normal=normal_q,
2322
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2323
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2324
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2325
+ )
2326
+ u_elem = {
2327
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
2328
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
2329
+ }
2330
+ fe_q = res_form(ctx, u_elem, params)
2331
+ for name, facet, value_dim, offset in (
2332
+ (field_a, nodes_a, value_dim_a, offset_a),
2333
+ (field_b, nodes_b, value_dim_b, offset_b),
2334
+ ):
2335
+ fe_field = fe_q[name]
2336
+ if fe_field.ndim != 2 or fe_field.shape[0] != ctx.x_q.shape[0]:
2337
+ raise ValueError("mixed surface residual must return (n_q, n_ldofs)")
2338
+ if includes_measure.get(name, False):
2339
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
2340
+ else:
2341
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
2342
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
2343
+ dofs = _global_dof_indices(facet, value_dim, int(offset))
2344
+ R[dofs] += np.asarray(fe)
2345
+ return R
2346
+
2347
+ for (tri, a, b, c), fa, fb in zip(
2348
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
2349
+ source_facets_a,
2350
+ source_facets_b,
2351
+ ):
2352
+ area = _tri_area(a, b, c)
2353
+ if area <= tol:
2354
+ continue
2355
+ if skip_small_tri and facet_area_a is not None and facet_area_b is not None:
2356
+ area_ref = max(float(facet_area_a[int(fa)]), float(facet_area_b[int(fb)]))
2357
+ if area_ref > 0.0 and area < area_scale * area_ref:
2358
+ continue
2359
+ detJ = 2.0 * area
2360
+ if diag_force and diag_abs_detj:
2361
+ detJ = abs(detJ)
2362
+ if quad_order <= 0:
2363
+ quad_pts = np.array([[1.0 / 3.0, 1.0 / 3.0]], dtype=float)
2364
+ quad_w = np.array([0.5], dtype=float)
2365
+ else:
2366
+ quad_pts, quad_w = _tri_quadrature(quad_order)
2367
+ quad_source = "fluxfem"
2368
+ quad_override = _diag_quad_override(diag_force, diag_qp_mode, diag_qp_path)
2369
+ if quad_override is not None:
2370
+ quad_pts, quad_w = quad_override
2371
+ quad_source = _DEBUG_PROJ_QP_SOURCE or "override"
2372
+ _diag_quad_dump(diag_force, diag_qp_mode, diag_qp_path, quad_pts, quad_w)
2373
+
2374
+ facet_a = facets_a[int(fa)]
2375
+ facet_b = facets_b[int(fb)]
2376
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
2377
+
2378
+ gradNa = None
2379
+ gradNb = None
2380
+ nodes_a = facet_a
2381
+ nodes_b = facet_b
2382
+
2383
+ Na = None
2384
+ Nb = None
2385
+
2386
+ elem_id_a = -1
2387
+ elem_nodes_a = None
2388
+ elem_coords_a = None
2389
+ if use_elem_a:
2390
+ assert elem_conn_a is not None
2391
+ assert facet_to_elem_a is not None
2392
+ elem_id_a = int(facet_to_elem_a[int(fa)])
2393
+ if elem_id_a < 0:
2394
+ raise ValueError("facet_to_elem_a has invalid mapping")
2395
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
2396
+ elem_coords_a = coords_a[elem_nodes_a]
2397
+ if elem_coords_a.shape[0] not in {4, 8, 10, 20, 27}:
2398
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
2399
+
2400
+ elem_id_b = -1
2401
+ elem_nodes_b = None
2402
+ elem_coords_b = None
2403
+ if use_elem_b:
2404
+ assert elem_conn_b is not None
2405
+ assert facet_to_elem_b is not None
2406
+ elem_id_b = int(facet_to_elem_b[int(fb)])
2407
+ if elem_id_b < 0:
2408
+ raise ValueError("facet_to_elem_b has invalid mapping")
2409
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
2410
+ elem_coords_b = coords_b[elem_nodes_b]
2411
+ if elem_coords_b.shape[0] not in {4, 8, 10, 20, 27}:
2412
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
2413
+ if proj_diag:
2414
+ _proj_diag_set_context(
2415
+ fa=int(fa),
2416
+ fb=int(fb),
2417
+ face_a=_facet_label(facet_a),
2418
+ face_b=_facet_label(facet_b),
2419
+ elem_a=elem_id_a,
2420
+ elem_b=elem_id_b,
2421
+ )
2422
+
2423
+ if grad_source == "surface":
2424
+ gradNa = np.array(
2425
+ [_surface_gradN(pt, facet_a, coords_a, tol=tol) for pt in x_q],
2426
+ dtype=float,
2427
+ )
2428
+ gradNb = np.array(
2429
+ [_surface_gradN(pt, facet_b, coords_b, tol=tol) for pt in x_q],
2430
+ dtype=float,
2431
+ )
2432
+ if use_elem_a and grad_source == "volume":
2433
+ assert elem_nodes_a is not None
2434
+ assert elem_coords_a is not None
2435
+ local = _local_indices(elem_nodes_a, facet_a)
2436
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, local=local, tol=tol)
2437
+
2438
+ if use_elem_b and grad_source == "volume":
2439
+ assert elem_nodes_b is not None
2440
+ assert elem_coords_b is not None
2441
+ local = _local_indices(elem_nodes_b, facet_b)
2442
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, local=local, tol=tol)
2443
+
2444
+ if dof_source == "volume":
2445
+ if not use_elem_a or elem_nodes_a is None or elem_coords_a is None:
2446
+ raise ValueError("dof_source 'volume' requires elem_conn_a and facet_to_elem_a")
2447
+ if not use_elem_b or elem_nodes_b is None or elem_coords_b is None:
2448
+ raise ValueError("dof_source 'volume' requires elem_conn_b and facet_to_elem_b")
2449
+ nodes_a = elem_nodes_a
2450
+ nodes_b = elem_nodes_b
2451
+ Na = _volume_shape_values_at_points(x_q, elem_coords_a, tol=tol)
2452
+ Nb = _volume_shape_values_at_points(x_q, elem_coords_b, tol=tol)
2453
+ if grad_source == "volume":
2454
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, tol=tol)
2455
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, tol=tol)
2456
+ else:
2457
+ Na = np.array([_facet_shape_values(pt, facet_a, coords_a, tol=tol) for pt in x_q], dtype=float)
2458
+ Nb = np.array([_facet_shape_values(pt, facet_b, coords_b, tol=tol) for pt in x_q], dtype=float)
2459
+
2460
+ normal = None
2461
+ na = normals_a[int(fa)] if normals_a is not None else None
2462
+ nb = normals_b[int(fb)] if normals_b is not None else None
2463
+ if normal_source == "a":
2464
+ normal = na
2465
+ elif normal_source == "b":
2466
+ normal = nb
2467
+ else:
2468
+ if na is not None and nb is not None:
2469
+ avg = na + nb
2470
+ norm = np.linalg.norm(avg)
2471
+ normal = avg / norm if norm > tol else na
2472
+ else:
2473
+ normal = na if na is not None else nb
2474
+ if normal is not None:
2475
+ normal = normal_sign * normal
2476
+ if diag_force:
2477
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
2478
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
2479
+ _diag_contact_projection(
2480
+ fa=int(fa),
2481
+ fb=int(fb),
2482
+ quad_pts=quad_pts,
2483
+ quad_w=quad_w,
2484
+ x_q=x_q,
2485
+ Na=Na,
2486
+ Nb=Nb,
2487
+ nodes_a=nodes_a,
2488
+ nodes_b=nodes_b,
2489
+ dofs_a=dofs_a,
2490
+ dofs_b=dofs_b,
2491
+ elem_coords_a=elem_coords_a if dof_source == "volume" else None,
2492
+ elem_coords_b=elem_coords_b if dof_source == "volume" else None,
2493
+ na=na,
2494
+ nb=nb,
2495
+ normal=normal,
2496
+ normal_source=normal_source,
2497
+ normal_sign=normal_sign,
2498
+ detJ=detJ,
2499
+ diag_facet=diag_facet,
2500
+ diag_max_q=diag_max_q,
2501
+ quad_source=quad_source,
2502
+ tol=tol,
2503
+ )
2504
+
2505
+ field_a_obj = SurfaceMixedFormField(
2506
+ N=Na,
2507
+ gradN=gradNa,
2508
+ value_dim=value_dim_a,
2509
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2510
+ )
2511
+ field_b_obj = SurfaceMixedFormField(
2512
+ N=Nb,
2513
+ gradN=gradNb,
2514
+ value_dim=value_dim_b,
2515
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2516
+ )
2517
+ fields = {
2518
+ field_a: FieldPair(test=cast("FormFieldLike", field_a_obj), trial=cast("FormFieldLike", field_a_obj)),
2519
+ field_b: FieldPair(test=cast("FormFieldLike", field_b_obj), trial=cast("FormFieldLike", field_b_obj)),
2520
+ }
2521
+ normal_q = None if normal is None else np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
2522
+ ctx = SurfaceMixedFormContext(
2523
+ fields=fields,
2524
+ x_q=x_q,
2525
+ w=quad_w,
2526
+ detJ=np.array([detJ], dtype=float),
2527
+ normal=normal_q,
2528
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2529
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2530
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2531
+ )
2532
+
2533
+ u_elem = {
2534
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
2535
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
2536
+ }
2537
+ fe_q = res_form(ctx, u_elem, params)
2538
+ for name, facet, value_dim, offset in (
2539
+ (field_a, nodes_a, value_dim_a, offset_a),
2540
+ (field_b, nodes_b, value_dim_b, offset_b),
2541
+ ):
2542
+ fe_field = fe_q[name]
2543
+ if fe_field.ndim != 2 or fe_field.shape[0] != ctx.x_q.shape[0]:
2544
+ raise ValueError("surface residual must return shape (n_q, n_ldofs) per field")
2545
+ if includes_measure.get(name, False):
2546
+ fe = np.sum(np.asarray(fe_field), axis=0)
2547
+ else:
2548
+ wJ = ctx.w * ctx.detJ
2549
+ fe = np.einsum("qi,q->i", np.asarray(fe_field), wJ)
2550
+ dofs = _global_dof_indices(facet, value_dim, int(offset))
2551
+ R[dofs] += fe
2552
+ if proj_diag:
2553
+ _proj_diag_report()
2554
+ return R
2555
+
2556
+
2557
+ def assemble_mixed_surface_jacobian(
2558
+ supermesh_coords: np.ndarray,
2559
+ supermesh_conn: np.ndarray,
2560
+ source_facets_a: Iterable[int],
2561
+ source_facets_b: Iterable[int],
2562
+ surface_a: SurfaceMesh,
2563
+ surface_b: SurfaceMesh,
2564
+ res_form,
2565
+ u_a: np.ndarray,
2566
+ u_b: np.ndarray,
2567
+ params,
2568
+ *,
2569
+ value_dim_a: int = 1,
2570
+ value_dim_b: int = 1,
2571
+ offset_a: int = 0,
2572
+ offset_b: int | None = None,
2573
+ field_a: str = "a",
2574
+ field_b: str = "b",
2575
+ elem_conn_a: np.ndarray | None = None,
2576
+ elem_conn_b: np.ndarray | None = None,
2577
+ facet_to_elem_a: np.ndarray | None = None,
2578
+ facet_to_elem_b: np.ndarray | None = None,
2579
+ normal_source: str = "master",
2580
+ normal_from: str | None = None,
2581
+ master_field: str | None = None,
2582
+ normal_sign: float = 1.0,
2583
+ grad_source: str = "volume",
2584
+ dof_source: str = "surface",
2585
+ quad_order: int = 0,
2586
+ tol: float = 1e-8,
2587
+ sparse: bool = False,
2588
+ backend: str = "jax",
2589
+ batch_jac: bool | None = None,
2590
+ fd_eps: float = 1e-6,
2591
+ fd_mode: str = "central",
2592
+ fd_block_size: int = 1,
2593
+ ):
2594
+ """
2595
+ Assemble mixed surface Jacobian over a supermesh (centroid quadrature).
2596
+
2597
+ normal_source can be "master", "slave", "a", "b", or "avg"; use master_field
2598
+ to pick which field acts as the master when normal_source is "master"/"slave".
2599
+ dof_source="volume" assembles into element nodes (requires elem_conn_* mappings).
2600
+ """
2601
+ source_facets_a = list(source_facets_a)
2602
+ source_facets_b = list(source_facets_b)
2603
+ from ..core.forms import FieldPair
2604
+ _mortar_dbg(
2605
+ f"[mortar] enter assemble_mixed_surface_jacobian quad_order={quad_order} backend={backend}"
2606
+ )
2607
+ trace = os.getenv("FLUXFEM_MORTAR_TRACE", "0") not in ("0", "", "false", "False")
2608
+ trace_max = int(os.getenv("FLUXFEM_MORTAR_TRACE_MAX", "5"))
2609
+ trace_every = int(os.getenv("FLUXFEM_MORTAR_TRACE_EVERY", "50"))
2610
+ trace_fd_max = int(os.getenv("FLUXFEM_MORTAR_TRACE_FD_MAX", "5"))
2611
+ def _trace(msg: str) -> None:
2612
+ if trace:
2613
+ print(msg, flush=True)
2614
+ def _trace_time(msg: str, t0: float) -> None:
2615
+ if trace:
2616
+ print(f"{msg}: {time.perf_counter() - t0:.6f}s", flush=True)
2617
+ t_prep = time.perf_counter()
2618
+ coords_a = np.asarray(surface_a.coords, dtype=float)
2619
+ coords_b = np.asarray(surface_b.coords, dtype=float)
2620
+ facets_a = np.asarray(surface_a.conn, dtype=int)
2621
+ facets_b = np.asarray(surface_b.conn, dtype=int)
2622
+ n_a = int(coords_a.shape[0] * value_dim_a)
2623
+ n_b = int(coords_b.shape[0] * value_dim_b)
2624
+ if offset_b is None:
2625
+ offset_b = offset_a + n_a
2626
+ n_total = int(offset_b + n_b)
2627
+ if trace:
2628
+ _trace("[CONTACT] assemble_mixed_surface_jacobian ENTER")
2629
+ _trace(f"[CONTACT] shapes: coords_a={coords_a.shape} coords_b={coords_b.shape} supermesh={supermesh_conn.shape}")
2630
+ _trace(f"[CONTACT] dtypes: coords_a={coords_a.dtype} coords_b={coords_b.dtype} supermesh={supermesh_conn.dtype}")
2631
+ _trace(f"[CONTACT] finite: coords_a={np.isfinite(coords_a).all()} coords_b={np.isfinite(coords_b).all()}")
2632
+ _trace_time("[CONTACT] prep_done", t_prep)
2633
+
2634
+ guard = os.getenv("FLUXFEM_CONTACT_GUARD", "0") == "1"
2635
+ detj_eps = float(os.getenv("FLUXFEM_CONTACT_DETJ_EPS", "0.0"))
2636
+ tri_timeout = float(os.getenv("FLUXFEM_CONTACT_TRI_TIMEOUT_S", "0.0"))
2637
+ skip_nonfinite = os.getenv("FLUXFEM_CONTACT_SKIP_NONFINITE", "1") == "1"
2638
+ if guard:
2639
+ if not (np.isfinite(coords_a).all() and np.isfinite(coords_b).all()):
2640
+ raise RuntimeError("[CONTACT] non-finite coords in contact surfaces")
2641
+ if not np.isfinite(supermesh_coords).all():
2642
+ raise RuntimeError("[CONTACT] non-finite supermesh coords")
2643
+ if supermesh_conn.size:
2644
+ min_idx = int(supermesh_conn.min())
2645
+ max_idx = int(supermesh_conn.max())
2646
+ if min_idx < 0 or max_idx >= supermesh_coords.shape[0]:
2647
+ raise RuntimeError(
2648
+ f"[CONTACT] supermesh_conn index out of range: min={min_idx} max={max_idx} n={supermesh_coords.shape[0]}"
2649
+ )
2650
+ if len(supermesh_conn) != len(source_facets_a) or len(supermesh_conn) != len(source_facets_b):
2651
+ raise RuntimeError(
2652
+ "[CONTACT] supermesh_conn and source_facets lengths mismatch "
2653
+ f"conn={len(supermesh_conn)} fa={len(source_facets_a)} fb={len(source_facets_b)}"
2654
+ )
2655
+
2656
+ normals_a = None
2657
+ normals_b = None
2658
+ if hasattr(surface_a, "facet_normals"):
2659
+ normals_a = surface_a.facet_normals()
2660
+ if hasattr(surface_b, "facet_normals"):
2661
+ normals_b = surface_b.facet_normals()
2662
+
2663
+ area_scale = float(os.getenv("FLUXFEM_SMALL_TRI_EPS_SCALE", "0.0"))
2664
+ skip_small_tri = os.getenv("FLUXFEM_SKIP_SMALL_TRI", "0") == "1" and area_scale > 0.0
2665
+ facet_area_a = None
2666
+ facet_area_b = None
2667
+ if area_scale > 0.0:
2668
+ facet_area_a = np.array([_facet_area_estimate(fa, coords_a) for fa in facets_a], dtype=float)
2669
+ facet_area_b = np.array([_facet_area_estimate(fb, coords_b) for fb in facets_b], dtype=float)
2670
+
2671
+ includes_measure = getattr(res_form, "_includes_measure", {})
2672
+
2673
+ rows: list[int] = []
2674
+ cols: list[int] = []
2675
+ data: list[float] = []
2676
+ K_dense: np.ndarray | None = np.zeros((n_total, n_total), dtype=float) if not sparse else None
2677
+
2678
+ use_elem_a = elem_conn_a is not None and facet_to_elem_a is not None
2679
+ use_elem_b = elem_conn_b is not None and facet_to_elem_b is not None
2680
+
2681
+ if grad_source not in {"volume", "surface"}:
2682
+ raise ValueError("grad_source must be 'volume' or 'surface'")
2683
+ if dof_source not in {"surface", "volume"}:
2684
+ raise ValueError("dof_source must be 'surface' or 'volume'")
2685
+ if dof_source == "volume" and grad_source == "surface":
2686
+ raise ValueError("dof_source 'volume' requires grad_source 'volume'")
2687
+ global _DEBUG_SURFACE_SOURCE_ONCE
2688
+ if grad_source == "surface" and not _DEBUG_SURFACE_SOURCE_ONCE:
2689
+ print("[fluxfem] using surface gradN in mortar")
2690
+ _DEBUG_SURFACE_SOURCE_ONCE = True
2691
+ diag_map = os.getenv("FLUXFEM_DIAG_CONTACT_MAP", "0") == "1"
2692
+ diag_n = os.getenv("FLUXFEM_DIAG_CONTACT_N", "0") == "1"
2693
+ proj_diag = _proj_diag_enabled()
2694
+ if proj_diag:
2695
+ _proj_diag_reset()
2696
+ diag_force = os.getenv("FLUXFEM_PROJ_DIAG_FORCE", "0") == "1"
2697
+ diag_qp_mode = os.getenv("FLUXFEM_PROJ_DIAG_QP_MODE", "").strip().lower()
2698
+ diag_qp_path = os.getenv("FLUXFEM_PROJ_DIAG_QP_PATH", "").strip()
2699
+ diag_normal = os.getenv("FLUXFEM_PROJ_DIAG_NORMAL", "").strip().lower()
2700
+ diag_facet = int(os.getenv("FLUXFEM_PROJ_DIAG_FACET", "-1"))
2701
+ diag_max_q = int(os.getenv("FLUXFEM_PROJ_DIAG_MAX_Q", "3"))
2702
+ diag_abs_detj = os.getenv("FLUXFEM_PROJ_DIAG_ABS_DETJ", "1") == "1"
2703
+ if backend not in {"jax", "numpy"}:
2704
+ raise ValueError("backend must be 'jax' or 'numpy'")
2705
+ if backend == "numpy":
2706
+ if fd_eps <= 0.0:
2707
+ raise ValueError("fd_eps must be positive for numpy backend")
2708
+ if fd_mode not in {"central", "forward"}:
2709
+ raise ValueError("fd_mode must be 'central' or 'forward' for numpy backend")
2710
+ if batch_jac is None:
2711
+ batch_jac = _env_flag("FLUXFEM_MORTAR_BATCH_JAC", True)
2712
+
2713
+ if normal_from is not None:
2714
+ if normal_from not in {"master", "slave"}:
2715
+ raise ValueError("normal_from must be 'master' or 'slave'")
2716
+ master_name = field_a if master_field is None else master_field
2717
+ if master_name not in {field_a, field_b}:
2718
+ raise ValueError("master_field must match field_a or field_b")
2719
+ if normal_from == "master":
2720
+ normal_source = "a" if master_name == field_a else "b"
2721
+ else:
2722
+ normal_source = "b" if master_name == field_a else "a"
2723
+ if diag_force and diag_normal:
2724
+ normal_source = diag_normal
2725
+ if normal_source not in {"a", "b", "avg", "master", "slave"}:
2726
+ raise ValueError("normal_source must be 'a', 'b', 'avg', 'master', or 'slave'")
2727
+ if normal_source == "master":
2728
+ normal_source = "a" if (master_field is None or master_field == field_a) else "b"
2729
+ if normal_source == "slave":
2730
+ normal_source = "b" if (master_field is None or master_field == field_a) else "a"
2731
+
2732
+ mortar_mode = os.getenv("FLUXFEM_MORTAR_MODE", "supermesh").lower()
2733
+ _mortar_dbg(f"[mortar] mode={mortar_mode}")
2734
+ if mortar_mode == "projection":
2735
+ batches, fallback = _projection_surface_batches(
2736
+ source_facets_a,
2737
+ source_facets_b,
2738
+ surface_a,
2739
+ surface_b,
2740
+ elem_conn_a=elem_conn_a,
2741
+ elem_conn_b=elem_conn_b,
2742
+ facet_to_elem_a=facet_to_elem_a,
2743
+ facet_to_elem_b=facet_to_elem_b,
2744
+ quad_order=quad_order,
2745
+ grad_source=grad_source,
2746
+ dof_source=dof_source,
2747
+ normal_source=normal_source,
2748
+ normal_sign=normal_sign,
2749
+ tol=tol,
2750
+ )
2751
+ if batches is not None and not fallback:
2752
+ for batch in batches:
2753
+ Na = batch["Na"]
2754
+ Nb = batch["Nb"]
2755
+ gradNa = batch["gradNa"]
2756
+ gradNb = batch["gradNb"]
2757
+ nodes_a = batch["nodes_a"]
2758
+ nodes_b = batch["nodes_b"]
2759
+ normal_q = batch["normal"]
2760
+
2761
+ field_a_obj = SurfaceMixedFormField(
2762
+ N=Na,
2763
+ gradN=gradNa,
2764
+ value_dim=value_dim_a,
2765
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2766
+ )
2767
+ field_b_obj = SurfaceMixedFormField(
2768
+ N=Nb,
2769
+ gradN=gradNb,
2770
+ value_dim=value_dim_b,
2771
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2772
+ )
2773
+ fields = {
2774
+ field_a: FieldPair(test=cast("FormFieldLike", field_a_obj), trial=cast("FormFieldLike", field_a_obj)),
2775
+ field_b: FieldPair(test=cast("FormFieldLike", field_b_obj), trial=cast("FormFieldLike", field_b_obj)),
2776
+ }
2777
+ ctx = SurfaceMixedFormContext(
2778
+ fields=fields,
2779
+ x_q=batch["x_q"],
2780
+ w=batch["w"],
2781
+ detJ=batch["detJ"],
2782
+ normal=normal_q,
2783
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2784
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2785
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2786
+ )
2787
+
2788
+ u_elem = {
2789
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
2790
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
2791
+ }
2792
+ u_local = np.concatenate([u_elem[field_a], u_elem[field_b]], axis=0)
2793
+ sizes = (u_elem[field_a].shape[0], u_elem[field_b].shape[0])
2794
+ slices = {
2795
+ field_a: slice(0, sizes[0]),
2796
+ field_b: slice(sizes[0], sizes[0] + sizes[1]),
2797
+ }
2798
+
2799
+ def _res_local_np(u_vec):
2800
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
2801
+ fe_q = res_form(ctx, u_dict, params)
2802
+ res_parts = []
2803
+ for name in (field_a, field_b):
2804
+ fe_field = fe_q[name]
2805
+ if includes_measure.get(name, False):
2806
+ fe = np.sum(np.asarray(fe_field), axis=0)
2807
+ else:
2808
+ wJ = np.asarray(ctx.w) * np.asarray(ctx.detJ)
2809
+ fe = np.einsum("qi,q->i", np.asarray(fe_field), wJ)
2810
+ res_parts.append(np.asarray(fe))
2811
+ return np.concatenate(res_parts, axis=0)
2812
+
2813
+ if backend == "jax":
2814
+ def _res_local(u_vec):
2815
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
2816
+ fe_q = res_form(ctx, u_dict, params)
2817
+ res_parts = []
2818
+ for name in (field_a, field_b):
2819
+ fe_field = fe_q[name]
2820
+ if includes_measure.get(name, False):
2821
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
2822
+ else:
2823
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
2824
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
2825
+ res_parts.append(fe)
2826
+ return jnp.concatenate(res_parts, axis=0)
2827
+
2828
+ J_local = jax.jacrev(_res_local)(jnp.asarray(u_local))
2829
+ J_local_np = np.asarray(J_local)
2830
+ else:
2831
+ n_ldofs = int(u_local.shape[0])
2832
+ J_local_np = np.zeros((n_ldofs, n_ldofs), dtype=float)
2833
+ u_base = np.asarray(u_local, dtype=float)
2834
+ r0 = _res_local_np(u_base) if fd_mode == "forward" else None
2835
+ for i in range(n_ldofs):
2836
+ u_p = u_base.copy()
2837
+ u_p[i] += fd_eps
2838
+ r_p = _res_local_np(u_p)
2839
+ if fd_mode == "central":
2840
+ u_m = u_base.copy()
2841
+ u_m[i] -= fd_eps
2842
+ r_m = _res_local_np(u_m)
2843
+ col = (r_p - r_m) / (2.0 * fd_eps)
2844
+ else:
2845
+ col = (r_p - r0) / fd_eps
2846
+ J_local_np[:, i] = np.asarray(col, dtype=float)
2847
+
2848
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
2849
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
2850
+ dofs = np.concatenate([dofs_a, dofs_b], axis=0)
2851
+ for i, gi in enumerate(dofs):
2852
+ for j, gj in enumerate(dofs):
2853
+ val = float(J_local_np[i, j])
2854
+ if sparse:
2855
+ rows.append(int(gi))
2856
+ cols.append(int(gj))
2857
+ data.append(val)
2858
+ else:
2859
+ assert K_dense is not None
2860
+ K_dense[int(gi), int(gj)] += val
2861
+ if sparse:
2862
+ return np.asarray(rows, dtype=int), np.asarray(cols, dtype=int), np.asarray(data, dtype=float), n_total
2863
+ assert K_dense is not None
2864
+ return K_dense
2865
+
2866
+ if (
2867
+ batch_jac
2868
+ and backend == "jax"
2869
+ and dof_source == "volume"
2870
+ and grad_source == "volume"
2871
+ and use_elem_a
2872
+ and use_elem_b
2873
+ and not proj_diag
2874
+ and not diag_force
2875
+ ):
2876
+ if trace:
2877
+ _trace("[CONTACT] batch_jac_enter")
2878
+ batch_items = []
2879
+ dofs_batch = []
2880
+ u_local_batch = []
2881
+ batch_rows: list[np.ndarray] = []
2882
+ batch_cols: list[np.ndarray] = []
2883
+ batch_data: list[np.ndarray] = []
2884
+ batch_size = int(os.getenv("FLUXFEM_MORTAR_BATCH_SIZE", "128"))
2885
+ if batch_size <= 0:
2886
+ batch_size = 0
2887
+ n_q = None
2888
+ n_nodes_a = None
2889
+ n_nodes_b = None
2890
+ n_a_local_const = None
2891
+ n_b_local_const = None
2892
+ batch_failed = False
2893
+ jit_batch = _env_flag("FLUXFEM_MORTAR_BATCH_JIT", False)
2894
+
2895
+ def _make_jac_fun(n_a_local: int, n_b_local: int):
2896
+ def _res_local_batch(u_vec, Na, Nb, gradNa, gradNb, x_q, w, detJ, normal):
2897
+ field_a_obj = SurfaceMixedFormField(
2898
+ N=Na,
2899
+ gradN=gradNa,
2900
+ value_dim=value_dim_a,
2901
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2902
+ )
2903
+ field_b_obj = SurfaceMixedFormField(
2904
+ N=Nb,
2905
+ gradN=gradNb,
2906
+ value_dim=value_dim_b,
2907
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2908
+ )
2909
+ fields = {
2910
+ field_a: FieldPair(test=cast("FormFieldLike", field_a_obj), trial=cast("FormFieldLike", field_a_obj)),
2911
+ field_b: FieldPair(test=cast("FormFieldLike", field_b_obj), trial=cast("FormFieldLike", field_b_obj)),
2912
+ }
2913
+ normal_q = jnp.repeat(normal[None, :], x_q.shape[0], axis=0)
2914
+ ctx = SurfaceMixedFormContext(
2915
+ fields=fields,
2916
+ x_q=x_q,
2917
+ w=w,
2918
+ detJ=detJ,
2919
+ normal=normal_q,
2920
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2921
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2922
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2923
+ )
2924
+ u_dict = {
2925
+ field_a: u_vec[:n_a_local],
2926
+ field_b: u_vec[n_a_local:],
2927
+ }
2928
+ fe_q = res_form(ctx, u_dict, params)
2929
+ res_parts = []
2930
+ for name in (field_a, field_b):
2931
+ fe_field = fe_q[name]
2932
+ if includes_measure.get(name, False):
2933
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
2934
+ else:
2935
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
2936
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
2937
+ res_parts.append(fe)
2938
+ return jnp.concatenate(res_parts, axis=0)
2939
+
2940
+ if trace:
2941
+ _trace(f"[CONTACT] batch_jac_build n_a={n_a_local} n_b={n_b_local} jit={jit_batch}")
2942
+ jac_fun = jax.vmap(jax.jacrev(_res_local_batch))
2943
+ return jax.jit(jac_fun) if jit_batch else jac_fun
2944
+
2945
+ jac_fun_cache: dict[tuple[int, int], Callable[..., jnp.ndarray]] = {}
2946
+
2947
+ def _emit_batch(
2948
+ Na_b,
2949
+ Nb_b,
2950
+ gradNa_b,
2951
+ gradNb_b,
2952
+ x_q_b,
2953
+ w_b,
2954
+ detJ_b,
2955
+ normal_b,
2956
+ u_local_b,
2957
+ dofs_batch_np,
2958
+ n_a_local,
2959
+ n_b_local,
2960
+ batch_n,
2961
+ ) -> None:
2962
+ if trace:
2963
+ _trace(f"[CONTACT] batch_emit start n={int(Na_b.shape[0])}")
2964
+ if batch_size and batch_n < batch_size:
2965
+ pad = int(batch_size - batch_n)
2966
+ if trace:
2967
+ _trace(f"[CONTACT] batch_pad n={batch_n} target={batch_size}")
2968
+
2969
+ def _pad_batch(x, pad_value: float = 0.0):
2970
+ pad_width = [(0, pad)] + [(0, 0)] * (x.ndim - 1)
2971
+ return jnp.pad(jnp.asarray(x), pad_width, mode="constant", constant_values=pad_value)
2972
+
2973
+ Na_b = _pad_batch(Na_b)
2974
+ Nb_b = _pad_batch(Nb_b)
2975
+ gradNa_b = _pad_batch(gradNa_b)
2976
+ gradNb_b = _pad_batch(gradNb_b)
2977
+ x_q_b = _pad_batch(x_q_b)
2978
+ w_b = _pad_batch(w_b)
2979
+ detJ_b = _pad_batch(detJ_b)
2980
+ normal_b = _pad_batch(normal_b)
2981
+ u_local_b = _pad_batch(u_local_b)
2982
+ key = (n_a_local, n_b_local)
2983
+ jac_fun = jac_fun_cache.get(key)
2984
+ if jac_fun is None:
2985
+ jac_fun = _make_jac_fun(n_a_local, n_b_local)
2986
+ jac_fun_cache[key] = jac_fun
2987
+ t_batch = time.perf_counter()
2988
+ J_b = jac_fun(u_local_b, Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b)
2989
+ J_b_np = np.asarray(J_b)[:batch_n]
2990
+ if trace:
2991
+ _trace_time("[CONTACT] batch_emit jac_done", t_batch)
2992
+ n_ldofs = dofs_batch_np.shape[1]
2993
+ rows = np.repeat(dofs_batch_np, n_ldofs, axis=1).reshape(-1)
2994
+ cols = np.tile(dofs_batch_np, (1, n_ldofs)).reshape(-1)
2995
+ data = J_b_np.reshape(-1)
2996
+ if sparse:
2997
+ batch_rows.append(rows)
2998
+ batch_cols.append(cols)
2999
+ batch_data.append(data)
3000
+ else:
3001
+ assert K_dense is not None
3002
+ K_dense[rows, cols] += data
3003
+ for (tri, a, b, c), fa, fb in zip(
3004
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
3005
+ source_facets_a,
3006
+ source_facets_b,
3007
+ ):
3008
+ area = _tri_area(a, b, c)
3009
+ if area <= tol:
3010
+ continue
3011
+ if skip_small_tri and facet_area_a is not None and facet_area_b is not None:
3012
+ area_ref = max(float(facet_area_a[int(fa)]), float(facet_area_b[int(fb)]))
3013
+ if area_ref > 0.0 and area < area_scale * area_ref:
3014
+ continue
3015
+ detJ = 2.0 * area
3016
+ if diag_force and diag_abs_detj:
3017
+ detJ = abs(detJ)
3018
+ if quad_order <= 0:
3019
+ quad_pts = np.array([[1.0 / 3.0, 1.0 / 3.0]], dtype=float)
3020
+ quad_w = np.array([0.5], dtype=float)
3021
+ else:
3022
+ quad_pts, quad_w = _tri_quadrature(quad_order)
3023
+
3024
+ facet_a = facets_a[int(fa)]
3025
+ facet_b = facets_b[int(fb)]
3026
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
3027
+
3028
+ assert facet_to_elem_a is not None
3029
+ assert elem_conn_a is not None
3030
+ elem_id_a = int(facet_to_elem_a[int(fa)])
3031
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
3032
+ elem_coords_a = coords_a[elem_nodes_a]
3033
+ assert facet_to_elem_b is not None
3034
+ assert elem_conn_b is not None
3035
+ elem_id_b = int(facet_to_elem_b[int(fb)])
3036
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
3037
+ elem_coords_b = coords_b[elem_nodes_b]
3038
+
3039
+ Na = _volume_shape_values_at_points(x_q, elem_coords_a, tol=tol)
3040
+ Nb = _volume_shape_values_at_points(x_q, elem_coords_b, tol=tol)
3041
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, tol=tol)
3042
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, tol=tol)
3043
+
3044
+ na = normals_a[int(fa)] if normals_a is not None else None
3045
+ nb = normals_b[int(fb)] if normals_b is not None else None
3046
+ if normal_source == "a":
3047
+ normal = na
3048
+ elif normal_source == "b":
3049
+ normal = nb
3050
+ else:
3051
+ if na is not None and nb is not None:
3052
+ avg = na + nb
3053
+ norm = np.linalg.norm(avg)
3054
+ normal = avg / norm if norm > tol else na
3055
+ else:
3056
+ normal = na if na is not None else nb
3057
+ if normal is not None:
3058
+ normal = normal_sign * normal
3059
+ if normal is None:
3060
+ batch_failed = True
3061
+ break
3062
+
3063
+ u_elem = {
3064
+ field_a: _gather_u_local(u_a, elem_nodes_a, value_dim_a),
3065
+ field_b: _gather_u_local(u_b, elem_nodes_b, value_dim_b),
3066
+ }
3067
+ u_local = np.concatenate([u_elem[field_a], u_elem[field_b]], axis=0)
3068
+
3069
+ dofs_a = _global_dof_indices(elem_nodes_a, value_dim_a, int(offset_a))
3070
+ dofs_b = _global_dof_indices(elem_nodes_b, value_dim_b, int(offset_b))
3071
+ dofs = np.concatenate([dofs_a, dofs_b], axis=0)
3072
+
3073
+ batch_items.append((Na, Nb, gradNa, gradNb, x_q, quad_w, detJ, normal))
3074
+ dofs_batch.append(dofs)
3075
+ u_local_batch.append(u_local)
3076
+
3077
+ if n_q is None:
3078
+ n_q = Na.shape[0]
3079
+ n_nodes_a = Na.shape[1]
3080
+ n_nodes_b = Nb.shape[1]
3081
+ n_a_local_const = dofs_a.shape[0]
3082
+ n_b_local_const = dofs_b.shape[0]
3083
+ else:
3084
+ shape_mismatch = (
3085
+ Na.shape[0] != n_q
3086
+ or Nb.shape[0] != n_q
3087
+ or Na.shape[1] != n_nodes_a
3088
+ or Nb.shape[1] != n_nodes_b
3089
+ or dofs_a.shape[0] != n_a_local_const
3090
+ or dofs_b.shape[0] != n_b_local_const
3091
+ )
3092
+ if shape_mismatch:
3093
+ if batch_items:
3094
+ Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b = zip(*batch_items)
3095
+ Na_b = jnp.asarray(np.stack(Na_b, axis=0))
3096
+ Nb_b = jnp.asarray(np.stack(Nb_b, axis=0))
3097
+ gradNa_b = jnp.asarray(np.stack(gradNa_b, axis=0))
3098
+ gradNb_b = jnp.asarray(np.stack(gradNb_b, axis=0))
3099
+ x_q_b = jnp.asarray(np.stack(x_q_b, axis=0))
3100
+ w_b = jnp.asarray(np.stack(w_b, axis=0))
3101
+ detJ_b = jnp.asarray(np.array(detJ_b, dtype=float)).reshape(-1, 1)
3102
+ normal_b = jnp.asarray(np.stack(normal_b, axis=0))
3103
+ u_local_b = jnp.asarray(np.stack(u_local_batch, axis=0))
3104
+ dofs_batch_np = np.asarray(dofs_batch, dtype=int)
3105
+ assert n_a_local_const is not None
3106
+ assert n_b_local_const is not None
3107
+ _emit_batch(
3108
+ Na_b,
3109
+ Nb_b,
3110
+ gradNa_b,
3111
+ gradNb_b,
3112
+ x_q_b,
3113
+ w_b,
3114
+ detJ_b,
3115
+ normal_b,
3116
+ u_local_b,
3117
+ dofs_batch_np,
3118
+ int(n_a_local_const),
3119
+ int(n_b_local_const),
3120
+ int(Na_b.shape[0]),
3121
+ )
3122
+ batch_items = [(Na, Nb, gradNa, gradNb, x_q, quad_w, detJ, normal)]
3123
+ dofs_batch = [dofs]
3124
+ u_local_batch = [u_local]
3125
+ n_q = Na.shape[0]
3126
+ n_nodes_a = Na.shape[1]
3127
+ n_nodes_b = Nb.shape[1]
3128
+ n_a_local_const = dofs_a.shape[0]
3129
+ n_b_local_const = dofs_b.shape[0]
3130
+
3131
+ if batch_size and len(batch_items) >= batch_size:
3132
+ Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b = zip(*batch_items)
3133
+ Na_b = jnp.asarray(np.stack(Na_b, axis=0))
3134
+ Nb_b = jnp.asarray(np.stack(Nb_b, axis=0))
3135
+ gradNa_b = jnp.asarray(np.stack(gradNa_b, axis=0))
3136
+ gradNb_b = jnp.asarray(np.stack(gradNb_b, axis=0))
3137
+ x_q_b = jnp.asarray(np.stack(x_q_b, axis=0))
3138
+ w_b = jnp.asarray(np.stack(w_b, axis=0))
3139
+ detJ_b = jnp.asarray(np.array(detJ_b, dtype=float)).reshape(-1, 1)
3140
+ normal_b = jnp.asarray(np.stack(normal_b, axis=0))
3141
+ u_local_b = jnp.asarray(np.stack(u_local_batch, axis=0))
3142
+ dofs_batch_np = np.asarray(dofs_batch, dtype=int)
3143
+ assert n_a_local_const is not None
3144
+ assert n_b_local_const is not None
3145
+ _emit_batch(
3146
+ Na_b,
3147
+ Nb_b,
3148
+ gradNa_b,
3149
+ gradNb_b,
3150
+ x_q_b,
3151
+ w_b,
3152
+ detJ_b,
3153
+ normal_b,
3154
+ u_local_b,
3155
+ dofs_batch_np,
3156
+ int(n_a_local_const),
3157
+ int(n_b_local_const),
3158
+ int(Na_b.shape[0]),
3159
+ )
3160
+ batch_items = []
3161
+ dofs_batch = []
3162
+ u_local_batch = []
3163
+
3164
+ if not batch_failed and batch_items:
3165
+ Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b = zip(*batch_items)
3166
+ Na_b = jnp.asarray(np.stack(Na_b, axis=0))
3167
+ Nb_b = jnp.asarray(np.stack(Nb_b, axis=0))
3168
+ gradNa_b = jnp.asarray(np.stack(gradNa_b, axis=0))
3169
+ gradNb_b = jnp.asarray(np.stack(gradNb_b, axis=0))
3170
+ x_q_b = jnp.asarray(np.stack(x_q_b, axis=0))
3171
+ w_b = jnp.asarray(np.stack(w_b, axis=0))
3172
+ detJ_b = jnp.asarray(np.array(detJ_b, dtype=float)).reshape(-1, 1)
3173
+ normal_b = jnp.asarray(np.stack(normal_b, axis=0))
3174
+ u_local_b = jnp.asarray(np.stack(u_local_batch, axis=0))
3175
+ dofs_batch_np = np.asarray(dofs_batch, dtype=int)
3176
+ assert n_a_local_const is not None
3177
+ assert n_b_local_const is not None
3178
+ _emit_batch(
3179
+ Na_b,
3180
+ Nb_b,
3181
+ gradNa_b,
3182
+ gradNb_b,
3183
+ x_q_b,
3184
+ w_b,
3185
+ detJ_b,
3186
+ normal_b,
3187
+ u_local_b,
3188
+ dofs_batch_np,
3189
+ int(n_a_local_const),
3190
+ int(n_b_local_const),
3191
+ int(Na_b.shape[0]),
3192
+ )
3193
+
3194
+ if not batch_failed and (batch_rows or (not sparse and K_dense is not None)):
3195
+ if sparse:
3196
+ if batch_rows:
3197
+ rows_np = np.concatenate(batch_rows)
3198
+ cols_np = np.concatenate(batch_cols)
3199
+ data_np = np.concatenate(batch_data)
3200
+ else:
3201
+ rows_np = np.zeros((0,), dtype=int)
3202
+ cols_np = np.zeros((0,), dtype=int)
3203
+ data_np = np.zeros((0,), dtype=float)
3204
+ return rows_np, cols_np, data_np, n_total
3205
+ assert K_dense is not None
3206
+ return K_dense
3207
+
3208
+ if trace:
3209
+ _trace("[CONTACT] batch_jac_fallback")
3210
+
3211
+ if trace:
3212
+ _trace("[CONTACT] supermesh_loop_enter")
3213
+ _mortar_dbg("[mortar] step: supermesh loop START")
3214
+ t_loop = time.perf_counter()
3215
+ for it, ((tri, a, b, c), fa, fb) in enumerate(
3216
+ zip(
3217
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
3218
+ source_facets_a,
3219
+ source_facets_b,
3220
+ )
3221
+ ):
3222
+ log_tri = trace and (it < trace_max or it % trace_every == 0)
3223
+ t_tri0 = time.perf_counter()
3224
+ def _tri_check(stage: str) -> None:
3225
+ if tri_timeout > 0.0 and (time.perf_counter() - t_tri0) > tri_timeout:
3226
+ raise RuntimeError(f"[CONTACT] tri {it} timeout at {stage}")
3227
+ if log_tri:
3228
+ _trace(f"[CONTACT] tri {it} start fa={int(fa)} fb={int(fb)}")
3229
+ t_geom = time.perf_counter()
3230
+ area = _tri_area(a, b, c)
3231
+ if area <= tol:
3232
+ continue
3233
+ if skip_small_tri and facet_area_a is not None and facet_area_b is not None:
3234
+ area_ref = max(float(facet_area_a[int(fa)]), float(facet_area_b[int(fb)]))
3235
+ if area_ref > 0.0 and area < area_scale * area_ref:
3236
+ continue
3237
+ detJ = 2.0 * area
3238
+ if diag_force and diag_abs_detj:
3239
+ detJ = abs(detJ)
3240
+ if guard:
3241
+ if not np.isfinite(detJ):
3242
+ if log_tri:
3243
+ _trace(f"[CONTACT] tri {it} detJ non-finite; skip")
3244
+ if skip_nonfinite:
3245
+ continue
3246
+ raise RuntimeError(f"[CONTACT] tri {it} detJ non-finite")
3247
+ if detj_eps > 0.0 and abs(detJ) < detj_eps:
3248
+ if log_tri:
3249
+ _trace(f"[CONTACT] tri {it} detJ too small {detJ:.3e}; skip")
3250
+ continue
3251
+ if quad_order <= 0:
3252
+ quad_pts = np.array([[1.0 / 3.0, 1.0 / 3.0]], dtype=float)
3253
+ quad_w = np.array([0.5], dtype=float)
3254
+ else:
3255
+ quad_pts, quad_w = _tri_quadrature(quad_order)
3256
+ quad_source = "fluxfem"
3257
+ quad_override = _diag_quad_override(diag_force, diag_qp_mode, diag_qp_path)
3258
+ if quad_override is not None:
3259
+ quad_pts, quad_w = quad_override
3260
+ quad_source = _DEBUG_PROJ_QP_SOURCE or "override"
3261
+ _diag_quad_dump(diag_force, diag_qp_mode, diag_qp_path, quad_pts, quad_w)
3262
+
3263
+ facet_a = facets_a[int(fa)]
3264
+ facet_b = facets_b[int(fb)]
3265
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
3266
+ if guard and not np.isfinite(x_q).all():
3267
+ if log_tri:
3268
+ _trace(f"[CONTACT] tri {it} x_q non-finite; skip")
3269
+ if skip_nonfinite:
3270
+ continue
3271
+ raise RuntimeError(f"[CONTACT] tri {it} x_q non-finite")
3272
+ if log_tri:
3273
+ _trace_time(f"[CONTACT] tri {it} geom_done", t_geom)
3274
+ _tri_check("geom_done")
3275
+
3276
+ gradNa = None
3277
+ gradNb = None
3278
+ nodes_a = facet_a
3279
+ nodes_b = facet_b
3280
+
3281
+ Na = None
3282
+ Nb = None
3283
+
3284
+ elem_id_a = -1
3285
+ elem_nodes_a = None
3286
+ elem_coords_a = None
3287
+ local_a = None
3288
+ if use_elem_a:
3289
+ assert elem_conn_a is not None
3290
+ assert facet_to_elem_a is not None
3291
+ elem_id_a = int(facet_to_elem_a[int(fa)])
3292
+ if elem_id_a < 0:
3293
+ raise ValueError("facet_to_elem_a has invalid mapping")
3294
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
3295
+ elem_coords_a = coords_a[elem_nodes_a]
3296
+ if elem_coords_a.shape[0] not in {4, 8, 10, 20, 27}:
3297
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
3298
+
3299
+ elem_id_b = -1
3300
+ elem_nodes_b = None
3301
+ elem_coords_b = None
3302
+ local_b = None
3303
+ if use_elem_b:
3304
+ assert elem_conn_b is not None
3305
+ assert facet_to_elem_b is not None
3306
+ elem_id_b = int(facet_to_elem_b[int(fb)])
3307
+ if elem_id_b < 0:
3308
+ raise ValueError("facet_to_elem_b has invalid mapping")
3309
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
3310
+ elem_coords_b = coords_b[elem_nodes_b]
3311
+ if elem_coords_b.shape[0] not in {4, 8, 10, 20, 27}:
3312
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
3313
+ if proj_diag:
3314
+ _proj_diag_set_context(
3315
+ fa=int(fa),
3316
+ fb=int(fb),
3317
+ face_a=_facet_label(facet_a),
3318
+ face_b=_facet_label(facet_b),
3319
+ elem_a=elem_id_a,
3320
+ elem_b=elem_id_b,
3321
+ )
3322
+
3323
+ t_basis = time.perf_counter()
3324
+ if grad_source == "surface":
3325
+ gradNa = np.array(
3326
+ [_surface_gradN(pt, facet_a, coords_a, tol=tol) for pt in x_q],
3327
+ dtype=float,
3328
+ )
3329
+ gradNb = np.array(
3330
+ [_surface_gradN(pt, facet_b, coords_b, tol=tol) for pt in x_q],
3331
+ dtype=float,
3332
+ )
3333
+ if use_elem_a and grad_source == "volume":
3334
+ assert elem_nodes_a is not None
3335
+ assert elem_coords_a is not None
3336
+ local_a = _local_indices(elem_nodes_a, facet_a)
3337
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, local=local_a, tol=tol)
3338
+
3339
+ if use_elem_b and grad_source == "volume":
3340
+ assert elem_nodes_b is not None
3341
+ assert elem_coords_b is not None
3342
+ local_b = _local_indices(elem_nodes_b, facet_b)
3343
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, local=local_b, tol=tol)
3344
+
3345
+ if dof_source == "volume":
3346
+ if not use_elem_a or elem_nodes_a is None or elem_coords_a is None:
3347
+ raise ValueError("dof_source 'volume' requires elem_conn_a and facet_to_elem_a")
3348
+ if not use_elem_b or elem_nodes_b is None or elem_coords_b is None:
3349
+ raise ValueError("dof_source 'volume' requires elem_conn_b and facet_to_elem_b")
3350
+ nodes_a = elem_nodes_a
3351
+ nodes_b = elem_nodes_b
3352
+ Na = _volume_shape_values_at_points(x_q, elem_coords_a, tol=tol)
3353
+ Nb = _volume_shape_values_at_points(x_q, elem_coords_b, tol=tol)
3354
+ if grad_source == "volume":
3355
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, tol=tol)
3356
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, tol=tol)
3357
+ else:
3358
+ Na = np.array([_facet_shape_values(pt, facet_a, coords_a, tol=tol) for pt in x_q], dtype=float)
3359
+ Nb = np.array([_facet_shape_values(pt, facet_b, coords_b, tol=tol) for pt in x_q], dtype=float)
3360
+ if guard and (not np.isfinite(Na).all() or not np.isfinite(Nb).all()):
3361
+ if log_tri:
3362
+ _trace(f"[CONTACT] tri {it} N non-finite; skip")
3363
+ if skip_nonfinite:
3364
+ continue
3365
+ raise RuntimeError(f"[CONTACT] tri {it} N non-finite")
3366
+ if log_tri:
3367
+ _trace_time(f"[CONTACT] tri {it} basis_done", t_basis)
3368
+ _tri_check("basis_done")
3369
+
3370
+ global _DEBUG_CONTACT_MAP_ONCE
3371
+ if diag_map and not _DEBUG_CONTACT_MAP_ONCE:
3372
+ if use_elem_a:
3373
+ assert facet_to_elem_a is not None
3374
+ elem_id_a = int(facet_to_elem_a[int(fa)])
3375
+ else:
3376
+ elem_id_a = -1
3377
+ if use_elem_b:
3378
+ assert facet_to_elem_b is not None
3379
+ elem_id_b = int(facet_to_elem_b[int(fb)])
3380
+ else:
3381
+ elem_id_b = -1
3382
+ print("[fluxfem][diag][contact-map] first facet")
3383
+ print(f" fa={int(fa)} fb={int(fb)} elem_a={elem_id_a} elem_b={elem_id_b}")
3384
+ print(f" facet_nodes_a={facet_a.tolist()}")
3385
+ print(f" facet_nodes_b={facet_b.tolist()}")
3386
+ print(f" facet_coords_a={coords_a[facet_a].tolist()}")
3387
+ print(f" facet_coords_b={coords_b[facet_b].tolist()}")
3388
+ if elem_nodes_a is not None:
3389
+ if local_a is None:
3390
+ local_a = _local_indices(elem_nodes_a, facet_a)
3391
+ match_a = np.all(elem_nodes_a[local_a] == facet_a)
3392
+ print(f" elem_nodes_a={elem_nodes_a.tolist()}")
3393
+ print(f" local_indices_a={local_a.tolist()} match={bool(match_a)}")
3394
+ if elem_nodes_b is not None:
3395
+ if local_b is None:
3396
+ local_b = _local_indices(elem_nodes_b, facet_b)
3397
+ match_b = np.all(elem_nodes_b[local_b] == facet_b)
3398
+ print(f" elem_nodes_b={elem_nodes_b.tolist()}")
3399
+ print(f" local_indices_b={local_b.tolist()} match={bool(match_b)}")
3400
+ _DEBUG_CONTACT_MAP_ONCE = True
3401
+
3402
+ global _DEBUG_CONTACT_N_ONCE
3403
+ if diag_n and not _DEBUG_CONTACT_N_ONCE:
3404
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
3405
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
3406
+ samples = min(3, Na.shape[0])
3407
+ print("[fluxfem][diag][contact-n] first facet q-points")
3408
+ print(f" nodes_a={nodes_a.tolist()} nodes_b={nodes_b.tolist()}")
3409
+ print(f" dofs_a={dofs_a.tolist()} dofs_b={dofs_b.tolist()}")
3410
+ for qi in range(samples):
3411
+ print(f" q{qi} x={x_q[qi].tolist()} Na={Na[qi].tolist()} Nb={Nb[qi].tolist()}")
3412
+ _DEBUG_CONTACT_N_ONCE = True
3413
+
3414
+ normal = None
3415
+ na = normals_a[int(fa)] if normals_a is not None else None
3416
+ nb = normals_b[int(fb)] if normals_b is not None else None
3417
+ if normal_source == "a":
3418
+ normal = na
3419
+ elif normal_source == "b":
3420
+ normal = nb
3421
+ else:
3422
+ if na is not None and nb is not None:
3423
+ avg = na + nb
3424
+ norm = np.linalg.norm(avg)
3425
+ normal = avg / norm if norm > tol else na
3426
+ else:
3427
+ normal = na if na is not None else nb
3428
+ if normal is not None:
3429
+ normal = normal_sign * normal
3430
+
3431
+ field_a_obj = SurfaceMixedFormField(
3432
+ N=Na,
3433
+ gradN=gradNa,
3434
+ value_dim=value_dim_a,
3435
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
3436
+ )
3437
+ field_b_obj = SurfaceMixedFormField(
3438
+ N=Nb,
3439
+ gradN=gradNb,
3440
+ value_dim=value_dim_b,
3441
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
3442
+ )
3443
+ fields = {
3444
+ field_a: FieldPair(test=cast("FormFieldLike", field_a_obj), trial=cast("FormFieldLike", field_a_obj)),
3445
+ field_b: FieldPair(test=cast("FormFieldLike", field_b_obj), trial=cast("FormFieldLike", field_b_obj)),
3446
+ }
3447
+ normal_q = None if normal is None else np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
3448
+ ctx = SurfaceMixedFormContext(
3449
+ fields=fields,
3450
+ x_q=x_q,
3451
+ w=quad_w,
3452
+ detJ=np.array([detJ], dtype=float),
3453
+ normal=normal_q,
3454
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
3455
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
3456
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
3457
+ )
3458
+
3459
+ u_elem = {
3460
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
3461
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
3462
+ }
3463
+ u_local = np.concatenate([u_elem[field_a], u_elem[field_b]], axis=0)
3464
+ sizes = (u_elem[field_a].shape[0], u_elem[field_b].shape[0])
3465
+ slices = {
3466
+ field_a: slice(0, sizes[0]),
3467
+ field_b: slice(sizes[0], sizes[0] + sizes[1]),
3468
+ }
3469
+
3470
+ def _res_local(u_vec):
3471
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
3472
+ fe_q = res_form(ctx, u_dict, params)
3473
+ res_parts = []
3474
+ for name in (field_a, field_b):
3475
+ fe_field = fe_q[name]
3476
+ if includes_measure.get(name, False):
3477
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
3478
+ else:
3479
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
3480
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
3481
+ res_parts.append(fe)
3482
+ return jnp.concatenate(res_parts, axis=0)
3483
+
3484
+ def _res_local_np(u_vec):
3485
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
3486
+ fe_q = res_form(ctx, u_dict, params)
3487
+ res_parts = []
3488
+ for name in (field_a, field_b):
3489
+ fe_field = fe_q[name]
3490
+ if includes_measure.get(name, False):
3491
+ fe = np.sum(np.asarray(fe_field), axis=0)
3492
+ else:
3493
+ wJ = np.asarray(ctx.w) * np.asarray(ctx.detJ)
3494
+ fe = np.einsum("qi...,q->i...", np.asarray(fe_field), wJ)
3495
+ res_parts.append(np.asarray(fe))
3496
+ return np.concatenate(res_parts, axis=0)
3497
+
3498
+ t_jac = time.perf_counter()
3499
+ if backend == "jax":
3500
+ J_local = jax.jacrev(_res_local)(jnp.asarray(u_local))
3501
+ J_local_np = np.asarray(J_local)
3502
+ else:
3503
+ n_ldofs = int(u_local.shape[0])
3504
+ J_local_np = np.zeros((n_ldofs, n_ldofs), dtype=float)
3505
+ u_base = np.asarray(u_local, dtype=float)
3506
+ if log_tri:
3507
+ _trace(f"[CONTACT] tri {it} fd_start n_ldofs={n_ldofs} fd_mode={fd_mode}")
3508
+ r0 = _res_local_np(u_base) if fd_mode == "forward" else None
3509
+ if log_tri and fd_mode == "forward":
3510
+ _trace(f"[CONTACT] tri {it} fd_r0_done")
3511
+ block = max(1, int(fd_block_size))
3512
+ if block <= 1:
3513
+ for i in range(n_ldofs):
3514
+ log_fd = log_tri and i < trace_fd_max
3515
+ if log_fd:
3516
+ _trace(f"[CONTACT] tri {it} fd_col {i} start")
3517
+ u_p = u_base.copy()
3518
+ u_p[i] += fd_eps
3519
+ t_rp = time.perf_counter()
3520
+ r_p = _res_local_np(u_p)
3521
+ if log_fd:
3522
+ _trace_time(f"[CONTACT] tri {it} fd_col {i} r_p", t_rp)
3523
+ if fd_mode == "central":
3524
+ u_m = u_base.copy()
3525
+ u_m[i] -= fd_eps
3526
+ t_rm = time.perf_counter()
3527
+ r_m = _res_local_np(u_m)
3528
+ if log_fd:
3529
+ _trace_time(f"[CONTACT] tri {it} fd_col {i} r_m", t_rm)
3530
+ col = (r_p - r_m) / (2.0 * fd_eps)
3531
+ else:
3532
+ col = (r_p - r0) / fd_eps
3533
+ J_local_np[:, i] = np.asarray(col, dtype=float)
3534
+ if log_fd:
3535
+ _trace(f"[CONTACT] tri {it} fd_col {i} done")
3536
+ else:
3537
+ for i0 in range(0, n_ldofs, block):
3538
+ idxs = np.arange(i0, min(i0 + block, n_ldofs))
3539
+ u_block = np.repeat(u_base[:, None], idxs.size, axis=1)
3540
+ for bi, idx in enumerate(idxs):
3541
+ u_block[idx, bi] += fd_eps
3542
+ t_rp = time.perf_counter()
3543
+ r_p = _res_local_np(u_block)
3544
+ if log_tri and i0 < trace_fd_max:
3545
+ _trace_time(f"[CONTACT] tri {it} fd_block r_p", t_rp)
3546
+ if fd_mode == "central":
3547
+ u_block_m = np.repeat(u_base[:, None], idxs.size, axis=1)
3548
+ for bi, idx in enumerate(idxs):
3549
+ u_block_m[idx, bi] -= fd_eps
3550
+ t_rm = time.perf_counter()
3551
+ r_m = _res_local_np(u_block_m)
3552
+ if log_tri and i0 < trace_fd_max:
3553
+ _trace_time(f"[CONTACT] tri {it} fd_block r_m", t_rm)
3554
+ cols = (r_p - r_m) / (2.0 * fd_eps)
3555
+ else:
3556
+ assert r0 is not None
3557
+ cols = (r_p - r0[:, None]) / fd_eps
3558
+ J_local_np[:, idxs] = np.asarray(cols, dtype=float)
3559
+ if log_tri:
3560
+ _trace_time(f"[CONTACT] tri {it} jac_done", t_jac)
3561
+ _tri_check("jac_done")
3562
+
3563
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
3564
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
3565
+ if diag_force:
3566
+ _diag_contact_projection(
3567
+ fa=int(fa),
3568
+ fb=int(fb),
3569
+ quad_pts=quad_pts,
3570
+ quad_w=quad_w,
3571
+ x_q=x_q,
3572
+ Na=Na,
3573
+ Nb=Nb,
3574
+ nodes_a=nodes_a,
3575
+ nodes_b=nodes_b,
3576
+ dofs_a=dofs_a,
3577
+ dofs_b=dofs_b,
3578
+ elem_coords_a=elem_coords_a if dof_source == "volume" else None,
3579
+ elem_coords_b=elem_coords_b if dof_source == "volume" else None,
3580
+ na=na,
3581
+ nb=nb,
3582
+ normal=normal,
3583
+ normal_source=normal_source,
3584
+ normal_sign=normal_sign,
3585
+ detJ=detJ,
3586
+ diag_facet=diag_facet,
3587
+ diag_max_q=diag_max_q,
3588
+ quad_source=quad_source,
3589
+ tol=tol,
3590
+ )
3591
+ t_scatter = time.perf_counter()
3592
+ dofs = np.concatenate([dofs_a, dofs_b], axis=0)
3593
+ if sparse:
3594
+ n_ldofs = int(dofs.shape[0])
3595
+ rows.extend(np.repeat(dofs, n_ldofs).tolist())
3596
+ cols.extend(np.tile(dofs, n_ldofs).tolist())
3597
+ data.extend(J_local_np.reshape(-1).tolist())
3598
+ else:
3599
+ assert K_dense is not None
3600
+ K_dense[np.ix_(dofs, dofs)] += J_local_np
3601
+ if log_tri:
3602
+ _trace_time(f"[CONTACT] tri {it} scatter_done", t_scatter)
3603
+ _tri_check("scatter_done")
3604
+
3605
+
3606
+ if proj_diag:
3607
+ _proj_diag_report()
3608
+ if sparse:
3609
+ return np.asarray(rows, dtype=int), np.asarray(cols, dtype=int), np.asarray(data, dtype=float), n_total
3610
+ assert K_dense is not None
3611
+ return K_dense
3612
+
3613
+
3614
+ def assemble_onesided_bilinear(
3615
+ surface_slave: SurfaceMesh,
3616
+ u_hat_fn,
3617
+ params: "WeakParams",
3618
+ *,
3619
+ surface_master: SurfaceMesh | None = None,
3620
+ u_master: np.ndarray | None = None,
3621
+ value_dim: int = 3,
3622
+ elem_conn: np.ndarray | None = None,
3623
+ facet_to_elem: np.ndarray | None = None,
3624
+ elem_conn_master: np.ndarray | None = None,
3625
+ facet_to_elem_master: np.ndarray | None = None,
3626
+ grad_source: str = "volume",
3627
+ dof_source: str = "volume",
3628
+ quad_order: int = 2,
3629
+ normal_sign: float = 1.0,
3630
+ tol: float = 1e-8,
3631
+ ) -> tuple[np.ndarray, np.ndarray]:
3632
+ """
3633
+ Assemble one-sided (slave-only) Nitsche matrices without supermesh.
3634
+
3635
+ The master side is treated as prescribed displacement u_hat(x). Provide
3636
+ either u_hat_fn(x_q) or u_master with master element mappings to evaluate
3637
+ u_hat at slave quadrature points.
3638
+
3639
+ Note: this implementation currently assumes volume-trace bases for both
3640
+ gradients and DOFs. Surface-only bases are not supported here yet.
3641
+ """
3642
+ from ..core.forms import FieldPair
3643
+ coords_s = np.asarray(surface_slave.coords, dtype=float)
3644
+ facets_s = np.asarray(surface_slave.conn, dtype=int)
3645
+ coords_m = np.asarray(surface_master.coords, dtype=float) if surface_master is not None else coords_s
3646
+ facets_m = np.asarray(surface_master.conn, dtype=int) if surface_master is not None else facets_s
3647
+ n_s = int(coords_s.shape[0] * value_dim)
3648
+ K: np.ndarray = np.zeros((n_s, n_s), dtype=float)
3649
+ f: np.ndarray = np.zeros((n_s,), dtype=float)
3650
+
3651
+ normals_s = surface_slave.facet_normals() if hasattr(surface_slave, "facet_normals") else None
3652
+ use_elem = elem_conn is not None and facet_to_elem is not None
3653
+ use_master = u_master is not None
3654
+
3655
+ if use_master:
3656
+ if surface_master is None:
3657
+ raise ValueError("surface_master is required when u_master is provided")
3658
+ if elem_conn_master is None or facet_to_elem_master is None:
3659
+ raise ValueError("elem_conn_master and facet_to_elem_master are required when u_master is provided")
3660
+ else:
3661
+ if u_hat_fn is None:
3662
+ raise ValueError("u_hat_fn or u_master must be provided")
3663
+ if surface_master is None:
3664
+ surface_master = surface_slave
3665
+
3666
+ if grad_source != "volume" or dof_source != "volume":
3667
+ raise ValueError("one-sided Nitsche currently supports only volume/volume")
3668
+
3669
+ from ..core.weakform import (
3670
+ Params,
3671
+ compile_mixed_surface_residual_numpy,
3672
+ param_ref,
3673
+ test_ref,
3674
+ unknown_ref,
3675
+ )
3676
+ import fluxfem.helpers_wf as h_wf
3677
+
3678
+ u = unknown_ref("u")
3679
+ v = test_ref("u")
3680
+ p = param_ref()
3681
+ n = h_wf.normal()
3682
+ t_u = h_wf.traction(u, n, p)
3683
+ t_v = h_wf.traction(v, n, p)
3684
+ sym_term = h_wf.einsum("qia,qi->qa", t_v, u.val)
3685
+ sym_term_hat = h_wf.einsum("qia,qi->qa", t_v, p.u_hat)
3686
+ expr = (
3687
+ -h_wf.dot(v, t_u)
3688
+ - sym_term
3689
+ + (p.alpha * p.inv_h) * h_wf.dot(v, u.val)
3690
+ + sym_term_hat
3691
+ - (p.alpha * p.inv_h) * h_wf.dot(v, p.u_hat)
3692
+ ) * h_wf.ds()
3693
+ res_form = compile_mixed_surface_residual_numpy({"u": expr})
3694
+ includes_measure = res_form._includes_measure
3695
+
3696
+ quad_pts, quad_w = _tri_quadrature(quad_order) if quad_order > 0 else (np.array([[1.0 / 3.0, 1.0 / 3.0]]), np.array([0.5]))
3697
+
3698
+ for f_id, facet in enumerate(facets_s):
3699
+ triangles = _facet_triangles(coords_s, facet)
3700
+ if not triangles:
3701
+ continue
3702
+ area_f = _facet_area_estimate(facet, coords_s)
3703
+ if area_f <= tol:
3704
+ continue
3705
+ inv_h = 1.0 / max(np.sqrt(area_f), tol)
3706
+
3707
+ elem_nodes = None
3708
+ elem_coords = None
3709
+ local = None
3710
+ if use_elem:
3711
+ assert facet_to_elem is not None
3712
+ assert elem_conn is not None
3713
+ elem_id = int(facet_to_elem[int(f_id)])
3714
+ if elem_id < 0:
3715
+ raise ValueError("facet_to_elem has invalid mapping")
3716
+ elem_nodes = np.asarray(elem_conn[elem_id], dtype=int)
3717
+ elem_coords = coords_s[elem_nodes]
3718
+
3719
+ for a, b, c in triangles:
3720
+ area = _tri_area(a, b, c)
3721
+ if area <= tol:
3722
+ continue
3723
+ detJ = 2.0 * area
3724
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
3725
+ if use_master:
3726
+ if dof_source == "surface":
3727
+ assert u_master is not None
3728
+ facet_m = facets_m[int(f_id)]
3729
+ u_master_local = _gather_u_local(u_master, facet_m, value_dim).reshape(-1, value_dim)
3730
+ N_master = np.array(
3731
+ [_facet_shape_values(pt, facet_m, coords_m, tol=tol) for pt in x_q],
3732
+ dtype=float,
3733
+ )
3734
+ u_hat = N_master @ u_master_local
3735
+ else:
3736
+ assert u_master is not None
3737
+ assert facet_to_elem_master is not None
3738
+ assert elem_conn_master is not None
3739
+ elem_id_m = int(facet_to_elem_master[int(f_id)])
3740
+ if elem_id_m < 0:
3741
+ raise ValueError("facet_to_elem_master has invalid mapping")
3742
+ elem_nodes_m = np.asarray(elem_conn_master[elem_id_m], dtype=int)
3743
+ elem_coords_m = coords_m[elem_nodes_m]
3744
+ u_master_local = _gather_u_local(u_master, elem_nodes_m, value_dim).reshape(-1, value_dim)
3745
+ N_master = _volume_shape_values_at_points(x_q, elem_coords_m, tol=tol)
3746
+ u_hat = N_master @ u_master_local
3747
+ else:
3748
+ u_hat = np.asarray(u_hat_fn(x_q), dtype=float)
3749
+ if u_hat.shape[0] != x_q.shape[0]:
3750
+ raise ValueError("u_hat_fn must return shape (n_q, value_dim)")
3751
+
3752
+ gradN = None
3753
+ nodes = facet
3754
+ N = None
3755
+
3756
+ if grad_source == "surface":
3757
+ gradN = np.array(
3758
+ [_surface_gradN(pt, facet, coords_s, tol=tol) for pt in x_q],
3759
+ dtype=float,
3760
+ )
3761
+ if use_elem and grad_source == "volume":
3762
+ assert elem_nodes is not None
3763
+ assert elem_coords is not None
3764
+ local = _local_indices(elem_nodes, facet)
3765
+ gradN = _tet_gradN_at_points(x_q, elem_coords, local=local, tol=tol)
3766
+
3767
+ if dof_source == "volume":
3768
+ if not use_elem or elem_nodes is None or elem_coords is None:
3769
+ raise ValueError("dof_source 'volume' requires elem_conn and facet_to_elem")
3770
+ nodes = elem_nodes
3771
+ N = _volume_shape_values_at_points(x_q, elem_coords, tol=tol)
3772
+ if grad_source == "volume":
3773
+ gradN = _tet_gradN_at_points(x_q, elem_coords, tol=tol)
3774
+ else:
3775
+ N = np.array([_facet_shape_values(pt, facet, coords_s, tol=tol) for pt in x_q], dtype=float)
3776
+
3777
+ field = SurfaceMixedFormField(
3778
+ N=N,
3779
+ gradN=gradN,
3780
+ value_dim=value_dim,
3781
+ basis=_SurfaceBasis(dofs_per_node=value_dim),
3782
+ )
3783
+ fields = {"u": FieldPair(test=cast("FormFieldLike", field), trial=cast("FormFieldLike", field))}
3784
+ normal = normals_s[int(f_id)] if normals_s is not None else None
3785
+ if normal is not None:
3786
+ normal = normal_sign * normal
3787
+ normal_q = None if normal is None else np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
3788
+ ctx = SurfaceMixedFormContext(
3789
+ fields=fields,
3790
+ x_q=x_q,
3791
+ w=quad_w,
3792
+ detJ=np.array([detJ], dtype=float),
3793
+ normal=normal_q,
3794
+ trial_fields={"u": field},
3795
+ test_fields={"u": field},
3796
+ unknown_fields={"u": field},
3797
+ )
3798
+ params_local = Params(
3799
+ lam=params.lam,
3800
+ mu=params.mu,
3801
+ alpha=params.alpha,
3802
+ inv_h=inv_h,
3803
+ u_hat=u_hat,
3804
+ )
3805
+ u_zero: np.ndarray = np.zeros((len(nodes) * value_dim,), dtype=float)
3806
+ u_dict = {"u": u_zero}
3807
+ sizes = (u_zero.shape[0],)
3808
+ slices = {"u": slice(0, sizes[0])}
3809
+
3810
+ def _res_local_np_single(u_vec: np.ndarray) -> np.ndarray:
3811
+ u_local = {"u": u_vec[slices["u"]]}
3812
+ fe_q = res_form(ctx, u_local, params_local)["u"]
3813
+ if includes_measure.get("u", False):
3814
+ return np.sum(np.asarray(fe_q), axis=0)
3815
+ wJ = np.asarray(ctx.w) * np.asarray(ctx.detJ)
3816
+ return np.einsum("qi,q->i", np.asarray(fe_q), wJ)
3817
+
3818
+ def _res_local_np(u_vec: np.ndarray) -> np.ndarray:
3819
+ if u_vec.ndim == 1:
3820
+ return _res_local_np_single(u_vec)
3821
+ out = np.empty((u_vec.shape[0], u_vec.shape[1]), dtype=float)
3822
+ for col in range(u_vec.shape[1]):
3823
+ out[:, col] = _res_local_np_single(u_vec[:, col])
3824
+ return out
3825
+
3826
+ f_local = _res_local_np(u_zero)
3827
+ n_ldofs = int(u_zero.shape[0])
3828
+ k_local: np.ndarray = np.zeros((n_ldofs, n_ldofs), dtype=float)
3829
+ block = max(1, int(os.getenv("FLUXFEM_ONESIDE_BLOCK_SIZE", "16")))
3830
+ for start in range(0, n_ldofs, block):
3831
+ idxs: np.ndarray = np.arange(start, min(n_ldofs, start + block), dtype=int)
3832
+ u_block: np.ndarray = np.zeros((n_ldofs, idxs.size), dtype=float)
3833
+ u_block[idxs, np.arange(idxs.size, dtype=int)] = 1.0
3834
+ r_block = _res_local_np(u_block)
3835
+ k_local[:, idxs] = r_block - f_local[:, None]
3836
+
3837
+ dofs = _global_dof_indices(nodes, value_dim, 0)
3838
+ f[dofs] += f_local
3839
+ K[np.ix_(dofs, dofs)] += k_local
3840
+
3841
+ return K, f
3842
+
3843
+
3844
+ def assemble_contact_onesided_floor(
3845
+ surface_slave: SurfaceMesh,
3846
+ u: np.ndarray,
3847
+ *,
3848
+ n: np.ndarray | None = None,
3849
+ c: float,
3850
+ k: float,
3851
+ beta: float,
3852
+ value_dim: int = 3,
3853
+ elem_conn: np.ndarray | None = None,
3854
+ facet_to_elem: np.ndarray | None = None,
3855
+ quad_order: int = 2,
3856
+ normal_sign: float = 1.0,
3857
+ tol: float = 1e-8,
3858
+ return_metrics: bool = False,
3859
+ ) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, dict[str, float]]:
3860
+ """
3861
+ Assemble one-sided contact penalty against a rigid plane g = n·x - c.
3862
+
3863
+ Uses softplus for a smooth contact pressure:
3864
+ p(g) = k * softplus(-g; beta)
3865
+ with softplus(z; beta) = (1 / beta) * log(1 + exp(beta z)).
3866
+
3867
+ Note: the resulting stiffness matrix can be nonsymmetric; avoid CG.
3868
+ """
3869
+ if elem_conn is None or facet_to_elem is None:
3870
+ raise ValueError("elem_conn and facet_to_elem are required")
3871
+ if beta <= 0.0:
3872
+ raise ValueError("beta must be positive")
3873
+
3874
+ import jax
3875
+ import jax.numpy as jnp
3876
+
3877
+ coords_s = np.asarray(surface_slave.coords, dtype=float)
3878
+ facets_s = np.asarray(surface_slave.conn, dtype=int)
3879
+ n_s = int(coords_s.shape[0] * value_dim)
3880
+ K: np.ndarray = np.zeros((n_s, n_s), dtype=float)
3881
+ f: np.ndarray = np.zeros((n_s,), dtype=float)
3882
+
3883
+ normals_s = surface_slave.facet_normals() if hasattr(surface_slave, "facet_normals") else None
3884
+ if n is not None:
3885
+ n = np.asarray(n, dtype=float).reshape(-1)
3886
+ if n.shape[0] != 3:
3887
+ raise ValueError("n must be a 3-vector")
3888
+ n_norm = np.linalg.norm(n)
3889
+ if n_norm <= tol:
3890
+ raise ValueError("n must be non-zero")
3891
+ n = (n / n_norm) * float(normal_sign)
3892
+ elif normals_s is None:
3893
+ raise ValueError("surface normals are required when n is not provided")
3894
+
3895
+ penetration = 0.0
3896
+ min_g = float("inf")
3897
+ quad_pts, quad_w = _tri_quadrature(quad_order) if quad_order > 0 else (np.array([[1.0 / 3.0, 1.0 / 3.0]]), np.array([0.5]))
3898
+
3899
+ for f_id, facet in enumerate(facets_s):
3900
+ triangles = _facet_triangles(coords_s, facet)
3901
+ if not triangles:
3902
+ continue
3903
+ area_f = _facet_area_estimate(facet, coords_s)
3904
+ if area_f <= tol:
3905
+ continue
3906
+
3907
+ elem_id = int(facet_to_elem[int(f_id)])
3908
+ if elem_id < 0:
3909
+ raise ValueError("facet_to_elem has invalid mapping")
3910
+ elem_nodes = np.asarray(elem_conn[elem_id], dtype=int)
3911
+ elem_coords = coords_s[elem_nodes]
3912
+ u_local = _gather_u_local(u, elem_nodes, value_dim).reshape(-1, value_dim)
3913
+
3914
+ if n is not None:
3915
+ normal = n
3916
+ else:
3917
+ assert normals_s is not None
3918
+ normal = normal_sign * normals_s[int(f_id)]
3919
+
3920
+ for a, b, c_tri in triangles:
3921
+ area = _tri_area(a, b, c_tri)
3922
+ if area <= tol:
3923
+ continue
3924
+ detJ = 2.0 * area
3925
+ x_q_ref = np.array([a + r * (b - a) + s * (c_tri - a) for r, s in quad_pts], dtype=float)
3926
+ N = _volume_shape_values_at_points(x_q_ref, elem_coords, tol=tol)
3927
+
3928
+ normal_q = np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
3929
+
3930
+ u_q_np = N @ u_local
3931
+ x_q_cur = x_q_ref + u_q_np
3932
+ g_np = np.sum(normal_q * x_q_cur, axis=1) - float(c)
3933
+ min_g = min(min_g, float(np.min(g_np)))
3934
+ z_np = -float(beta) * g_np
3935
+ z_clip = np.minimum(z_np, 30.0)
3936
+ softplus_np = np.where(z_np > 30.0, z_np, np.log1p(np.exp(z_clip))) / float(beta)
3937
+ penetration += float(np.sum(softplus_np * quad_w) * detJ)
3938
+
3939
+ def _res_local(u_vec):
3940
+ u_loc = u_vec.reshape(-1, value_dim)
3941
+ u_q = jnp.einsum("qi,ia->qa", jnp.asarray(N), u_loc)
3942
+ x_q_j = jnp.asarray(x_q_ref)
3943
+ n_q = jnp.asarray(normal_q)
3944
+ x_q_cur_j = x_q_j + u_q
3945
+ g = jnp.einsum("qa,qa->q", n_q, x_q_cur_j) - float(c)
3946
+ p = float(k) * jax.nn.softplus(-float(beta) * g) / float(beta)
3947
+ t = p[:, None] * n_q
3948
+ wJ = jnp.asarray(quad_w) * float(detJ)
3949
+ nodal = jnp.einsum("qi,qa,q->ia", jnp.asarray(N), t, wJ)
3950
+ return nodal.reshape(-1)
3951
+
3952
+ u_vec0 = np.asarray(u_local.reshape(-1), dtype=float)
3953
+ f_local = np.asarray(_res_local(jnp.asarray(u_vec0)))
3954
+ k_local = np.asarray(jax.jacrev(_res_local)(jnp.asarray(u_vec0)))
3955
+
3956
+ dofs = _global_dof_indices(elem_nodes, value_dim, 0)
3957
+ for i, gi in enumerate(dofs):
3958
+ f[int(gi)] += float(f_local[i])
3959
+ for j, gj in enumerate(dofs):
3960
+ K[int(gi), int(gj)] += float(k_local[i, j])
3961
+
3962
+ if return_metrics:
3963
+ if min_g == float("inf"):
3964
+ min_g = 0.0
3965
+ metrics = {
3966
+ "penetration": float(penetration),
3967
+ "min_g": float(min_g),
3968
+ }
3969
+ return K, f, metrics
3970
+ return K, f