fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/mesh/mortar.py ADDED
@@ -0,0 +1,3907 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import os
5
+ import time
6
+ from typing import Iterable, TYPE_CHECKING
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+
12
+ from .surface import SurfaceMesh
13
+ if TYPE_CHECKING:
14
+ from ..core.forms import FieldPair
15
+ from ..core.weakform import Params as WeakParams
16
+
17
+
18
+ @dataclass(eq=False)
19
+ class _SurfaceBasis:
20
+ dofs_per_node: int
21
+
22
+
23
+ @dataclass(eq=False)
24
+ class SurfaceMixedFormField:
25
+ """Surface form field for mixed weak-form evaluation."""
26
+ N: np.ndarray
27
+ gradN: np.ndarray | None
28
+ value_dim: int
29
+ basis: _SurfaceBasis
30
+
31
+
32
+ @dataclass(eq=False)
33
+ class SurfaceMixedFormContext:
34
+ """Surface mixed context for weak-form evaluation on supermesh."""
35
+ fields: dict[str, "FieldPair"]
36
+ x_q: np.ndarray
37
+ w: np.ndarray
38
+ detJ: np.ndarray
39
+ normal: np.ndarray | None = None
40
+ trial_fields: dict[str, SurfaceMixedFormField] | None = None
41
+ test_fields: dict[str, SurfaceMixedFormField] | None = None
42
+ unknown_fields: dict[str, SurfaceMixedFormField] | None = None
43
+
44
+
45
+ _DEBUG_SURFACE_GRADN = os.getenv("FLUXFEM_DEBUG_SURFACE_GRADN")
46
+ _DEBUG_SURFACE_GRADN_MAX = int(os.getenv("FLUXFEM_DEBUG_SURFACE_GRADN_MAX", "8")) if _DEBUG_SURFACE_GRADN else 0
47
+ _DEBUG_SURFACE_GRADN_COUNT = 0
48
+ _DEBUG_SURFACE_SOURCE_ONCE = False
49
+ _DEBUG_CONTACT_MAP_ONCE = False
50
+ _DEBUG_CONTACT_N_ONCE = False
51
+ _DEBUG_PROJECTION_DIAG = os.getenv("FLUXFEM_PROJ_DIAG")
52
+ _DEBUG_PROJECTION_DIAG_MAX = int(os.getenv("FLUXFEM_PROJ_DIAG_MAX", "20")) if _DEBUG_PROJECTION_DIAG else 0
53
+ _DEBUG_CONTACT_PROJ_ONCE = False
54
+ _DEBUG_PROJ_QP_CACHE = None
55
+ _DEBUG_PROJ_QP_SOURCE = None
56
+ _DEBUG_PROJ_QP_DUMPED = False
57
+ _PROJ_DIAG_STATS = None
58
+ _PROJ_DIAG_COUNT = 0
59
+ _PROJ_DIAG_CONTEXT: dict[str, int | str] = {}
60
+
61
+
62
+ def _mortar_dbg_enabled() -> bool:
63
+ return os.getenv("FLUXFEM_MORTAR_DEBUG", "0") not in ("0", "", "false", "False")
64
+
65
+
66
+ def _mortar_dbg(msg: str) -> None:
67
+ if _mortar_dbg_enabled():
68
+ print(msg, flush=True)
69
+
70
+
71
+ def _env_flag(name: str, default: bool) -> bool:
72
+ raw = os.getenv(name)
73
+ if raw is None:
74
+ return default
75
+ return raw not in ("0", "", "false", "False")
76
+
77
+
78
+ @dataclass(eq=False)
79
+ class MortarMatrix:
80
+ """COO storage for mortar coupling matrices (can be rectangular)."""
81
+ rows: np.ndarray
82
+ cols: np.ndarray
83
+ data: np.ndarray
84
+ shape: tuple[int, int]
85
+
86
+
87
+ def _tri_area(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> float:
88
+ return 0.5 * float(np.linalg.norm(np.cross(b - a, c - a)))
89
+
90
+
91
+ def tri_area(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> float:
92
+ """Public wrapper for triangle area (used in contact diagnostics)."""
93
+ return _tri_area(a, b, c)
94
+
95
+
96
+ def tri_quadrature(order: int) -> tuple[np.ndarray, np.ndarray]:
97
+ """Public wrapper for triangle quadrature."""
98
+ return _tri_quadrature(order)
99
+
100
+
101
+ def facet_triangles(coords: np.ndarray, facet_nodes: np.ndarray) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
102
+ """Public wrapper for facet triangulation."""
103
+ return _facet_triangles(coords, facet_nodes)
104
+
105
+
106
+ def facet_shape_values(point: np.ndarray, facet_nodes: np.ndarray, coords: np.ndarray, *, tol: float) -> np.ndarray:
107
+ """Public wrapper for facet shape values at a point."""
108
+ return _facet_shape_values(point, facet_nodes, coords, tol=tol)
109
+
110
+
111
+ def volume_shape_values_at_points(x_q: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
112
+ """Public wrapper for volume shape values at quadrature points."""
113
+ return _volume_shape_values_at_points(x_q, elem_coords, tol=tol)
114
+
115
+
116
+ def quad_shape_and_local(
117
+ point: np.ndarray,
118
+ quad_nodes: np.ndarray,
119
+ corner_coords: np.ndarray,
120
+ *,
121
+ tol: float,
122
+ ) -> tuple[np.ndarray, float, float]:
123
+ """Public wrapper for quad shape values and local coordinates."""
124
+ return _quad_shape_and_local(point, quad_nodes, corner_coords, tol=tol)
125
+
126
+
127
+ def quad9_shape_values(xi: float, eta: float) -> np.ndarray:
128
+ """Public wrapper for quad9 shape values."""
129
+ return _quad9_shape_values(xi, eta)
130
+
131
+
132
+ def hex27_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
133
+ """Public wrapper for hex27 gradN (diagnostics)."""
134
+ return _hex27_gradN(point, elem_coords, tol=tol)
135
+
136
+
137
+ def _quad_quadrature(order: int) -> tuple[np.ndarray, np.ndarray]:
138
+ if order <= 1:
139
+ order = 2
140
+ n = int(np.ceil((order + 1.0) / 2.0))
141
+ x1d, w1d = np.polynomial.legendre.leggauss(n)
142
+ X, Y = np.meshgrid(x1d, x1d, indexing="xy")
143
+ W = np.outer(w1d, w1d)
144
+ pts = np.stack([X.ravel(), Y.ravel()], axis=1)
145
+ w = W.ravel()
146
+ return pts, w
147
+
148
+
149
+ def _facet_area_estimate(facet_nodes: np.ndarray, coords: np.ndarray) -> float:
150
+ n = int(len(facet_nodes))
151
+ if n == 3:
152
+ pts = coords[facet_nodes]
153
+ return _tri_area(pts[0], pts[1], pts[2])
154
+ if n == 4:
155
+ pts = coords[facet_nodes]
156
+ return _tri_area(pts[0], pts[1], pts[2]) + _tri_area(pts[0], pts[2], pts[3])
157
+ if n == 8:
158
+ corner_nodes = facet_nodes[:4]
159
+ pts = coords[corner_nodes]
160
+ return _tri_area(pts[0], pts[1], pts[2]) + _tri_area(pts[0], pts[2], pts[3])
161
+ if n == 9:
162
+ corner_nodes = facet_nodes[[0, 2, 8, 6]]
163
+ pts = coords[corner_nodes]
164
+ return _tri_area(pts[0], pts[1], pts[2]) + _tri_area(pts[0], pts[2], pts[3])
165
+ pts = coords[facet_nodes]
166
+ area = 0.0
167
+ p0 = pts[0]
168
+ for i in range(1, len(pts) - 1):
169
+ area += _tri_area(p0, pts[i], pts[i + 1])
170
+ return float(area)
171
+
172
+
173
+ def _facet_triangles(coords: np.ndarray, facet_nodes: np.ndarray) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
174
+ n = int(len(facet_nodes))
175
+ if n in {3, 6}:
176
+ corner = facet_nodes[:3]
177
+ pts = coords[corner]
178
+ return [(pts[0], pts[1], pts[2])]
179
+ if n == 4:
180
+ corner = facet_nodes
181
+ elif n == 8:
182
+ corner = facet_nodes[:4]
183
+ elif n == 9:
184
+ corner = facet_nodes[[0, 2, 8, 6]]
185
+ else:
186
+ corner = facet_nodes
187
+ pts = coords[corner]
188
+ if len(pts) < 3:
189
+ return []
190
+ if len(pts) == 3:
191
+ return [(pts[0], pts[1], pts[2])]
192
+ tris = [(pts[0], pts[1], pts[2])]
193
+ if len(pts) >= 4:
194
+ tris.append((pts[0], pts[2], pts[3]))
195
+ if len(pts) > 4:
196
+ for i in range(2, len(pts) - 1):
197
+ tris.append((pts[0], pts[i], pts[i + 1]))
198
+ return tris
199
+
200
+
201
+
202
+
203
+ def _tri_centroid(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
204
+ return (a + b + c) / 3.0
205
+
206
+
207
+ def _tri_quadrature(order: int) -> tuple[np.ndarray, np.ndarray]:
208
+ """
209
+ Return reference triangle quadrature points (r, s) and weights.
210
+ Reference triangle is (0,0), (1,0), (0,1); weights integrate over area 1/2.
211
+ """
212
+ if order <= 0:
213
+ return np.array([[1.0 / 3.0, 1.0 / 3.0]]), np.array([0.5])
214
+ if order <= 2:
215
+ pts = np.array(
216
+ [
217
+ [1.0 / 6.0, 1.0 / 6.0],
218
+ [2.0 / 3.0, 1.0 / 6.0],
219
+ [1.0 / 6.0, 2.0 / 3.0],
220
+ ],
221
+ dtype=float,
222
+ )
223
+ weights = np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0], dtype=float)
224
+ return pts, weights
225
+ if order <= 3:
226
+ pts = np.array(
227
+ [
228
+ [1.0 / 3.0, 1.0 / 3.0],
229
+ [0.2, 0.2],
230
+ [0.6, 0.2],
231
+ [0.2, 0.6],
232
+ ],
233
+ dtype=float,
234
+ )
235
+ weights = np.array(
236
+ [-27.0 / 96.0, 25.0 / 96.0, 25.0 / 96.0, 25.0 / 96.0],
237
+ dtype=float,
238
+ )
239
+ return pts, weights
240
+ if order <= 4:
241
+ a = 0.445948490915965
242
+ b = 0.108103018168070
243
+ c = 0.091576213509771
244
+ d = 0.816847572980459
245
+ pts = np.array(
246
+ [
247
+ [a, a],
248
+ [a, b],
249
+ [b, a],
250
+ [c, c],
251
+ [c, d],
252
+ [d, c],
253
+ ],
254
+ dtype=float,
255
+ )
256
+ weights = np.array(
257
+ [
258
+ 0.111690794839005,
259
+ 0.111690794839005,
260
+ 0.111690794839005,
261
+ 0.054975871827661,
262
+ 0.054975871827661,
263
+ 0.054975871827661,
264
+ ],
265
+ dtype=float,
266
+ )
267
+ return pts, weights
268
+ if order <= 5:
269
+ a = 0.470142064105115
270
+ b = 0.059715871789770
271
+ c = 0.101286507323456
272
+ d = 0.797426985353087
273
+ pts = np.array(
274
+ [
275
+ [1.0 / 3.0, 1.0 / 3.0],
276
+ [a, a],
277
+ [a, b],
278
+ [b, a],
279
+ [c, c],
280
+ [c, d],
281
+ [d, c],
282
+ ],
283
+ dtype=float,
284
+ )
285
+ weights = np.array(
286
+ [
287
+ 0.225000000000000,
288
+ 0.132394152788506,
289
+ 0.132394152788506,
290
+ 0.132394152788506,
291
+ 0.125939180544827,
292
+ 0.125939180544827,
293
+ 0.125939180544827,
294
+ ],
295
+ dtype=float,
296
+ )
297
+ weights *= 0.5
298
+ return pts, weights
299
+ raise NotImplementedError("triangle quadrature order > 5 is not implemented")
300
+
301
+
302
+ def _proj_diag_enabled() -> bool:
303
+ return os.getenv("FLUXFEM_PROJ_DIAG", "0") == "1"
304
+
305
+
306
+ def _proj_diag_max() -> int:
307
+ return int(os.getenv("FLUXFEM_PROJ_DIAG_MAX", "20"))
308
+
309
+
310
+ def _proj_diag_reset() -> None:
311
+ global _PROJ_DIAG_STATS, _PROJ_DIAG_COUNT
312
+ _PROJ_DIAG_STATS = {
313
+ "total": 0,
314
+ "fail": 0,
315
+ "by_code": {},
316
+ }
317
+ _PROJ_DIAG_COUNT = 0
318
+
319
+
320
+ def _proj_diag_set_context(
321
+ *,
322
+ fa: int,
323
+ fb: int,
324
+ face_a: str,
325
+ face_b: str,
326
+ elem_a: int,
327
+ elem_b: int,
328
+ ) -> None:
329
+ _PROJ_DIAG_CONTEXT.clear()
330
+ _PROJ_DIAG_CONTEXT.update(
331
+ {
332
+ "fa": int(fa),
333
+ "fb": int(fb),
334
+ "face_a": face_a,
335
+ "face_b": face_b,
336
+ "elem_a": int(elem_a),
337
+ "elem_b": int(elem_b),
338
+ }
339
+ )
340
+
341
+
342
+ def _proj_diag_attempt() -> None:
343
+ if _PROJ_DIAG_STATS is None:
344
+ return
345
+ _PROJ_DIAG_STATS["total"] += 1
346
+
347
+
348
+ def _proj_diag_log(
349
+ code: str,
350
+ *,
351
+ iters: int,
352
+ res_norm: float,
353
+ delta_norm: float | None,
354
+ detJ: float | None,
355
+ point: np.ndarray,
356
+ local: np.ndarray,
357
+ in_ref_domain: bool,
358
+ ) -> None:
359
+ global _PROJ_DIAG_COUNT
360
+ if _PROJ_DIAG_STATS is None:
361
+ return
362
+ _PROJ_DIAG_STATS["fail"] += 1
363
+ by_code = _PROJ_DIAG_STATS["by_code"]
364
+ by_code[code] = by_code.get(code, 0) + 1
365
+ if _PROJ_DIAG_COUNT >= _proj_diag_max():
366
+ return
367
+ _PROJ_DIAG_COUNT += 1
368
+ ctx = " ".join(f"{k}={v}" for k, v in _PROJ_DIAG_CONTEXT.items()) if _PROJ_DIAG_CONTEXT else "ctx=unknown"
369
+ det_str = "None" if detJ is None else f"{detJ:.6e}"
370
+ delta_str = "None" if delta_norm is None else f"{delta_norm:.6e}"
371
+ print(
372
+ "[fluxfem][proj][fail]",
373
+ f"code={code}",
374
+ ctx,
375
+ f"iters={iters}",
376
+ f"res={res_norm:.6e}",
377
+ f"delta={delta_str}",
378
+ f"detJ={det_str}",
379
+ f"in_ref={bool(in_ref_domain)}",
380
+ f"point={point.tolist()}",
381
+ f"local={local.tolist()}",
382
+ )
383
+
384
+
385
+ def _proj_diag_report() -> None:
386
+ if _PROJ_DIAG_STATS is None:
387
+ return
388
+ total = _PROJ_DIAG_STATS["total"]
389
+ fail = _PROJ_DIAG_STATS["fail"]
390
+ by_code = _PROJ_DIAG_STATS["by_code"]
391
+ print("[fluxfem][proj][diag] total=", total, "fail=", fail, "by_code=", by_code)
392
+
393
+
394
+ def _facet_label(facet: np.ndarray) -> str:
395
+ n = int(len(facet))
396
+ if n == 3:
397
+ return "tri3"
398
+ if n == 4:
399
+ return "quad4"
400
+ if n == 6:
401
+ return "tri6"
402
+ if n == 8:
403
+ return "quad8"
404
+ if n == 9:
405
+ return "quad9"
406
+ return f"n{n}"
407
+
408
+
409
+ def _diag_quad_override(diag_force: bool, mode: str, path: str) -> tuple[np.ndarray, np.ndarray] | None:
410
+ global _DEBUG_PROJ_QP_CACHE, _DEBUG_PROJ_QP_SOURCE
411
+ if not diag_force or mode != "load" or not path:
412
+ return None
413
+ if _DEBUG_PROJ_QP_CACHE is None:
414
+ data = np.load(path)
415
+ _DEBUG_PROJ_QP_CACHE = (np.asarray(data["quad_pts"], dtype=float), np.asarray(data["quad_w"], dtype=float))
416
+ _DEBUG_PROJ_QP_SOURCE = f"file:{path}"
417
+ return _DEBUG_PROJ_QP_CACHE
418
+
419
+
420
+ def _diag_quad_dump(diag_force: bool, mode: str, path: str, quad_pts: np.ndarray, quad_w: np.ndarray) -> None:
421
+ global _DEBUG_PROJ_QP_DUMPED
422
+ if not diag_force or mode != "dump" or not path or _DEBUG_PROJ_QP_DUMPED:
423
+ return
424
+ np.savez(path, quad_pts=np.asarray(quad_pts, dtype=float), quad_w=np.asarray(quad_w, dtype=float))
425
+ _DEBUG_PROJ_QP_DUMPED = True
426
+
427
+
428
+ def _volume_local_coords(point: np.ndarray, elem_coords: np.ndarray, *, tol: float):
429
+ n_nodes = elem_coords.shape[0]
430
+ if n_nodes in {4, 10}:
431
+ corner_coords = elem_coords[:4]
432
+ M = np.stack([corner_coords[:, 0], corner_coords[:, 1], corner_coords[:, 2], np.ones(4)], axis=1)
433
+ rhs = np.array([point[0], point[1], point[2], 1.0], dtype=float)
434
+ try:
435
+ lam = np.linalg.solve(M.T, rhs)
436
+ except np.linalg.LinAlgError:
437
+ return None
438
+ return lam
439
+ if n_nodes == 8:
440
+ _, xi, eta, zeta = _hex8_shape_and_local(point, elem_coords, tol=tol)
441
+ return np.array([xi, eta, zeta], dtype=float)
442
+ if n_nodes == 20:
443
+ _, xi, eta, zeta = _hex20_shape_and_local(point, elem_coords, tol=tol)
444
+ return np.array([xi, eta, zeta], dtype=float)
445
+ if n_nodes == 27:
446
+ _, xi, eta, zeta = _hex27_shape_and_local(point, elem_coords, tol=tol)
447
+ return np.array([xi, eta, zeta], dtype=float)
448
+ return None
449
+
450
+
451
+ def _diag_contact_projection(
452
+ *,
453
+ fa: int,
454
+ fb: int,
455
+ quad_pts: np.ndarray,
456
+ quad_w: np.ndarray,
457
+ x_q: np.ndarray,
458
+ Na: np.ndarray,
459
+ Nb: np.ndarray,
460
+ nodes_a: np.ndarray,
461
+ nodes_b: np.ndarray,
462
+ dofs_a: np.ndarray,
463
+ dofs_b: np.ndarray,
464
+ elem_coords_a: np.ndarray | None,
465
+ elem_coords_b: np.ndarray | None,
466
+ na: np.ndarray | None,
467
+ nb: np.ndarray | None,
468
+ normal: np.ndarray | None,
469
+ normal_source: str,
470
+ normal_sign: float,
471
+ detJ: float,
472
+ diag_facet: int,
473
+ diag_max_q: int,
474
+ quad_source: str,
475
+ tol: float,
476
+ ) -> None:
477
+ global _DEBUG_CONTACT_PROJ_ONCE
478
+ if _DEBUG_CONTACT_PROJ_ONCE:
479
+ return
480
+ if diag_facet >= 0 and fa != diag_facet:
481
+ return
482
+ samples = min(diag_max_q, int(x_q.shape[0]))
483
+ print("[fluxfem][diag][proj] first facet")
484
+ print(f" fa={fa} fb={fb} quad_source={quad_source}")
485
+ print(f" quad_pts={quad_pts.tolist()} quad_w={quad_w.tolist()}")
486
+ print(f" normal_source={normal_source} normal_sign={normal_sign}")
487
+ print(f" n_master={None if na is None else na.tolist()}")
488
+ print(f" n_slave={None if nb is None else nb.tolist()}")
489
+ print(f" n_used={None if normal is None else normal.tolist()}")
490
+ if normal is not None and na is not None:
491
+ print(f" dot(n_used,n_master)={float(np.dot(normal, na)):.6e}")
492
+ if normal is not None and nb is not None:
493
+ print(f" dot(n_used,n_slave)={float(np.dot(normal, nb)):.6e}")
494
+ print(f" detJ={float(detJ):.6e}")
495
+ print(f" nodes_a={nodes_a.tolist()} nodes_b={nodes_b.tolist()}")
496
+ print(f" dofs_a={dofs_a.tolist()} dofs_b={dofs_b.tolist()}")
497
+ for qi in range(samples):
498
+ nsum_a = float(np.sum(Na[qi]))
499
+ nsum_b = float(np.sum(Nb[qi]))
500
+ xq = x_q[qi]
501
+ msg = f" q{qi} x={xq.tolist()} sum(Na)={nsum_a:.6e} sum(Nb)={nsum_b:.6e}"
502
+ if elem_coords_a is not None:
503
+ xa = Na[qi] @ elem_coords_a
504
+ msg += f" x_a={xa.tolist()} |x_a-x_q|={float(np.linalg.norm(xa - xq)):.6e}"
505
+ local_a = _volume_local_coords(xq, elem_coords_a, tol=tol)
506
+ if local_a is not None:
507
+ msg += f" xi_a={local_a.tolist()}"
508
+ if elem_coords_b is not None:
509
+ xb = Nb[qi] @ elem_coords_b
510
+ msg += f" x_b={xb.tolist()} |x_b-x_q|={float(np.linalg.norm(xb - xq)):.6e}"
511
+ local_b = _volume_local_coords(xq, elem_coords_b, tol=tol)
512
+ if local_b is not None:
513
+ msg += f" xi_b={local_b.tolist()}"
514
+ print(msg)
515
+ _DEBUG_CONTACT_PROJ_ONCE = True
516
+
517
+
518
+ def _barycentric(p: np.ndarray, a: np.ndarray, b: np.ndarray, c: np.ndarray):
519
+ v0 = b - a
520
+ v1 = c - a
521
+ v2 = p - a
522
+ d00 = float(np.dot(v0, v0))
523
+ d01 = float(np.dot(v0, v1))
524
+ d11 = float(np.dot(v1, v1))
525
+ d20 = float(np.dot(v2, v0))
526
+ d21 = float(np.dot(v2, v1))
527
+ denom = d00 * d11 - d01 * d01
528
+ if abs(denom) < 1e-14:
529
+ return None
530
+ v = (d11 * d20 - d01 * d21) / denom
531
+ w = (d00 * d21 - d01 * d20) / denom
532
+ u = 1.0 - v - w
533
+ return np.array([u, v, w], dtype=float)
534
+
535
+
536
+ def _point_in_tri(lam: np.ndarray, *, tol: float) -> bool:
537
+ return np.all(lam >= -tol) and np.all(lam <= 1.0 + tol)
538
+
539
+
540
+ def _plane_basis(pts: np.ndarray, *, tol: float):
541
+ v1 = pts[1] - pts[0]
542
+ v2 = pts[3] - pts[0] if pts.shape[0] > 3 else pts[2] - pts[0]
543
+ n = np.cross(v1, v2)
544
+ n_norm = np.linalg.norm(n)
545
+ if n_norm < tol:
546
+ return None, None
547
+ n = n / n_norm
548
+ t1 = v1 / np.linalg.norm(v1)
549
+ v2_proj = v2 - np.dot(v2, t1) * t1
550
+ v2_norm = np.linalg.norm(v2_proj)
551
+ if v2_norm < tol:
552
+ return None, None
553
+ t2 = v2_proj / v2_norm
554
+ return t1, t2
555
+
556
+
557
+ def _quad_shape_and_local(
558
+ point: np.ndarray,
559
+ facet_nodes: np.ndarray,
560
+ coords: np.ndarray,
561
+ *,
562
+ tol: float,
563
+ ) -> tuple[np.ndarray, float, float]:
564
+ if _proj_diag_enabled():
565
+ _proj_diag_attempt()
566
+ pts = coords[facet_nodes]
567
+ basis = _plane_basis(pts, tol=tol)
568
+ if basis[0] is None:
569
+ return np.zeros((4,), dtype=float), 0.0, 0.0
570
+ t1, t2 = basis
571
+ origin = pts[0]
572
+ local = (pts - origin) @ np.stack([t1, t2], axis=1)
573
+ p_local = (point - origin) @ np.stack([t1, t2], axis=1)
574
+ x = local[:, 0]
575
+ y = local[:, 1]
576
+ xp = float(p_local[0])
577
+ yp = float(p_local[1])
578
+
579
+ xi = 0.0
580
+ eta = 0.0
581
+ res_norm = 0.0
582
+ detJ = None
583
+ iters = 0
584
+ for _ in range(12):
585
+ iters += 1
586
+ n1 = 0.25 * (1.0 - xi) * (1.0 - eta)
587
+ n2 = 0.25 * (1.0 + xi) * (1.0 - eta)
588
+ n3 = 0.25 * (1.0 + xi) * (1.0 + eta)
589
+ n4 = 0.25 * (1.0 - xi) * (1.0 + eta)
590
+ x_m = n1 * x[0] + n2 * x[1] + n3 * x[2] + n4 * x[3]
591
+ y_m = n1 * y[0] + n2 * y[1] + n3 * y[2] + n4 * y[3]
592
+ rx = x_m - xp
593
+ ry = y_m - yp
594
+ res_norm = float(np.hypot(rx, ry))
595
+ if abs(rx) + abs(ry) < tol:
596
+ break
597
+ dndxi = np.array(
598
+ [
599
+ -0.25 * (1.0 - eta),
600
+ 0.25 * (1.0 - eta),
601
+ 0.25 * (1.0 + eta),
602
+ -0.25 * (1.0 + eta),
603
+ ],
604
+ dtype=float,
605
+ )
606
+ dndeta = np.array(
607
+ [
608
+ -0.25 * (1.0 - xi),
609
+ -0.25 * (1.0 + xi),
610
+ 0.25 * (1.0 + xi),
611
+ 0.25 * (1.0 - xi),
612
+ ],
613
+ dtype=float,
614
+ )
615
+ j11 = float(np.dot(dndxi, x))
616
+ j12 = float(np.dot(dndeta, x))
617
+ j21 = float(np.dot(dndxi, y))
618
+ j22 = float(np.dot(dndeta, y))
619
+ det = j11 * j22 - j12 * j21
620
+ detJ = float(det)
621
+ if abs(det) < tol:
622
+ if _proj_diag_enabled():
623
+ _proj_diag_log(
624
+ "SINGULAR_H",
625
+ iters=iters,
626
+ res_norm=res_norm,
627
+ delta_norm=None,
628
+ detJ=detJ,
629
+ point=point,
630
+ local=np.array([xi, eta], dtype=float),
631
+ in_ref_domain=False,
632
+ )
633
+ return np.zeros((4,), dtype=float), xi, eta
634
+ dxi = (-j22 * rx + j12 * ry) / det
635
+ deta = (j21 * rx - j11 * ry) / det
636
+ xi += dxi
637
+ eta += deta
638
+ if not np.isfinite(xi) or not np.isfinite(eta):
639
+ if _proj_diag_enabled():
640
+ _proj_diag_log(
641
+ "NAN_INF",
642
+ iters=iters,
643
+ res_norm=res_norm,
644
+ delta_norm=float(np.hypot(dxi, deta)),
645
+ detJ=detJ,
646
+ point=point,
647
+ local=np.array([xi, eta], dtype=float),
648
+ in_ref_domain=False,
649
+ )
650
+ return np.zeros((4,), dtype=float), 0.0, 0.0
651
+
652
+ in_ref = max(abs(xi), abs(eta)) <= 1.0 + tol
653
+ if _proj_diag_enabled() and (not in_ref or res_norm > tol):
654
+ code = "OUTSIDE_DOMAIN" if not in_ref else "NEWTON_NO_CONVERGE"
655
+ _proj_diag_log(
656
+ code,
657
+ iters=iters,
658
+ res_norm=res_norm,
659
+ delta_norm=None,
660
+ detJ=detJ,
661
+ point=point,
662
+ local=np.array([xi, eta], dtype=float),
663
+ in_ref_domain=in_ref,
664
+ )
665
+
666
+ return np.array([n1, n2, n3, n4], dtype=float), xi, eta
667
+
668
+
669
+ def _quad_shape_values(
670
+ point: np.ndarray,
671
+ facet_nodes: np.ndarray,
672
+ coords: np.ndarray,
673
+ *,
674
+ tol: float,
675
+ ) -> np.ndarray:
676
+ values, _xi, _eta = _quad_shape_and_local(point, facet_nodes, coords, tol=tol)
677
+ return values
678
+
679
+
680
+ def _quad8_shape_values(xi: float, eta: float) -> np.ndarray:
681
+ n1 = -0.25 * (1.0 - xi) * (1.0 - eta) * (1.0 + xi + eta)
682
+ n2 = -0.25 * (1.0 + xi) * (1.0 - eta) * (1.0 - xi + eta)
683
+ n3 = -0.25 * (1.0 + xi) * (1.0 + eta) * (1.0 - xi - eta)
684
+ n4 = -0.25 * (1.0 - xi) * (1.0 + eta) * (1.0 + xi - eta)
685
+ n5 = 0.5 * (1.0 - xi * xi) * (1.0 - eta)
686
+ n6 = 0.5 * (1.0 + xi) * (1.0 - eta * eta)
687
+ n7 = 0.5 * (1.0 - xi * xi) * (1.0 + eta)
688
+ n8 = 0.5 * (1.0 - xi) * (1.0 - eta * eta)
689
+ return np.array([n1, n2, n3, n4, n5, n6, n7, n8], dtype=float)
690
+
691
+
692
+ def _quad9_shape_values(xi: float, eta: float) -> np.ndarray:
693
+ def q1(t):
694
+ return 0.5 * t * (t - 1.0)
695
+
696
+ def q2(t):
697
+ return 1.0 - t * t
698
+
699
+ def q3(t):
700
+ return 0.5 * t * (t + 1.0)
701
+
702
+ Nx = [q1(xi), q2(xi), q3(xi)]
703
+ Ny = [q1(eta), q2(eta), q3(eta)]
704
+ out = []
705
+ for j in range(3):
706
+ for i in range(3):
707
+ out.append(Nx[i] * Ny[j])
708
+ return np.array(out, dtype=float)
709
+
710
+
711
+ def _quad9_shape_grad_ref(xi: float, eta: float) -> np.ndarray:
712
+ def q1(t):
713
+ return 0.5 * t * (t - 1.0)
714
+
715
+ def q2(t):
716
+ return 1.0 - t * t
717
+
718
+ def q3(t):
719
+ return 0.5 * t * (t + 1.0)
720
+
721
+ def dq1(t):
722
+ return t - 0.5
723
+
724
+ def dq2(t):
725
+ return -2.0 * t
726
+
727
+ def dq3(t):
728
+ return t + 0.5
729
+
730
+ Nx = [q1(xi), q2(xi), q3(xi)]
731
+ Ny = [q1(eta), q2(eta), q3(eta)]
732
+ dNx = [dq1(xi), dq2(xi), dq3(xi)]
733
+ dNy = [dq1(eta), dq2(eta), dq3(eta)]
734
+ out = []
735
+ for j in range(3):
736
+ for i in range(3):
737
+ out.append([dNx[i] * Ny[j], Nx[i] * dNy[j]])
738
+ return np.array(out, dtype=float)
739
+
740
+
741
+ def _quad9_map_and_jacobian(pts: np.ndarray, xi: float, eta: float) -> tuple[np.ndarray, np.ndarray]:
742
+ N = _quad9_shape_values(xi, eta)
743
+ dN = _quad9_shape_grad_ref(xi, eta)
744
+ x = N @ pts
745
+ J = (dN.T @ pts).T # (3,2)
746
+ return x, J
747
+
748
+
749
+ def _project_point_to_quad9(
750
+ point: np.ndarray,
751
+ pts: np.ndarray,
752
+ *,
753
+ tol: float,
754
+ max_iter: int = 15,
755
+ ) -> tuple[float, float, bool, np.ndarray, np.ndarray, dict]:
756
+ xi0 = 0.0
757
+ eta0 = 0.0
758
+ xi = xi0
759
+ eta = eta0
760
+ last_delta = np.array([np.nan, np.nan], dtype=float)
761
+ last_r = np.array([np.nan, np.nan], dtype=float)
762
+ last_det = np.nan
763
+ status = "OK"
764
+ for _ in range(max_iter):
765
+ x, J = _quad9_map_and_jacobian(pts, xi, eta)
766
+ JTJ = J.T @ J
767
+ det = float(np.linalg.det(JTJ))
768
+ last_det = det
769
+ if abs(det) < tol:
770
+ status = "SINGULAR_H"
771
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
772
+ r = J.T @ (x - point)
773
+ last_r = r
774
+ try:
775
+ delta = -np.linalg.solve(JTJ, r)
776
+ except np.linalg.LinAlgError:
777
+ status = "SINGULAR_H"
778
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
779
+ if not np.all(np.isfinite(delta)):
780
+ status = "NAN_INF"
781
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
782
+ last_delta = delta
783
+ step = float(np.max(np.abs(delta)))
784
+ if step > 1.0:
785
+ delta = delta / step
786
+ xi += float(delta[0])
787
+ eta += float(delta[1])
788
+ if float(np.linalg.norm(delta)) < tol and float(np.linalg.norm(r)) < tol:
789
+ break
790
+ x, J = _quad9_map_and_jacobian(pts, xi, eta)
791
+ ok = abs(xi) <= 1.0 + tol and abs(eta) <= 1.0 + tol
792
+ if not ok:
793
+ status = "OUTSIDE_DOMAIN"
794
+ if status == "OK" and (float(np.linalg.norm(last_delta)) >= tol or float(np.linalg.norm(last_r)) >= tol):
795
+ status = "NEWTON_NO_CONVERGE"
796
+ return xi, eta, ok and status == "OK", x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, last_det, J.T @ J)
797
+
798
+
799
+ def _tri6_shape_values(xi: float, eta: float) -> np.ndarray:
800
+ L1 = 1.0 - xi - eta
801
+ L2 = xi
802
+ L3 = eta
803
+ return np.array(
804
+ [
805
+ L1 * (2.0 * L1 - 1.0),
806
+ L2 * (2.0 * L2 - 1.0),
807
+ L3 * (2.0 * L3 - 1.0),
808
+ 4.0 * L1 * L2,
809
+ 4.0 * L2 * L3,
810
+ 4.0 * L1 * L3,
811
+ ],
812
+ dtype=float,
813
+ )
814
+
815
+
816
+ def _tri6_shape_grad_ref(xi: float, eta: float) -> np.ndarray:
817
+ L1 = 1.0 - xi - eta
818
+ L2 = xi
819
+ L3 = eta
820
+ dN1 = np.array([-(4.0 * L1 - 1.0), -(4.0 * L1 - 1.0)], dtype=float)
821
+ dN2 = np.array([4.0 * L2 - 1.0, 0.0], dtype=float)
822
+ dN3 = np.array([0.0, 4.0 * L3 - 1.0], dtype=float)
823
+ dN4 = np.array([4.0 * (L1 - L2), -4.0 * L2], dtype=float)
824
+ dN5 = np.array([4.0 * L3, 4.0 * L2], dtype=float)
825
+ dN6 = np.array([-4.0 * L3, 4.0 * (L1 - L3)], dtype=float)
826
+ return np.array([dN1, dN2, dN3, dN4, dN5, dN6], dtype=float)
827
+
828
+
829
+ def _tri6_map_and_jacobian(pts: np.ndarray, xi: float, eta: float) -> tuple[np.ndarray, np.ndarray]:
830
+ N = _tri6_shape_values(xi, eta)
831
+ dN = _tri6_shape_grad_ref(xi, eta)
832
+ x = N @ pts
833
+ J = (dN.T @ pts).T # (3,2)
834
+ return x, J
835
+
836
+
837
+ def _projection_info(
838
+ status: str,
839
+ xi0: float,
840
+ eta0: float,
841
+ xi: float,
842
+ eta: float,
843
+ r: np.ndarray,
844
+ delta: np.ndarray,
845
+ det: float,
846
+ JTJ: np.ndarray,
847
+ ) -> dict:
848
+ r_norm = float(np.linalg.norm(r)) if r.size else float("nan")
849
+ d_norm = float(np.linalg.norm(delta)) if delta.size else float("nan")
850
+ cond = float(np.linalg.cond(JTJ)) if JTJ.size and np.isfinite(JTJ).all() else float("nan")
851
+ return {
852
+ "status": status,
853
+ "xi0": float(xi0),
854
+ "eta0": float(eta0),
855
+ "xi": float(xi),
856
+ "eta": float(eta),
857
+ "r_norm": r_norm,
858
+ "d_norm": d_norm,
859
+ "det": float(det),
860
+ "cond": cond,
861
+ }
862
+
863
+
864
+ def _project_point_to_tri6(
865
+ point: np.ndarray,
866
+ pts: np.ndarray,
867
+ *,
868
+ tol: float,
869
+ max_iter: int = 15,
870
+ ) -> tuple[float, float, bool, np.ndarray, np.ndarray, dict]:
871
+ xi0 = 1.0 / 3.0
872
+ eta0 = 1.0 / 3.0
873
+ xi = xi0
874
+ eta = eta0
875
+ last_delta = np.array([np.nan, np.nan], dtype=float)
876
+ last_r = np.array([np.nan, np.nan], dtype=float)
877
+ last_det = np.nan
878
+ status = "OK"
879
+ for _ in range(max_iter):
880
+ x, J = _tri6_map_and_jacobian(pts, xi, eta)
881
+ JTJ = J.T @ J
882
+ det = float(np.linalg.det(JTJ))
883
+ last_det = det
884
+ if abs(det) < tol:
885
+ status = "SINGULAR_H"
886
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
887
+ r = J.T @ (x - point)
888
+ last_r = r
889
+ try:
890
+ delta = -np.linalg.solve(JTJ, r)
891
+ except np.linalg.LinAlgError:
892
+ status = "SINGULAR_H"
893
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
894
+ if not np.all(np.isfinite(delta)):
895
+ status = "NAN_INF"
896
+ return xi, eta, False, x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, det, JTJ)
897
+ last_delta = delta
898
+ step = float(np.max(np.abs(delta)))
899
+ if step > 1.0:
900
+ delta = delta / step
901
+ xi += float(delta[0])
902
+ eta += float(delta[1])
903
+ if float(np.linalg.norm(delta)) < tol and float(np.linalg.norm(r)) < tol:
904
+ break
905
+ x, J = _tri6_map_and_jacobian(pts, xi, eta)
906
+ ok = xi >= -tol and eta >= -tol and (xi + eta) <= 1.0 + tol
907
+ if not ok:
908
+ status = "OUTSIDE_DOMAIN"
909
+ if status == "OK" and (float(np.linalg.norm(last_delta)) >= tol or float(np.linalg.norm(last_r)) >= tol):
910
+ status = "NEWTON_NO_CONVERGE"
911
+ return xi, eta, ok and status == "OK", x, J, _projection_info(status, xi0, eta0, xi, eta, last_r, last_delta, last_det, J.T @ J)
912
+
913
+
914
+ def _facet_shape_values(
915
+ point: np.ndarray,
916
+ facet_nodes: np.ndarray,
917
+ coords: np.ndarray,
918
+ *,
919
+ tol: float,
920
+ ) -> np.ndarray:
921
+ """
922
+ Evaluate nodal shape values on a facet at a point.
923
+
924
+ Tri: standard barycentric.
925
+ Quad: split into (0,1,2) and (0,2,3) triangles, piecewise linear.
926
+ """
927
+ pts = coords[facet_nodes]
928
+ n = len(facet_nodes)
929
+ if n == 3:
930
+ lam = _barycentric(point, pts[0], pts[1], pts[2])
931
+ if lam is None:
932
+ return np.zeros((3,), dtype=float)
933
+ return lam
934
+ if n == 6:
935
+ lam = _barycentric(point, pts[0], pts[1], pts[2])
936
+ if lam is None or np.any(lam < -tol):
937
+ return np.zeros((6,), dtype=float)
938
+ L1, L2, L3 = lam
939
+ N1 = L1 * (2.0 * L1 - 1.0)
940
+ N2 = L2 * (2.0 * L2 - 1.0)
941
+ N3 = L3 * (2.0 * L3 - 1.0)
942
+ N4 = 4.0 * L1 * L2
943
+ N5 = 4.0 * L2 * L3
944
+ N6 = 4.0 * L1 * L3
945
+ return np.array([N1, N2, N3, N4, N5, N6], dtype=float)
946
+ if n == 4:
947
+ return _quad_shape_values(point, facet_nodes, coords, tol=tol)
948
+ if n == 8:
949
+ corner_nodes = facet_nodes[:4]
950
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
951
+ if np.allclose(values, 0.0):
952
+ return np.zeros((8,), dtype=float)
953
+ return _quad8_shape_values(xi, eta)
954
+ if n == 9:
955
+ corner_nodes = facet_nodes[[0, 2, 8, 6]]
956
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
957
+ if np.allclose(values, 0.0):
958
+ return np.zeros((9,), dtype=float)
959
+ return _quad9_shape_values(xi, eta)
960
+ raise ValueError("facet must be a triangle or quad")
961
+
962
+
963
+ def _gather_u_local(u_field: np.ndarray, nodes: np.ndarray, value_dim: int) -> np.ndarray:
964
+ if value_dim == 1:
965
+ return u_field[nodes]
966
+ idx = np.repeat(nodes * value_dim, value_dim) + np.tile(np.arange(value_dim), len(nodes))
967
+ return u_field[idx]
968
+
969
+
970
+ def _global_dof_indices(nodes: np.ndarray, value_dim: int, offset: int) -> np.ndarray:
971
+ if value_dim == 1:
972
+ return offset + nodes
973
+ idx = np.repeat(nodes * value_dim, value_dim) + np.tile(np.arange(value_dim), len(nodes))
974
+ return offset + idx
975
+
976
+
977
+ def map_surface_facets_to_tet_elements(surface: SurfaceMesh, tet_conn: np.ndarray) -> np.ndarray:
978
+ """
979
+ Map surface triangle facets to parent tet elements by node matching (tet4/tet10).
980
+ """
981
+ face_patterns_corner = [
982
+ (0, 1, 2),
983
+ (0, 1, 3),
984
+ (0, 2, 3),
985
+ (1, 2, 3),
986
+ ]
987
+ face_patterns_quad = [
988
+ (0, 1, 2, 4, 5, 6),
989
+ (0, 1, 3, 4, 8, 7),
990
+ (0, 2, 3, 6, 9, 7),
991
+ (1, 2, 3, 5, 9, 8),
992
+ ]
993
+ tet_conn = np.asarray(tet_conn, dtype=int)
994
+ if tet_conn.shape[1] not in {4, 10}:
995
+ raise NotImplementedError("Only tet4 and tet10 are supported.")
996
+ mapping_corner: dict[tuple[int, ...], int] = {}
997
+ mapping_quad: dict[tuple[int, ...], int] = {}
998
+ for e_id, elem in enumerate(tet_conn):
999
+ for pattern in face_patterns_corner:
1000
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1001
+ mapping_corner.setdefault(face_nodes, e_id)
1002
+ if elem.shape[0] == 10:
1003
+ for pattern in face_patterns_quad:
1004
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1005
+ mapping_quad.setdefault(face_nodes, e_id)
1006
+ facet_map = np.full((surface.conn.shape[0],), -1, dtype=int)
1007
+ for f_id, facet in enumerate(np.asarray(surface.conn, dtype=int)):
1008
+ key = tuple(sorted(int(n) for n in facet))
1009
+ if len(facet) == 3 and key in mapping_corner:
1010
+ facet_map[f_id] = mapping_corner[key]
1011
+ elif len(facet) == 6 and key in mapping_quad:
1012
+ facet_map[f_id] = mapping_quad[key]
1013
+ elif key in mapping_corner:
1014
+ facet_map[f_id] = mapping_corner[key]
1015
+ return facet_map
1016
+
1017
+
1018
+ def map_surface_facets_to_hex_elements(surface: SurfaceMesh, hex_conn: np.ndarray) -> np.ndarray:
1019
+ """
1020
+ Map surface quad facets to parent hex elements by node matching (hex8/hex20/hex27).
1021
+ """
1022
+ hex_conn = np.asarray(hex_conn, dtype=int)
1023
+ if hex_conn.shape[1] not in {8, 20, 27}:
1024
+ raise NotImplementedError("Only hex8/hex20/hex27 are supported.")
1025
+ face_patterns_corner = [
1026
+ (0, 1, 2, 3),
1027
+ (4, 5, 6, 7),
1028
+ (0, 1, 5, 4),
1029
+ (1, 2, 6, 5),
1030
+ (2, 3, 7, 6),
1031
+ (3, 0, 4, 7),
1032
+ ]
1033
+ face_patterns_corner27 = [
1034
+ (0, 2, 8, 6),
1035
+ (18, 20, 26, 24),
1036
+ (0, 2, 20, 18),
1037
+ (6, 8, 26, 24),
1038
+ (0, 6, 24, 18),
1039
+ (2, 8, 26, 20),
1040
+ ]
1041
+ face_patterns_quad = [
1042
+ (0, 1, 2, 3, 8, 9, 10, 11),
1043
+ (4, 5, 6, 7, 12, 13, 14, 15),
1044
+ (0, 1, 5, 4, 8, 17, 12, 16),
1045
+ (1, 2, 6, 5, 9, 18, 13, 17),
1046
+ (2, 3, 7, 6, 10, 19, 14, 18),
1047
+ (3, 0, 4, 7, 11, 16, 15, 19),
1048
+ ]
1049
+ face_patterns_quad9 = [
1050
+ (0, 1, 2, 3, 4, 5, 6, 7, 8),
1051
+ (18, 19, 20, 21, 22, 23, 24, 25, 26),
1052
+ (0, 1, 2, 9, 10, 11, 18, 19, 20),
1053
+ (6, 7, 8, 15, 16, 17, 24, 25, 26),
1054
+ (0, 3, 6, 9, 12, 15, 18, 21, 24),
1055
+ (2, 5, 8, 11, 14, 17, 20, 23, 26),
1056
+ ]
1057
+ mapping_corner: dict[tuple[int, ...], int] = {}
1058
+ mapping_quad: dict[tuple[int, ...], int] = {}
1059
+ for e_id, elem in enumerate(hex_conn):
1060
+ if elem.shape[0] == 27:
1061
+ corner_patterns = face_patterns_corner27
1062
+ else:
1063
+ corner_patterns = face_patterns_corner
1064
+ for pattern in corner_patterns:
1065
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1066
+ mapping_corner.setdefault(face_nodes, e_id)
1067
+ if elem.shape[0] == 20:
1068
+ for pattern in face_patterns_quad:
1069
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1070
+ mapping_quad.setdefault(face_nodes, e_id)
1071
+ if elem.shape[0] == 27:
1072
+ for pattern in face_patterns_quad9:
1073
+ face_nodes = tuple(sorted(int(elem[i]) for i in pattern))
1074
+ mapping_quad.setdefault(face_nodes, e_id)
1075
+ facet_map = np.full((surface.conn.shape[0],), -1, dtype=int)
1076
+ for f_id, facet in enumerate(np.asarray(surface.conn, dtype=int)):
1077
+ key = tuple(sorted(int(n) for n in facet))
1078
+ if len(facet) == 4 and key in mapping_corner:
1079
+ facet_map[f_id] = mapping_corner[key]
1080
+ elif len(facet) == 8 and key in mapping_quad:
1081
+ facet_map[f_id] = mapping_quad[key]
1082
+ elif len(facet) == 9 and key in mapping_quad:
1083
+ facet_map[f_id] = mapping_quad[key]
1084
+ elif key in mapping_corner:
1085
+ facet_map[f_id] = mapping_corner[key]
1086
+ return facet_map
1087
+
1088
+
1089
+ def _tet_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1090
+ corner_coords = elem_coords[:4]
1091
+ M = np.stack([corner_coords[:, 0], corner_coords[:, 1], corner_coords[:, 2], np.ones(4)], axis=1)
1092
+ rhs = np.array([point[0], point[1], point[2], 1.0], dtype=float)
1093
+ try:
1094
+ lam = np.linalg.solve(M.T, rhs)
1095
+ except np.linalg.LinAlgError:
1096
+ return np.zeros((elem_coords.shape[0],), dtype=float)
1097
+ if np.any(lam < -tol):
1098
+ return np.zeros((elem_coords.shape[0],), dtype=float)
1099
+ if elem_coords.shape[0] == 4:
1100
+ return lam
1101
+ if elem_coords.shape[0] != 10:
1102
+ raise NotImplementedError("tet shape evaluation supports tet4/tet10 only")
1103
+ L1, L2, L3, L4 = lam
1104
+ N1 = L1 * (2.0 * L1 - 1.0)
1105
+ N2 = L2 * (2.0 * L2 - 1.0)
1106
+ N3 = L3 * (2.0 * L3 - 1.0)
1107
+ N4 = L4 * (2.0 * L4 - 1.0)
1108
+ N5 = 4.0 * L1 * L2
1109
+ N6 = 4.0 * L2 * L3
1110
+ N7 = 4.0 * L1 * L3
1111
+ N8 = 4.0 * L1 * L4
1112
+ N9 = 4.0 * L2 * L4
1113
+ N10 = 4.0 * L3 * L4
1114
+ return np.array([N1, N2, N3, N4, N5, N6, N7, N8, N9, N10], dtype=float)
1115
+
1116
+
1117
+ def _tet_gradN(elem_coords: np.ndarray, *, point: np.ndarray | None = None, tol: float) -> np.ndarray:
1118
+ corner_coords = elem_coords[:4]
1119
+ M = np.stack([corner_coords[:, 0], corner_coords[:, 1], corner_coords[:, 2], np.ones(4)], axis=1)
1120
+ try:
1121
+ invM = np.linalg.inv(M)
1122
+ except np.linalg.LinAlgError:
1123
+ return np.zeros((elem_coords.shape[0], 3), dtype=float)
1124
+ dL = invM[:3, :].T
1125
+ if elem_coords.shape[0] == 4:
1126
+ return dL
1127
+ if elem_coords.shape[0] != 10:
1128
+ raise NotImplementedError("tet grad evaluation supports tet4/tet10 only")
1129
+ if point is None:
1130
+ raise ValueError("tet10 grad evaluation requires point")
1131
+ rhs = np.array([point[0], point[1], point[2], 1.0], dtype=float)
1132
+ try:
1133
+ lam = np.linalg.solve(M.T, rhs)
1134
+ except np.linalg.LinAlgError:
1135
+ return np.zeros((10, 3), dtype=float)
1136
+ if np.any(lam < -tol):
1137
+ return np.zeros((10, 3), dtype=float)
1138
+ L1, L2, L3, L4 = lam
1139
+ dL1, dL2, dL3, dL4 = dL
1140
+ dN1 = (4.0 * L1 - 1.0) * dL1
1141
+ dN2 = (4.0 * L2 - 1.0) * dL2
1142
+ dN3 = (4.0 * L3 - 1.0) * dL3
1143
+ dN4 = (4.0 * L4 - 1.0) * dL4
1144
+ dN5 = 4.0 * (L2 * dL1 + L1 * dL2)
1145
+ dN6 = 4.0 * (L3 * dL2 + L2 * dL3)
1146
+ dN7 = 4.0 * (L3 * dL1 + L1 * dL3)
1147
+ dN8 = 4.0 * (L4 * dL1 + L1 * dL4)
1148
+ dN9 = 4.0 * (L4 * dL2 + L2 * dL4)
1149
+ dN10 = 4.0 * (L4 * dL3 + L3 * dL4)
1150
+ return np.vstack([dN1, dN2, dN3, dN4, dN5, dN6, dN7, dN8, dN9, dN10])
1151
+
1152
+
1153
+ def _tet_gradN_at_points(
1154
+ points: np.ndarray,
1155
+ elem_coords: np.ndarray,
1156
+ *,
1157
+ local: np.ndarray | None = None,
1158
+ tol: float,
1159
+ ) -> np.ndarray:
1160
+ n_nodes = elem_coords.shape[0]
1161
+ if n_nodes == 4:
1162
+ grad = _tet_gradN(elem_coords, tol=tol)
1163
+ grad_q = np.repeat(grad[None, :, :], points.shape[0], axis=0)
1164
+ elif n_nodes == 10:
1165
+ grad_q = np.array([_tet_gradN(elem_coords, point=pt, tol=tol) for pt in points], dtype=float)
1166
+ elif n_nodes == 8:
1167
+ grad_q = np.array([_hex8_gradN(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1168
+ elif n_nodes == 20:
1169
+ grad_q = np.array([_hex20_gradN(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1170
+ elif n_nodes == 27:
1171
+ grad_q = np.array([_hex27_gradN(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1172
+ else:
1173
+ raise NotImplementedError("volume grad evaluation supports tet4/tet10/hex8/hex20/hex27 only")
1174
+ if local is not None:
1175
+ grad_q = grad_q[:, local, :]
1176
+ return grad_q
1177
+
1178
+
1179
+ def _hex8_shape_and_local(
1180
+ point: np.ndarray,
1181
+ elem_coords: np.ndarray,
1182
+ *,
1183
+ tol: float,
1184
+ ) -> tuple[np.ndarray, float, float, float]:
1185
+ if _proj_diag_enabled():
1186
+ _proj_diag_attempt()
1187
+ signs = np.array(
1188
+ [
1189
+ [-1.0, -1.0, -1.0],
1190
+ [1.0, -1.0, -1.0],
1191
+ [1.0, 1.0, -1.0],
1192
+ [-1.0, 1.0, -1.0],
1193
+ [-1.0, -1.0, 1.0],
1194
+ [1.0, -1.0, 1.0],
1195
+ [1.0, 1.0, 1.0],
1196
+ [-1.0, 1.0, 1.0],
1197
+ ],
1198
+ dtype=float,
1199
+ )
1200
+ xi = 0.0
1201
+ eta = 0.0
1202
+ zeta = 0.0
1203
+ res_norm = 0.0
1204
+ detJ = None
1205
+ iters = 0
1206
+ for _ in range(12):
1207
+ iters += 1
1208
+ n = 0.125 * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1209
+ x = n @ elem_coords
1210
+ r = x - point
1211
+ res_norm = float(np.linalg.norm(r))
1212
+ if res_norm < tol:
1213
+ break
1214
+ dN_dxi = 0.125 * signs[:, 0] * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1215
+ dN_deta = 0.125 * signs[:, 1] * (1.0 + xi * signs[:, 0]) * (1.0 + zeta * signs[:, 2])
1216
+ dN_dzeta = 0.125 * signs[:, 2] * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1])
1217
+ J = np.stack(
1218
+ [
1219
+ dN_dxi @ elem_coords,
1220
+ dN_deta @ elem_coords,
1221
+ dN_dzeta @ elem_coords,
1222
+ ],
1223
+ axis=1,
1224
+ )
1225
+ detJ = float(np.linalg.det(J))
1226
+ try:
1227
+ delta = np.linalg.solve(J, r)
1228
+ except np.linalg.LinAlgError:
1229
+ if _proj_diag_enabled():
1230
+ _proj_diag_log(
1231
+ "SINGULAR_H",
1232
+ iters=iters,
1233
+ res_norm=res_norm,
1234
+ delta_norm=None,
1235
+ detJ=detJ,
1236
+ point=point,
1237
+ local=np.array([xi, eta, zeta], dtype=float),
1238
+ in_ref_domain=False,
1239
+ )
1240
+ return np.zeros((8,), dtype=float), 0.0, 0.0, 0.0
1241
+ delta_norm = float(np.linalg.norm(delta))
1242
+ xi -= float(delta[0])
1243
+ eta -= float(delta[1])
1244
+ zeta -= float(delta[2])
1245
+ if not np.isfinite(xi) or not np.isfinite(eta) or not np.isfinite(zeta):
1246
+ if _proj_diag_enabled():
1247
+ _proj_diag_log(
1248
+ "NAN_INF",
1249
+ iters=iters,
1250
+ res_norm=res_norm,
1251
+ delta_norm=delta_norm,
1252
+ detJ=detJ,
1253
+ point=point,
1254
+ local=np.array([xi, eta, zeta], dtype=float),
1255
+ in_ref_domain=False,
1256
+ )
1257
+ return np.zeros((8,), dtype=float), 0.0, 0.0, 0.0
1258
+ if max(abs(xi), abs(eta), abs(zeta)) > 1.0 + tol:
1259
+ if _proj_diag_enabled():
1260
+ _proj_diag_log(
1261
+ "OUTSIDE_DOMAIN",
1262
+ iters=iters,
1263
+ res_norm=res_norm,
1264
+ delta_norm=None,
1265
+ detJ=detJ,
1266
+ point=point,
1267
+ local=np.array([xi, eta, zeta], dtype=float),
1268
+ in_ref_domain=False,
1269
+ )
1270
+ return np.zeros((8,), dtype=float), xi, eta, zeta
1271
+ if _proj_diag_enabled() and res_norm > tol:
1272
+ _proj_diag_log(
1273
+ "NEWTON_NO_CONVERGE",
1274
+ iters=iters,
1275
+ res_norm=res_norm,
1276
+ delta_norm=None,
1277
+ detJ=detJ,
1278
+ point=point,
1279
+ local=np.array([xi, eta, zeta], dtype=float),
1280
+ in_ref_domain=True,
1281
+ )
1282
+ n = 0.125 * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1283
+ return n, xi, eta, zeta
1284
+
1285
+
1286
+ def _hex8_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1287
+ n, _, _, _ = _hex8_shape_and_local(point, elem_coords, tol=tol)
1288
+ return n
1289
+
1290
+
1291
+ def _hex8_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1292
+ n, xi, eta, zeta = _hex8_shape_and_local(point, elem_coords, tol=tol)
1293
+ if np.allclose(n, 0.0):
1294
+ return np.zeros((8, 3), dtype=float)
1295
+ signs = np.array(
1296
+ [
1297
+ [-1.0, -1.0, -1.0],
1298
+ [1.0, -1.0, -1.0],
1299
+ [1.0, 1.0, -1.0],
1300
+ [-1.0, 1.0, -1.0],
1301
+ [-1.0, -1.0, 1.0],
1302
+ [1.0, -1.0, 1.0],
1303
+ [1.0, 1.0, 1.0],
1304
+ [-1.0, 1.0, 1.0],
1305
+ ],
1306
+ dtype=float,
1307
+ )
1308
+ dN_dxi = 0.125 * signs[:, 0] * (1.0 + eta * signs[:, 1]) * (1.0 + zeta * signs[:, 2])
1309
+ dN_deta = 0.125 * signs[:, 1] * (1.0 + xi * signs[:, 0]) * (1.0 + zeta * signs[:, 2])
1310
+ dN_dzeta = 0.125 * signs[:, 2] * (1.0 + xi * signs[:, 0]) * (1.0 + eta * signs[:, 1])
1311
+ J = np.stack(
1312
+ [
1313
+ dN_dxi @ elem_coords,
1314
+ dN_deta @ elem_coords,
1315
+ dN_dzeta @ elem_coords,
1316
+ ],
1317
+ axis=1,
1318
+ )
1319
+ try:
1320
+ invJ = np.linalg.inv(J)
1321
+ except np.linalg.LinAlgError:
1322
+ return np.zeros((8, 3), dtype=float)
1323
+ dN_dxi_eta = np.stack([dN_dxi, dN_deta, dN_dzeta], axis=1) # (8,3)
1324
+ return dN_dxi_eta @ invJ
1325
+
1326
+
1327
+ def _hex20_shape_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1328
+ s = np.array(
1329
+ [
1330
+ [-1.0, -1.0, -1.0],
1331
+ [1.0, -1.0, -1.0],
1332
+ [1.0, 1.0, -1.0],
1333
+ [-1.0, 1.0, -1.0],
1334
+ [-1.0, -1.0, 1.0],
1335
+ [1.0, -1.0, 1.0],
1336
+ [1.0, 1.0, 1.0],
1337
+ [-1.0, 1.0, 1.0],
1338
+ ],
1339
+ dtype=float,
1340
+ )
1341
+ sx, sy, sz = s[:, 0], s[:, 1], s[:, 2]
1342
+ term = xi * sx + eta * sy + zeta * sz - 2.0
1343
+ n_corner = 0.125 * (1.0 + sx * xi) * (1.0 + sy * eta) * (1.0 + sz * zeta) * term
1344
+
1345
+ def edge_x(sy, sz):
1346
+ return 0.25 * (1.0 - xi * xi) * (1.0 + sy * eta) * (1.0 + sz * zeta)
1347
+
1348
+ def edge_y(sx, sz):
1349
+ return 0.25 * (1.0 - eta * eta) * (1.0 + sx * xi) * (1.0 + sz * zeta)
1350
+
1351
+ def edge_z(sx, sy):
1352
+ return 0.25 * (1.0 - zeta * zeta) * (1.0 + sx * xi) * (1.0 + sy * eta)
1353
+
1354
+ n_edges = [
1355
+ edge_x(-1, -1),
1356
+ edge_y(1, -1),
1357
+ edge_x(1, -1),
1358
+ edge_y(-1, -1),
1359
+ edge_x(-1, 1),
1360
+ edge_y(1, 1),
1361
+ edge_x(1, 1),
1362
+ edge_y(-1, 1),
1363
+ edge_z(-1, -1),
1364
+ edge_z(1, -1),
1365
+ edge_z(1, 1),
1366
+ edge_z(-1, 1),
1367
+ ]
1368
+
1369
+ return np.concatenate([n_corner, np.array(n_edges, dtype=float)], axis=0)
1370
+
1371
+
1372
+ def _hex20_grad_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1373
+ s = np.array(
1374
+ [
1375
+ [-1.0, -1.0, -1.0],
1376
+ [1.0, -1.0, -1.0],
1377
+ [1.0, 1.0, -1.0],
1378
+ [-1.0, 1.0, -1.0],
1379
+ [-1.0, -1.0, 1.0],
1380
+ [1.0, -1.0, 1.0],
1381
+ [1.0, 1.0, 1.0],
1382
+ [-1.0, 1.0, 1.0],
1383
+ ],
1384
+ dtype=float,
1385
+ )
1386
+ sx, sy, sz = s[:, 0], s[:, 1], s[:, 2]
1387
+ term = xi * sx + eta * sy + zeta * sz - 2.0
1388
+
1389
+ dN_dxi_corner = (sx / 8.0) * (1.0 + sy * eta) * (1.0 + sz * zeta) * (term + (1.0 + sx * xi))
1390
+ dN_deta_corner = (sy / 8.0) * (1.0 + sx * xi) * (1.0 + sz * zeta) * (term + (1.0 + sy * eta))
1391
+ dN_dzeta_corner = (sz / 8.0) * (1.0 + sx * xi) * (1.0 + sy * eta) * (term + (1.0 + sz * zeta))
1392
+ d_corner = np.stack([dN_dxi_corner, dN_deta_corner, dN_dzeta_corner], axis=1)
1393
+
1394
+ def d_edge_x(sy_val, sz_val):
1395
+ dxi = -0.5 * xi * (1.0 + sy_val * eta) * (1.0 + sz_val * zeta)
1396
+ deta = 0.25 * (1.0 - xi * xi) * sy_val * (1.0 + sz_val * zeta)
1397
+ dzeta = 0.25 * (1.0 - xi * xi) * (1.0 + sy_val * eta) * sz_val
1398
+ return np.array([dxi, deta, dzeta], dtype=float)
1399
+
1400
+ def d_edge_y(sx_val, sz_val):
1401
+ dxi = 0.25 * (1.0 - eta * eta) * sx_val * (1.0 + sz_val * zeta)
1402
+ deta = -0.5 * eta * (1.0 + sx_val * xi) * (1.0 + sz_val * zeta)
1403
+ dzeta = 0.25 * (1.0 - eta * eta) * (1.0 + sx_val * xi) * sz_val
1404
+ return np.array([dxi, deta, dzeta], dtype=float)
1405
+
1406
+ def d_edge_z(sx_val, sy_val):
1407
+ dxi = 0.25 * (1.0 - zeta * zeta) * sx_val * (1.0 + sy_val * eta)
1408
+ deta = 0.25 * (1.0 - zeta * zeta) * (1.0 + sx_val * xi) * sy_val
1409
+ dzeta = -0.5 * zeta * (1.0 + sx_val * xi) * (1.0 + sy_val * eta)
1410
+ return np.array([dxi, deta, dzeta], dtype=float)
1411
+
1412
+ d_list = [
1413
+ d_edge_x(-1, -1),
1414
+ d_edge_y(1, -1),
1415
+ d_edge_x(1, -1),
1416
+ d_edge_y(-1, -1),
1417
+ d_edge_x(-1, 1),
1418
+ d_edge_y(1, 1),
1419
+ d_edge_x(1, 1),
1420
+ d_edge_y(-1, 1),
1421
+ d_edge_z(-1, -1),
1422
+ d_edge_z(1, -1),
1423
+ d_edge_z(1, 1),
1424
+ d_edge_z(-1, 1),
1425
+ ]
1426
+
1427
+ d_edges = np.stack(d_list, axis=0)
1428
+ return np.concatenate([d_corner, d_edges], axis=0)
1429
+
1430
+
1431
+ def _hex27_shape_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1432
+ def q1(t):
1433
+ return 0.5 * t * (t - 1.0)
1434
+
1435
+ def q2(t):
1436
+ return 1.0 - t * t
1437
+
1438
+ def q3(t):
1439
+ return 0.5 * t * (t + 1.0)
1440
+
1441
+ Nx = [q1(xi), q2(xi), q3(xi)]
1442
+ Ny = [q1(eta), q2(eta), q3(eta)]
1443
+ Nz = [q1(zeta), q2(zeta), q3(zeta)]
1444
+ out = []
1445
+ for k in range(3):
1446
+ for j in range(3):
1447
+ for i in range(3):
1448
+ out.append(Nx[i] * Ny[j] * Nz[k])
1449
+ return np.array(out, dtype=float)
1450
+
1451
+
1452
+ def _hex27_grad_ref(xi: float, eta: float, zeta: float) -> np.ndarray:
1453
+ def q1(t):
1454
+ return 0.5 * t * (t - 1.0)
1455
+
1456
+ def q2(t):
1457
+ return 1.0 - t * t
1458
+
1459
+ def q3(t):
1460
+ return 0.5 * t * (t + 1.0)
1461
+
1462
+ def dq1(t):
1463
+ return t - 0.5
1464
+
1465
+ def dq2(t):
1466
+ return -2.0 * t
1467
+
1468
+ def dq3(t):
1469
+ return t + 0.5
1470
+
1471
+ Nx = [q1(xi), q2(xi), q3(xi)]
1472
+ Ny = [q1(eta), q2(eta), q3(eta)]
1473
+ Nz = [q1(zeta), q2(zeta), q3(zeta)]
1474
+ dNx = [dq1(xi), dq2(xi), dq3(xi)]
1475
+ dNy = [dq1(eta), dq2(eta), dq3(eta)]
1476
+ dNz = [dq1(zeta), dq2(zeta), dq3(zeta)]
1477
+ out = []
1478
+ for k in range(3):
1479
+ for j in range(3):
1480
+ for i in range(3):
1481
+ dxi = dNx[i] * Ny[j] * Nz[k]
1482
+ deta = Nx[i] * dNy[j] * Nz[k]
1483
+ dzeta = Nx[i] * Ny[j] * dNz[k]
1484
+ out.append([dxi, deta, dzeta])
1485
+ return np.array(out, dtype=float)
1486
+
1487
+
1488
+ def _hex20_shape_and_local(
1489
+ point: np.ndarray,
1490
+ elem_coords: np.ndarray,
1491
+ *,
1492
+ tol: float,
1493
+ ) -> tuple[np.ndarray, float, float, float]:
1494
+ n8, xi, eta, zeta = _hex8_shape_and_local(point, elem_coords[:8], tol=tol)
1495
+ if np.allclose(n8, 0.0):
1496
+ return np.zeros((20,), dtype=float), 0.0, 0.0, 0.0
1497
+ if max(abs(xi), abs(eta), abs(zeta)) > 1.0 + tol:
1498
+ return np.zeros((20,), dtype=float), xi, eta, zeta
1499
+ n = _hex20_shape_ref(xi, eta, zeta)
1500
+ return n, xi, eta, zeta
1501
+
1502
+
1503
+ def _hex20_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1504
+ n, _, _, _ = _hex20_shape_and_local(point, elem_coords, tol=tol)
1505
+ return n
1506
+
1507
+
1508
+ def _hex20_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1509
+ n, xi, eta, zeta = _hex20_shape_and_local(point, elem_coords, tol=tol)
1510
+ if np.allclose(n, 0.0):
1511
+ return np.zeros((20, 3), dtype=float)
1512
+ dN = _hex20_grad_ref(xi, eta, zeta)
1513
+ J = dN.T @ elem_coords
1514
+ try:
1515
+ invJ = np.linalg.inv(J)
1516
+ except np.linalg.LinAlgError:
1517
+ return np.zeros((20, 3), dtype=float)
1518
+ return dN @ invJ
1519
+
1520
+
1521
+ def _hex27_shape_and_local(
1522
+ point: np.ndarray,
1523
+ elem_coords: np.ndarray,
1524
+ *,
1525
+ tol: float,
1526
+ ) -> tuple[np.ndarray, float, float, float]:
1527
+ corner_ids = np.array([0, 2, 8, 6, 18, 20, 26, 24], dtype=int)
1528
+ corner_coords = elem_coords[corner_ids]
1529
+ n8, xi, eta, zeta = _hex8_shape_and_local(point, corner_coords, tol=tol)
1530
+ if np.allclose(n8, 0.0):
1531
+ return np.zeros((27,), dtype=float), 0.0, 0.0, 0.0
1532
+ if max(abs(xi), abs(eta), abs(zeta)) > 1.0 + tol:
1533
+ return np.zeros((27,), dtype=float), xi, eta, zeta
1534
+ n = _hex27_shape_ref(xi, eta, zeta)
1535
+ return n, xi, eta, zeta
1536
+
1537
+
1538
+ def _hex27_shape_values(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1539
+ n, _, _, _ = _hex27_shape_and_local(point, elem_coords, tol=tol)
1540
+ return n
1541
+
1542
+
1543
+ def _hex27_gradN(point: np.ndarray, elem_coords: np.ndarray, *, tol: float) -> np.ndarray:
1544
+ n, xi, eta, zeta = _hex27_shape_and_local(point, elem_coords, tol=tol)
1545
+ if np.allclose(n, 0.0):
1546
+ return np.zeros((27, 3), dtype=float)
1547
+ dN = _hex27_grad_ref(xi, eta, zeta)
1548
+ J = dN.T @ elem_coords
1549
+ try:
1550
+ invJ = np.linalg.inv(J)
1551
+ except np.linalg.LinAlgError:
1552
+ return np.zeros((27, 3), dtype=float)
1553
+ return dN @ invJ
1554
+
1555
+
1556
+ def _volume_shape_values_at_points(
1557
+ points: np.ndarray,
1558
+ elem_coords: np.ndarray,
1559
+ *,
1560
+ tol: float,
1561
+ ) -> np.ndarray:
1562
+ n_nodes = elem_coords.shape[0]
1563
+ if n_nodes in {4, 10}:
1564
+ return np.array([_tet_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1565
+ if n_nodes == 20:
1566
+ return np.array([_hex20_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1567
+ if n_nodes == 8:
1568
+ return np.array([_hex8_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1569
+ if n_nodes == 27:
1570
+ return np.array([_hex27_shape_values(pt, elem_coords, tol=tol) for pt in points], dtype=float)
1571
+ raise NotImplementedError("volume shape evaluation supports tet4/tet10/hex8/hex20/hex27 only")
1572
+
1573
+
1574
+ def _local_indices(elem_nodes: np.ndarray, facet_nodes: np.ndarray) -> np.ndarray:
1575
+ index = {int(n): i for i, n in enumerate(elem_nodes)}
1576
+ try:
1577
+ return np.array([index[int(n)] for n in facet_nodes], dtype=int)
1578
+ except KeyError as exc:
1579
+ raise ValueError("facet nodes are not part of the element connectivity") from exc
1580
+
1581
+
1582
+ def _surface_gradN(
1583
+ point: np.ndarray,
1584
+ facet_nodes: np.ndarray,
1585
+ coords: np.ndarray,
1586
+ *,
1587
+ tol: float,
1588
+ ) -> np.ndarray:
1589
+ global _DEBUG_SURFACE_GRADN_COUNT
1590
+ pts = coords[facet_nodes]
1591
+ n = len(facet_nodes)
1592
+ debug = bool(_DEBUG_SURFACE_GRADN) and _DEBUG_SURFACE_GRADN_COUNT < _DEBUG_SURFACE_GRADN_MAX
1593
+ if n == 3:
1594
+ dN = np.array(
1595
+ [
1596
+ [-1.0, -1.0],
1597
+ [1.0, 0.0],
1598
+ [0.0, 1.0],
1599
+ ],
1600
+ dtype=float,
1601
+ )
1602
+ dX_dxi = dN[:, 0] @ pts
1603
+ dX_deta = dN[:, 1] @ pts
1604
+ dN_lin = dN
1605
+ elif n == 4:
1606
+ values, xi, eta = _quad_shape_and_local(point, facet_nodes, coords, tol=tol)
1607
+ dN_dxi = np.array(
1608
+ [
1609
+ -0.25 * (1.0 - eta),
1610
+ 0.25 * (1.0 - eta),
1611
+ 0.25 * (1.0 + eta),
1612
+ -0.25 * (1.0 + eta),
1613
+ ],
1614
+ dtype=float,
1615
+ )
1616
+ dN_deta = np.array(
1617
+ [
1618
+ -0.25 * (1.0 - xi),
1619
+ -0.25 * (1.0 + xi),
1620
+ 0.25 * (1.0 + xi),
1621
+ 0.25 * (1.0 - xi),
1622
+ ],
1623
+ dtype=float,
1624
+ )
1625
+ dX_dxi = dN_dxi @ pts
1626
+ dX_deta = dN_deta @ pts
1627
+ dN = np.stack([dN_dxi, dN_deta], axis=1)
1628
+ dN_lin = None
1629
+ if debug:
1630
+ n_sum = float(values.sum())
1631
+ x_phys = values @ pts
1632
+ n_raw = np.cross(dX_dxi, dX_deta)
1633
+ j_surf = float(np.linalg.norm(n_raw))
1634
+ print(
1635
+ "[fluxfem][surface_gradN][quad4]",
1636
+ f"pt={np.array2string(point, precision=6)}",
1637
+ f"xi={xi:.6f}",
1638
+ f"eta={eta:.6f}",
1639
+ f"N_sum={n_sum:.6e}",
1640
+ f"dN_dxi_sum={float(dN_dxi.sum()):.6e}",
1641
+ f"dN_deta_sum={float(dN_deta.sum()):.6e}",
1642
+ f"x_phys={np.array2string(x_phys, precision=6)}",
1643
+ f"t1={np.array2string(dX_dxi, precision=6)}",
1644
+ f"t2={np.array2string(dX_deta, precision=6)}",
1645
+ f"J_surf={j_surf:.6e}",
1646
+ )
1647
+ _DEBUG_SURFACE_GRADN_COUNT += 1
1648
+ elif n == 6:
1649
+ lam = _barycentric(point, pts[0], pts[1], pts[2])
1650
+ if lam is None:
1651
+ return np.zeros((6, 3), dtype=float)
1652
+ dN_lin = np.array(
1653
+ [
1654
+ [-1.0, -1.0],
1655
+ [1.0, 0.0],
1656
+ [0.0, 1.0],
1657
+ ],
1658
+ dtype=float,
1659
+ )
1660
+ dX_dxi = dN_lin[:, 0] @ pts[:3]
1661
+ dX_deta = dN_lin[:, 1] @ pts[:3]
1662
+ dN = dN_lin
1663
+ elif n == 8:
1664
+ corner_nodes = facet_nodes[:4]
1665
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
1666
+ if np.allclose(values, 0.0):
1667
+ return np.zeros((8, 3), dtype=float)
1668
+ dN_dxi_corner = np.array(
1669
+ [
1670
+ -0.25 * (1.0 - eta),
1671
+ 0.25 * (1.0 - eta),
1672
+ 0.25 * (1.0 + eta),
1673
+ -0.25 * (1.0 + eta),
1674
+ ],
1675
+ dtype=float,
1676
+ )
1677
+ dN_deta_corner = np.array(
1678
+ [
1679
+ -0.25 * (1.0 - xi),
1680
+ -0.25 * (1.0 + xi),
1681
+ 0.25 * (1.0 + xi),
1682
+ 0.25 * (1.0 - xi),
1683
+ ],
1684
+ dtype=float,
1685
+ )
1686
+ dX_dxi = dN_dxi_corner @ pts[:4]
1687
+ dX_deta = dN_deta_corner @ pts[:4]
1688
+ dN1_dxi = -0.25 * (1.0 - eta) * ((1.0 - xi) - (1.0 + xi + eta))
1689
+ dN1_deta = -0.25 * (1.0 - xi) * ((1.0 - eta) - (1.0 + xi + eta))
1690
+ dN2_dxi = 0.25 * (1.0 - eta) * ((1.0 + xi) - (1.0 - xi + eta))
1691
+ dN2_deta = -0.25 * (1.0 + xi) * ((1.0 - eta) - (1.0 - xi + eta))
1692
+ dN3_dxi = 0.25 * (1.0 + eta) * ((1.0 + xi) - (1.0 - xi - eta))
1693
+ dN3_deta = 0.25 * (1.0 + xi) * ((1.0 + eta) - (1.0 - xi - eta))
1694
+ dN4_dxi = -0.25 * (1.0 + eta) * ((1.0 - xi) - (1.0 + xi - eta))
1695
+ dN4_deta = 0.25 * (1.0 - xi) * ((1.0 + eta) - (1.0 + xi - eta))
1696
+ dN5_dxi = -xi * (1.0 - eta)
1697
+ dN5_deta = -0.5 * (1.0 - xi * xi)
1698
+ dN6_dxi = 0.5 * (1.0 - eta * eta)
1699
+ dN6_deta = -(1.0 + xi) * eta
1700
+ dN7_dxi = -xi * (1.0 + eta)
1701
+ dN7_deta = 0.5 * (1.0 - xi * xi)
1702
+ dN8_dxi = -0.5 * (1.0 - eta * eta)
1703
+ dN8_deta = -(1.0 - xi) * eta
1704
+ dN = np.array(
1705
+ [
1706
+ [dN1_dxi, dN1_deta],
1707
+ [dN2_dxi, dN2_deta],
1708
+ [dN3_dxi, dN3_deta],
1709
+ [dN4_dxi, dN4_deta],
1710
+ [dN5_dxi, dN5_deta],
1711
+ [dN6_dxi, dN6_deta],
1712
+ [dN7_dxi, dN7_deta],
1713
+ [dN8_dxi, dN8_deta],
1714
+ ],
1715
+ dtype=float,
1716
+ )
1717
+ if debug:
1718
+ values8 = _quad8_shape_values(xi, eta)
1719
+ n_sum = float(values8.sum())
1720
+ x_phys = values8 @ pts
1721
+ n_raw = np.cross(dX_dxi, dX_deta)
1722
+ j_surf = float(np.linalg.norm(n_raw))
1723
+ print(
1724
+ "[fluxfem][surface_gradN][quad8]",
1725
+ f"pt={np.array2string(point, precision=6)}",
1726
+ f"xi={xi:.6f}",
1727
+ f"eta={eta:.6f}",
1728
+ f"N_sum={n_sum:.6e}",
1729
+ f"dN_dxi_sum={float(dN[:, 0].sum()):.6e}",
1730
+ f"dN_deta_sum={float(dN[:, 1].sum()):.6e}",
1731
+ f"x_phys={np.array2string(x_phys, precision=6)}",
1732
+ f"t1={np.array2string(dX_dxi, precision=6)}",
1733
+ f"t2={np.array2string(dX_deta, precision=6)}",
1734
+ f"J_surf={j_surf:.6e}",
1735
+ )
1736
+ _DEBUG_SURFACE_GRADN_COUNT += 1
1737
+ elif n == 9:
1738
+ corner_nodes = facet_nodes[[0, 2, 8, 6]]
1739
+ values, xi, eta = _quad_shape_and_local(point, corner_nodes, coords, tol=tol)
1740
+ if np.allclose(values, 0.0):
1741
+ return np.zeros((9, 3), dtype=float)
1742
+ dN_dxi_corner = np.array(
1743
+ [
1744
+ -0.25 * (1.0 - eta),
1745
+ 0.25 * (1.0 - eta),
1746
+ 0.25 * (1.0 + eta),
1747
+ -0.25 * (1.0 + eta),
1748
+ ],
1749
+ dtype=float,
1750
+ )
1751
+ dN_deta_corner = np.array(
1752
+ [
1753
+ -0.25 * (1.0 - xi),
1754
+ -0.25 * (1.0 + xi),
1755
+ 0.25 * (1.0 + xi),
1756
+ 0.25 * (1.0 - xi),
1757
+ ],
1758
+ dtype=float,
1759
+ )
1760
+ dX_dxi = dN_dxi_corner @ pts[:4]
1761
+ dX_deta = dN_deta_corner @ pts[:4]
1762
+
1763
+ def q1(t):
1764
+ return 0.5 * t * (t - 1.0)
1765
+
1766
+ def q2(t):
1767
+ return 1.0 - t * t
1768
+
1769
+ def q3(t):
1770
+ return 0.5 * t * (t + 1.0)
1771
+
1772
+ def dq1(t):
1773
+ return t - 0.5
1774
+
1775
+ def dq2(t):
1776
+ return -2.0 * t
1777
+
1778
+ def dq3(t):
1779
+ return t + 0.5
1780
+
1781
+ Nx = [q1(xi), q2(xi), q3(xi)]
1782
+ Ny = [q1(eta), q2(eta), q3(eta)]
1783
+ dNx = [dq1(xi), dq2(xi), dq3(xi)]
1784
+ dNy = [dq1(eta), dq2(eta), dq3(eta)]
1785
+ dN = []
1786
+ for j in range(3):
1787
+ for i in range(3):
1788
+ dN_dxi = dNx[i] * Ny[j]
1789
+ dN_deta = Nx[i] * dNy[j]
1790
+ dN.append([dN_dxi, dN_deta])
1791
+ dN = np.array(dN, dtype=float)
1792
+ if debug:
1793
+ values9 = _quad9_shape_values(xi, eta)
1794
+ n_sum = float(values9.sum())
1795
+ x_phys = values9 @ pts
1796
+ n_raw = np.cross(dX_dxi, dX_deta)
1797
+ j_surf = float(np.linalg.norm(n_raw))
1798
+ print(
1799
+ "[fluxfem][surface_gradN][quad9]",
1800
+ f"pt={np.array2string(point, precision=6)}",
1801
+ f"xi={xi:.6f}",
1802
+ f"eta={eta:.6f}",
1803
+ f"N_sum={n_sum:.6e}",
1804
+ f"dN_dxi_sum={float(dN[:, 0].sum()):.6e}",
1805
+ f"dN_deta_sum={float(dN[:, 1].sum()):.6e}",
1806
+ f"x_phys={np.array2string(x_phys, precision=6)}",
1807
+ f"t1={np.array2string(dX_dxi, precision=6)}",
1808
+ f"t2={np.array2string(dX_deta, precision=6)}",
1809
+ f"J_surf={j_surf:.6e}",
1810
+ )
1811
+ _DEBUG_SURFACE_GRADN_COUNT += 1
1812
+ else:
1813
+ raise ValueError("facet must be a triangle or quad")
1814
+
1815
+ J = np.stack([dX_dxi, dX_deta], axis=1) # (3, 2)
1816
+ JTJ = J.T @ J
1817
+ if abs(np.linalg.det(JTJ)) < tol:
1818
+ return np.zeros((n, 3), dtype=float)
1819
+ M = J @ np.linalg.inv(JTJ) # (3, 2)
1820
+ gradN = (M @ dN.T).T # (n, 3)
1821
+ if n == 6:
1822
+ L1, L2, L3 = lam
1823
+ g1, g2, g3 = gradN[:3]
1824
+ gradN = np.array(
1825
+ [
1826
+ (4.0 * L1 - 1.0) * g1,
1827
+ (4.0 * L2 - 1.0) * g2,
1828
+ (4.0 * L3 - 1.0) * g3,
1829
+ 4.0 * (L1 * g2 + L2 * g1),
1830
+ 4.0 * (L2 * g3 + L3 * g2),
1831
+ 4.0 * (L1 * g3 + L3 * g1),
1832
+ ],
1833
+ dtype=float,
1834
+ )
1835
+ return gradN
1836
+
1837
+
1838
+ def _iter_supermesh_tris(coords: np.ndarray, conn: np.ndarray):
1839
+ for tri in conn:
1840
+ a, b, c = coords[tri]
1841
+ yield tri, a, b, c
1842
+
1843
+
1844
+ def _projection_surface_batches(
1845
+ source_facets_a: Iterable[int],
1846
+ source_facets_b: Iterable[int],
1847
+ surface_a: SurfaceMesh,
1848
+ surface_b: SurfaceMesh,
1849
+ *,
1850
+ elem_conn_a: np.ndarray | None,
1851
+ elem_conn_b: np.ndarray | None,
1852
+ facet_to_elem_a: np.ndarray | None,
1853
+ facet_to_elem_b: np.ndarray | None,
1854
+ quad_order: int,
1855
+ grad_source: str,
1856
+ dof_source: str,
1857
+ normal_source: str,
1858
+ normal_sign: float,
1859
+ tol: float,
1860
+ ):
1861
+ if dof_source != "volume" or grad_source != "volume":
1862
+ return None, False
1863
+
1864
+ facets_a = np.asarray(surface_a.conn, dtype=int)
1865
+ facets_b = np.asarray(surface_b.conn, dtype=int)
1866
+ coords_a = np.asarray(surface_a.coords, dtype=float)
1867
+ coords_b = np.asarray(surface_b.coords, dtype=float)
1868
+
1869
+ if facets_a.shape[1] != facets_b.shape[1] or facets_a.shape[1] not in {6, 9}:
1870
+ return None, False
1871
+ if elem_conn_a is None or elem_conn_b is None or facet_to_elem_a is None or facet_to_elem_b is None:
1872
+ return None, False
1873
+
1874
+ diag = bool(_DEBUG_PROJECTION_DIAG)
1875
+ diag_max = _DEBUG_PROJECTION_DIAG_MAX if diag else 0
1876
+ total_points = 0
1877
+ fail_points = 0
1878
+ fail_by_code: dict[str, int] = {}
1879
+ fail_samples: list[dict] = []
1880
+
1881
+ def _record_failure(code: str, info: dict | None, *, face_type: str, fa: int, fb: int, elem_id_a: int, elem_id_b: int, xm):
1882
+ nonlocal fail_points
1883
+ fail_points += 1
1884
+ fail_by_code[code] = fail_by_code.get(code, 0) + 1
1885
+ if not diag or len(fail_samples) >= diag_max:
1886
+ return
1887
+ sample = {
1888
+ "code": code,
1889
+ "face_type": face_type,
1890
+ "fa": int(fa),
1891
+ "fb": int(fb),
1892
+ "elem_a": int(elem_id_a),
1893
+ "elem_b": int(elem_id_b),
1894
+ "xm": None if xm is None else np.array(xm, dtype=float),
1895
+ }
1896
+ if info:
1897
+ sample.update(info)
1898
+ fail_samples.append(sample)
1899
+
1900
+ pairs = {(int(fa), int(fb)) for fa, fb in zip(source_facets_a, source_facets_b)}
1901
+ if facets_a.shape[1] == 9:
1902
+ quad_pts, quad_w = _quad_quadrature(quad_order if quad_order > 0 else 2)
1903
+ face_type = "quad9"
1904
+ else:
1905
+ quad_pts, quad_w = _tri_quadrature(quad_order if quad_order > 0 else 1)
1906
+ face_type = "tri6"
1907
+ batches = []
1908
+ fallback = False
1909
+
1910
+ for fa, fb in pairs:
1911
+ facet_a = facets_a[fa]
1912
+ facet_b = facets_b[fb]
1913
+ pts_a = coords_a[facet_a]
1914
+ pts_b = coords_b[facet_b]
1915
+
1916
+ elem_id_a = int(facet_to_elem_a[fa])
1917
+ elem_id_b = int(facet_to_elem_b[fb])
1918
+ if elem_id_a < 0 or elem_id_b < 0:
1919
+ return None, True
1920
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
1921
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
1922
+ elem_coords_a = coords_a[elem_nodes_a]
1923
+ elem_coords_b = coords_b[elem_nodes_b]
1924
+
1925
+ x_m_list = []
1926
+ x_s_list = []
1927
+ detJ_list = []
1928
+ normal_list = []
1929
+ for (xi, eta), w in zip(quad_pts, quad_w):
1930
+ if facets_a.shape[1] == 9:
1931
+ x_m, Jm = _quad9_map_and_jacobian(pts_a, xi, eta)
1932
+ xi_s, eta_s, ok, x_s, Js, info = _project_point_to_quad9(x_m, pts_b, tol=tol)
1933
+ else:
1934
+ x_m, Jm = _tri6_map_and_jacobian(pts_a, xi, eta)
1935
+ xi_s, eta_s, ok, x_s, Js, info = _project_point_to_tri6(x_m, pts_b, tol=tol)
1936
+ total_points += 1
1937
+ n_raw = np.cross(Jm[:, 0], Jm[:, 1])
1938
+ j_surf = float(np.linalg.norm(n_raw))
1939
+ if j_surf <= tol:
1940
+ fallback = True
1941
+ _record_failure(
1942
+ "DEGENERATE_MASTER",
1943
+ None,
1944
+ face_type=face_type,
1945
+ fa=fa,
1946
+ fb=fb,
1947
+ elem_id_a=elem_id_a,
1948
+ elem_id_b=elem_id_b,
1949
+ xm=x_m,
1950
+ )
1951
+ continue
1952
+ if not ok:
1953
+ fallback = True
1954
+ _record_failure(
1955
+ info.get("status", "PROJECTION_FAIL"),
1956
+ info,
1957
+ face_type=face_type,
1958
+ fa=fa,
1959
+ fb=fb,
1960
+ elem_id_a=elem_id_a,
1961
+ elem_id_b=elem_id_b,
1962
+ xm=x_m,
1963
+ )
1964
+ continue
1965
+ n_m = n_raw / j_surf
1966
+ n_use = n_m
1967
+ if normal_source in {"b", "slave"}:
1968
+ n_raw_b = np.cross(Js[:, 0], Js[:, 1])
1969
+ n_norm_b = float(np.linalg.norm(n_raw_b))
1970
+ if n_norm_b <= tol:
1971
+ fallback = True
1972
+ _record_failure(
1973
+ "DEGENERATE_SLAVE",
1974
+ None,
1975
+ face_type=face_type,
1976
+ fa=fa,
1977
+ fb=fb,
1978
+ elem_id_a=elem_id_a,
1979
+ elem_id_b=elem_id_b,
1980
+ xm=x_m,
1981
+ )
1982
+ continue
1983
+ n_use = n_raw_b / n_norm_b
1984
+ elif normal_source == "avg":
1985
+ n_raw_b = np.cross(Js[:, 0], Js[:, 1])
1986
+ n_norm_b = float(np.linalg.norm(n_raw_b))
1987
+ if n_norm_b <= tol:
1988
+ fallback = True
1989
+ _record_failure(
1990
+ "DEGENERATE_SLAVE",
1991
+ None,
1992
+ face_type=face_type,
1993
+ fa=fa,
1994
+ fb=fb,
1995
+ elem_id_a=elem_id_a,
1996
+ elem_id_b=elem_id_b,
1997
+ xm=x_m,
1998
+ )
1999
+ continue
2000
+ n_b = n_raw_b / n_norm_b
2001
+ avg = n_m + n_b
2002
+ avg_norm = float(np.linalg.norm(avg))
2003
+ n_use = avg / avg_norm if avg_norm > tol else n_m
2004
+ x_m_list.append(x_m)
2005
+ x_s_list.append(x_s)
2006
+ detJ_list.append(float(w * j_surf))
2007
+ normal_list.append(n_use)
2008
+
2009
+ if not x_m_list:
2010
+ continue
2011
+ x_m = np.array(x_m_list, dtype=float)
2012
+ x_s = np.array(x_s_list, dtype=float)
2013
+ weights = np.array(detJ_list, dtype=float)
2014
+ normals = normal_sign * np.array(normal_list, dtype=float)
2015
+
2016
+ Na = _volume_shape_values_at_points(x_m, elem_coords_a, tol=tol)
2017
+ Nb = _volume_shape_values_at_points(x_s, elem_coords_b, tol=tol)
2018
+ gradNa = _tet_gradN_at_points(x_m, elem_coords_a, tol=tol)
2019
+ gradNb = _tet_gradN_at_points(x_s, elem_coords_b, tol=tol)
2020
+
2021
+ batches.append(
2022
+ dict(
2023
+ x_q=x_m,
2024
+ w=weights,
2025
+ detJ=np.ones_like(weights),
2026
+ Na=Na,
2027
+ Nb=Nb,
2028
+ gradNa=gradNa,
2029
+ gradNb=gradNb,
2030
+ nodes_a=elem_nodes_a,
2031
+ nodes_b=elem_nodes_b,
2032
+ normal=normals,
2033
+ )
2034
+ )
2035
+
2036
+ if diag and fail_points:
2037
+ print(
2038
+ "[fluxfem][proj][diag]",
2039
+ f"total={total_points}",
2040
+ f"fail={fail_points}",
2041
+ f"fallback={fallback}",
2042
+ f"face_type={face_type}",
2043
+ f"fail_by_code={fail_by_code}",
2044
+ )
2045
+ for i, sample in enumerate(fail_samples):
2046
+ xm = sample.get("xm")
2047
+ xm_str = np.array2string(xm, precision=6) if xm is not None else "None"
2048
+ print(
2049
+ "[fluxfem][proj][diag]",
2050
+ f"sample={i}",
2051
+ f"code={sample.get('code')}",
2052
+ f"face={sample.get('face_type')}",
2053
+ f"fa={sample.get('fa')}",
2054
+ f"fb={sample.get('fb')}",
2055
+ f"elem_a={sample.get('elem_a')}",
2056
+ f"elem_b={sample.get('elem_b')}",
2057
+ f"xm={xm_str}",
2058
+ f"xi0={sample.get('xi0', float('nan')):.6f}",
2059
+ f"eta0={sample.get('eta0', float('nan')):.6f}",
2060
+ f"xi={sample.get('xi', float('nan')):.6f}",
2061
+ f"eta={sample.get('eta', float('nan')):.6f}",
2062
+ f"r={sample.get('r_norm', float('nan')):.3e}",
2063
+ f"d={sample.get('d_norm', float('nan')):.3e}",
2064
+ f"det={sample.get('det', float('nan')):.3e}",
2065
+ f"cond={sample.get('cond', float('nan')):.3e}",
2066
+ )
2067
+
2068
+ return batches, fallback
2069
+
2070
+
2071
+ def assemble_mortar_matrices(
2072
+ supermesh_coords: np.ndarray,
2073
+ supermesh_conn: np.ndarray,
2074
+ source_facets_a: Iterable[int],
2075
+ source_facets_b: Iterable[int],
2076
+ surface_a: SurfaceMesh,
2077
+ surface_b: SurfaceMesh,
2078
+ *,
2079
+ tol: float = 1e-8,
2080
+ ) -> tuple[MortarMatrix, MortarMatrix]:
2081
+ """
2082
+ Assemble mortar coupling matrices M_aa and M_ab using centroid quadrature.
2083
+ """
2084
+ coords_a = np.asarray(surface_a.coords, dtype=float)
2085
+ coords_b = np.asarray(surface_b.coords, dtype=float)
2086
+ facets_a = np.asarray(surface_a.conn, dtype=int)
2087
+ facets_b = np.asarray(surface_b.conn, dtype=int)
2088
+
2089
+ rows_aa: list[int] = []
2090
+ cols_aa: list[int] = []
2091
+ data_aa: list[float] = []
2092
+
2093
+ rows_ab: list[int] = []
2094
+ cols_ab: list[int] = []
2095
+ data_ab: list[float] = []
2096
+
2097
+ for (tri, a, b, c), fa, fb in zip(
2098
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
2099
+ source_facets_a,
2100
+ source_facets_b,
2101
+ ):
2102
+ centroid = _tri_centroid(a, b, c)
2103
+ weight = _tri_area(a, b, c)
2104
+ if weight <= tol:
2105
+ continue
2106
+
2107
+ facet_a = facets_a[int(fa)]
2108
+ facet_b = facets_b[int(fb)]
2109
+ Na = _facet_shape_values(centroid, facet_a, coords_a, tol=tol)
2110
+ Nb = _facet_shape_values(centroid, facet_b, coords_b, tol=tol)
2111
+
2112
+ for i, node_i in enumerate(facet_a):
2113
+ for j, node_j in enumerate(facet_a):
2114
+ rows_aa.append(int(node_i))
2115
+ cols_aa.append(int(node_j))
2116
+ data_aa.append(weight * float(Na[i]) * float(Na[j]))
2117
+
2118
+ for i, node_i in enumerate(facet_a):
2119
+ for j, node_j in enumerate(facet_b):
2120
+ rows_ab.append(int(node_i))
2121
+ cols_ab.append(int(node_j))
2122
+ data_ab.append(weight * float(Na[i]) * float(Nb[j]))
2123
+
2124
+ n_a = int(np.asarray(surface_a.coords).shape[0])
2125
+ n_b = int(np.asarray(surface_b.coords).shape[0])
2126
+ M_aa = MortarMatrix(
2127
+ rows=np.asarray(rows_aa, dtype=int),
2128
+ cols=np.asarray(cols_aa, dtype=int),
2129
+ data=np.asarray(data_aa, dtype=float),
2130
+ shape=(n_a, n_a),
2131
+ )
2132
+ M_ab = MortarMatrix(
2133
+ rows=np.asarray(rows_ab, dtype=int),
2134
+ cols=np.asarray(cols_ab, dtype=int),
2135
+ data=np.asarray(data_ab, dtype=float),
2136
+ shape=(n_a, n_b),
2137
+ )
2138
+ return M_aa, M_ab
2139
+
2140
+
2141
+ def assemble_mixed_surface_residual(
2142
+ supermesh_coords: np.ndarray,
2143
+ supermesh_conn: np.ndarray,
2144
+ source_facets_a: Iterable[int],
2145
+ source_facets_b: Iterable[int],
2146
+ surface_a: SurfaceMesh,
2147
+ surface_b: SurfaceMesh,
2148
+ res_form,
2149
+ u_a: np.ndarray,
2150
+ u_b: np.ndarray,
2151
+ params,
2152
+ *,
2153
+ value_dim_a: int = 1,
2154
+ value_dim_b: int = 1,
2155
+ offset_a: int = 0,
2156
+ offset_b: int | None = None,
2157
+ field_a: str = "a",
2158
+ field_b: str = "b",
2159
+ elem_conn_a: np.ndarray | None = None,
2160
+ elem_conn_b: np.ndarray | None = None,
2161
+ facet_to_elem_a: np.ndarray | None = None,
2162
+ facet_to_elem_b: np.ndarray | None = None,
2163
+ normal_source: str = "master",
2164
+ normal_from: str | None = None,
2165
+ master_field: str | None = None,
2166
+ normal_sign: float = 1.0,
2167
+ grad_source: str = "volume",
2168
+ dof_source: str = "surface",
2169
+ quad_order: int = 0,
2170
+ tol: float = 1e-8,
2171
+ ) -> np.ndarray:
2172
+ """
2173
+ Assemble mixed surface residual over a supermesh (centroid quadrature).
2174
+
2175
+ normal_source can be "master", "slave", "a", "b", or "avg"; use master_field
2176
+ to pick which field acts as the master when normal_source is "master"/"slave".
2177
+ dof_source="volume" assembles into element nodes (requires elem_conn_* mappings).
2178
+ """
2179
+ from ..core.forms import FieldPair
2180
+ coords_a = np.asarray(surface_a.coords, dtype=float)
2181
+ coords_b = np.asarray(surface_b.coords, dtype=float)
2182
+ facets_a = np.asarray(surface_a.conn, dtype=int)
2183
+ facets_b = np.asarray(surface_b.conn, dtype=int)
2184
+ n_a = int(coords_a.shape[0] * value_dim_a)
2185
+ n_b = int(coords_b.shape[0] * value_dim_b)
2186
+ if offset_b is None:
2187
+ offset_b = offset_a + n_a
2188
+ n_total = int(offset_b + n_b)
2189
+ R = np.zeros((n_total,), dtype=float)
2190
+
2191
+ t_norm = time.perf_counter()
2192
+ normals_a = None
2193
+ normals_b = None
2194
+ if hasattr(surface_a, "facet_normals"):
2195
+ normals_a = surface_a.facet_normals()
2196
+ if hasattr(surface_b, "facet_normals"):
2197
+ normals_b = surface_b.facet_normals()
2198
+ if trace:
2199
+ _trace_time("[CONTACT] normals_done", t_norm)
2200
+
2201
+ area_scale = float(os.getenv("FLUXFEM_SMALL_TRI_EPS_SCALE", "0.0"))
2202
+ skip_small_tri = os.getenv("FLUXFEM_SKIP_SMALL_TRI", "0") == "1" and area_scale > 0.0
2203
+ facet_area_a = None
2204
+ facet_area_b = None
2205
+ if area_scale > 0.0:
2206
+ t_area = time.perf_counter()
2207
+ facet_area_a = np.array([_facet_area_estimate(fa, coords_a) for fa in facets_a], dtype=float)
2208
+ facet_area_b = np.array([_facet_area_estimate(fb, coords_b) for fb in facets_b], dtype=float)
2209
+ if trace:
2210
+ _trace_time("[CONTACT] facet_area_done", t_area)
2211
+
2212
+ includes_measure = getattr(res_form, "_includes_measure", {})
2213
+
2214
+ use_elem_a = elem_conn_a is not None and facet_to_elem_a is not None
2215
+ use_elem_b = elem_conn_b is not None and facet_to_elem_b is not None
2216
+
2217
+ if grad_source not in {"volume", "surface"}:
2218
+ raise ValueError("grad_source must be 'volume' or 'surface'")
2219
+ if dof_source not in {"surface", "volume"}:
2220
+ raise ValueError("dof_source must be 'surface' or 'volume'")
2221
+ if dof_source == "volume" and grad_source == "surface":
2222
+ raise ValueError("dof_source 'volume' requires grad_source 'volume'")
2223
+ global _DEBUG_SURFACE_SOURCE_ONCE
2224
+ if grad_source == "surface" and not _DEBUG_SURFACE_SOURCE_ONCE:
2225
+ print("[fluxfem] using surface gradN in mortar")
2226
+ _DEBUG_SURFACE_SOURCE_ONCE = True
2227
+ proj_diag = _proj_diag_enabled()
2228
+ if proj_diag:
2229
+ _proj_diag_reset()
2230
+ diag_force = os.getenv("FLUXFEM_PROJ_DIAG_FORCE", "0") == "1"
2231
+ diag_qp_mode = os.getenv("FLUXFEM_PROJ_DIAG_QP_MODE", "").strip().lower()
2232
+ diag_qp_path = os.getenv("FLUXFEM_PROJ_DIAG_QP_PATH", "").strip()
2233
+ diag_normal = os.getenv("FLUXFEM_PROJ_DIAG_NORMAL", "").strip().lower()
2234
+ diag_facet = int(os.getenv("FLUXFEM_PROJ_DIAG_FACET", "-1"))
2235
+ diag_max_q = int(os.getenv("FLUXFEM_PROJ_DIAG_MAX_Q", "3"))
2236
+ diag_abs_detj = os.getenv("FLUXFEM_PROJ_DIAG_ABS_DETJ", "1") == "1"
2237
+
2238
+ if normal_from is not None:
2239
+ if normal_from not in {"master", "slave"}:
2240
+ raise ValueError("normal_from must be 'master' or 'slave'")
2241
+ master_name = field_a if master_field is None else master_field
2242
+ if master_name not in {field_a, field_b}:
2243
+ raise ValueError("master_field must match field_a or field_b")
2244
+ if normal_from == "master":
2245
+ normal_source = "a" if master_name == field_a else "b"
2246
+ else:
2247
+ normal_source = "b" if master_name == field_a else "a"
2248
+ if diag_force and diag_normal:
2249
+ normal_source = diag_normal
2250
+ if normal_source not in {"a", "b", "avg", "master", "slave"}:
2251
+ raise ValueError("normal_source must be 'a', 'b', 'avg', 'master', or 'slave'")
2252
+ if normal_source == "master":
2253
+ normal_source = "a" if (master_field is None or master_field == field_a) else "b"
2254
+ if normal_source == "slave":
2255
+ normal_source = "b" if (master_field is None or master_field == field_a) else "a"
2256
+
2257
+ mortar_mode = os.getenv("FLUXFEM_MORTAR_MODE", "supermesh").lower()
2258
+ if mortar_mode == "projection":
2259
+ batches, fallback = _projection_surface_batches(
2260
+ source_facets_a,
2261
+ source_facets_b,
2262
+ surface_a,
2263
+ surface_b,
2264
+ elem_conn_a=elem_conn_a,
2265
+ elem_conn_b=elem_conn_b,
2266
+ facet_to_elem_a=facet_to_elem_a,
2267
+ facet_to_elem_b=facet_to_elem_b,
2268
+ quad_order=quad_order,
2269
+ grad_source=grad_source,
2270
+ dof_source=dof_source,
2271
+ normal_source=normal_source,
2272
+ normal_sign=normal_sign,
2273
+ tol=tol,
2274
+ )
2275
+ if batches is not None and not fallback:
2276
+ for batch in batches:
2277
+ Na = batch["Na"]
2278
+ Nb = batch["Nb"]
2279
+ gradNa = batch["gradNa"]
2280
+ gradNb = batch["gradNb"]
2281
+ nodes_a = batch["nodes_a"]
2282
+ nodes_b = batch["nodes_b"]
2283
+ normal_q = batch["normal"]
2284
+
2285
+ field_a_obj = SurfaceMixedFormField(
2286
+ N=Na,
2287
+ gradN=gradNa,
2288
+ value_dim=value_dim_a,
2289
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2290
+ )
2291
+ field_b_obj = SurfaceMixedFormField(
2292
+ N=Nb,
2293
+ gradN=gradNb,
2294
+ value_dim=value_dim_b,
2295
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2296
+ )
2297
+ fields = {
2298
+ field_a: FieldPair(test=field_a_obj, trial=field_a_obj),
2299
+ field_b: FieldPair(test=field_b_obj, trial=field_b_obj),
2300
+ }
2301
+ ctx = SurfaceMixedFormContext(
2302
+ fields=fields,
2303
+ x_q=batch["x_q"],
2304
+ w=batch["w"],
2305
+ detJ=batch["detJ"],
2306
+ normal=normal_q,
2307
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2308
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2309
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2310
+ )
2311
+ u_elem = {
2312
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
2313
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
2314
+ }
2315
+ fe_q = res_form(ctx, u_elem, params)
2316
+ for name, facet, value_dim, offset in (
2317
+ (field_a, nodes_a, value_dim_a, offset_a),
2318
+ (field_b, nodes_b, value_dim_b, offset_b),
2319
+ ):
2320
+ fe_field = fe_q[name]
2321
+ if fe_field.ndim != 2 or fe_field.shape[0] != ctx.x_q.shape[0]:
2322
+ raise ValueError("mixed surface residual must return (n_q, n_ldofs)")
2323
+ if includes_measure.get(name, False):
2324
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
2325
+ else:
2326
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
2327
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
2328
+ dofs = _global_dof_indices(facet, value_dim, int(offset))
2329
+ R[dofs] += np.asarray(fe)
2330
+ return R
2331
+
2332
+ for (tri, a, b, c), fa, fb in zip(
2333
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
2334
+ source_facets_a,
2335
+ source_facets_b,
2336
+ ):
2337
+ area = _tri_area(a, b, c)
2338
+ if area <= tol:
2339
+ continue
2340
+ if skip_small_tri and facet_area_a is not None and facet_area_b is not None:
2341
+ area_ref = max(float(facet_area_a[int(fa)]), float(facet_area_b[int(fb)]))
2342
+ if area_ref > 0.0 and area < area_scale * area_ref:
2343
+ continue
2344
+ detJ = 2.0 * area
2345
+ if diag_force and diag_abs_detj:
2346
+ detJ = abs(detJ)
2347
+ if quad_order <= 0:
2348
+ quad_pts = np.array([[1.0 / 3.0, 1.0 / 3.0]], dtype=float)
2349
+ quad_w = np.array([0.5], dtype=float)
2350
+ else:
2351
+ quad_pts, quad_w = _tri_quadrature(quad_order)
2352
+ quad_source = "fluxfem"
2353
+ quad_override = _diag_quad_override(diag_force, diag_qp_mode, diag_qp_path)
2354
+ if quad_override is not None:
2355
+ quad_pts, quad_w = quad_override
2356
+ quad_source = _DEBUG_PROJ_QP_SOURCE or "override"
2357
+ _diag_quad_dump(diag_force, diag_qp_mode, diag_qp_path, quad_pts, quad_w)
2358
+
2359
+ facet_a = facets_a[int(fa)]
2360
+ facet_b = facets_b[int(fb)]
2361
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
2362
+
2363
+ gradNa = None
2364
+ gradNb = None
2365
+ nodes_a = facet_a
2366
+ nodes_b = facet_b
2367
+
2368
+ Na = None
2369
+ Nb = None
2370
+
2371
+ elem_id_a = -1
2372
+ elem_nodes_a = None
2373
+ elem_coords_a = None
2374
+ if use_elem_a:
2375
+ elem_id_a = int(facet_to_elem_a[int(fa)])
2376
+ if elem_id_a < 0:
2377
+ raise ValueError("facet_to_elem_a has invalid mapping")
2378
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
2379
+ elem_coords_a = coords_a[elem_nodes_a]
2380
+ if elem_coords_a.shape[0] not in {4, 8, 10, 20, 27}:
2381
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
2382
+
2383
+ elem_id_b = -1
2384
+ elem_nodes_b = None
2385
+ elem_coords_b = None
2386
+ if use_elem_b:
2387
+ elem_id_b = int(facet_to_elem_b[int(fb)])
2388
+ if elem_id_b < 0:
2389
+ raise ValueError("facet_to_elem_b has invalid mapping")
2390
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
2391
+ elem_coords_b = coords_b[elem_nodes_b]
2392
+ if elem_coords_b.shape[0] not in {4, 8, 10, 20, 27}:
2393
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
2394
+ if proj_diag:
2395
+ _proj_diag_set_context(
2396
+ fa=int(fa),
2397
+ fb=int(fb),
2398
+ face_a=_facet_label(facet_a),
2399
+ face_b=_facet_label(facet_b),
2400
+ elem_a=elem_id_a,
2401
+ elem_b=elem_id_b,
2402
+ )
2403
+
2404
+ if grad_source == "surface":
2405
+ gradNa = np.array(
2406
+ [_surface_gradN(pt, facet_a, coords_a, tol=tol) for pt in x_q],
2407
+ dtype=float,
2408
+ )
2409
+ gradNb = np.array(
2410
+ [_surface_gradN(pt, facet_b, coords_b, tol=tol) for pt in x_q],
2411
+ dtype=float,
2412
+ )
2413
+ if use_elem_a and grad_source == "volume":
2414
+ local = _local_indices(elem_nodes_a, facet_a)
2415
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, local=local, tol=tol)
2416
+
2417
+ if use_elem_b and grad_source == "volume":
2418
+ local = _local_indices(elem_nodes_b, facet_b)
2419
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, local=local, tol=tol)
2420
+
2421
+ if dof_source == "volume":
2422
+ if not use_elem_a or elem_nodes_a is None or elem_coords_a is None:
2423
+ raise ValueError("dof_source 'volume' requires elem_conn_a and facet_to_elem_a")
2424
+ if not use_elem_b or elem_nodes_b is None or elem_coords_b is None:
2425
+ raise ValueError("dof_source 'volume' requires elem_conn_b and facet_to_elem_b")
2426
+ nodes_a = elem_nodes_a
2427
+ nodes_b = elem_nodes_b
2428
+ Na = _volume_shape_values_at_points(x_q, elem_coords_a, tol=tol)
2429
+ Nb = _volume_shape_values_at_points(x_q, elem_coords_b, tol=tol)
2430
+ if grad_source == "volume":
2431
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, tol=tol)
2432
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, tol=tol)
2433
+ else:
2434
+ Na = np.array([_facet_shape_values(pt, facet_a, coords_a, tol=tol) for pt in x_q], dtype=float)
2435
+ Nb = np.array([_facet_shape_values(pt, facet_b, coords_b, tol=tol) for pt in x_q], dtype=float)
2436
+
2437
+ normal = None
2438
+ na = normals_a[int(fa)] if normals_a is not None else None
2439
+ nb = normals_b[int(fb)] if normals_b is not None else None
2440
+ if normal_source == "a":
2441
+ normal = na
2442
+ elif normal_source == "b":
2443
+ normal = nb
2444
+ else:
2445
+ if na is not None and nb is not None:
2446
+ avg = na + nb
2447
+ norm = np.linalg.norm(avg)
2448
+ normal = avg / norm if norm > tol else na
2449
+ else:
2450
+ normal = na if na is not None else nb
2451
+ if normal is not None:
2452
+ normal = normal_sign * normal
2453
+ if diag_force:
2454
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
2455
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
2456
+ _diag_contact_projection(
2457
+ fa=int(fa),
2458
+ fb=int(fb),
2459
+ quad_pts=quad_pts,
2460
+ quad_w=quad_w,
2461
+ x_q=x_q,
2462
+ Na=Na,
2463
+ Nb=Nb,
2464
+ nodes_a=nodes_a,
2465
+ nodes_b=nodes_b,
2466
+ dofs_a=dofs_a,
2467
+ dofs_b=dofs_b,
2468
+ elem_coords_a=elem_coords_a if dof_source == "volume" else None,
2469
+ elem_coords_b=elem_coords_b if dof_source == "volume" else None,
2470
+ na=na,
2471
+ nb=nb,
2472
+ normal=normal,
2473
+ normal_source=normal_source,
2474
+ normal_sign=normal_sign,
2475
+ detJ=detJ,
2476
+ diag_facet=diag_facet,
2477
+ diag_max_q=diag_max_q,
2478
+ quad_source=quad_source,
2479
+ tol=tol,
2480
+ )
2481
+
2482
+ field_a_obj = SurfaceMixedFormField(
2483
+ N=Na,
2484
+ gradN=gradNa,
2485
+ value_dim=value_dim_a,
2486
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2487
+ )
2488
+ field_b_obj = SurfaceMixedFormField(
2489
+ N=Nb,
2490
+ gradN=gradNb,
2491
+ value_dim=value_dim_b,
2492
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2493
+ )
2494
+ fields = {
2495
+ field_a: FieldPair(test=field_a_obj, trial=field_a_obj),
2496
+ field_b: FieldPair(test=field_b_obj, trial=field_b_obj),
2497
+ }
2498
+ normal_q = None if normal is None else np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
2499
+ ctx = SurfaceMixedFormContext(
2500
+ fields=fields,
2501
+ x_q=x_q,
2502
+ w=quad_w,
2503
+ detJ=np.array([detJ], dtype=float),
2504
+ normal=normal_q,
2505
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2506
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2507
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2508
+ )
2509
+
2510
+ u_elem = {
2511
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
2512
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
2513
+ }
2514
+ fe_q = res_form(ctx, u_elem, params)
2515
+ for name, facet, value_dim, offset in (
2516
+ (field_a, nodes_a, value_dim_a, offset_a),
2517
+ (field_b, nodes_b, value_dim_b, offset_b),
2518
+ ):
2519
+ fe_field = fe_q[name]
2520
+ if fe_field.ndim != 2 or fe_field.shape[0] != ctx.x_q.shape[0]:
2521
+ raise ValueError("surface residual must return shape (n_q, n_ldofs) per field")
2522
+ if includes_measure.get(name, False):
2523
+ fe = np.sum(np.asarray(fe_field), axis=0)
2524
+ else:
2525
+ wJ = ctx.w * ctx.detJ
2526
+ fe = np.einsum("qi,q->i", np.asarray(fe_field), wJ)
2527
+ dofs = _global_dof_indices(facet, value_dim, int(offset))
2528
+ R[dofs] += fe
2529
+ if proj_diag:
2530
+ _proj_diag_report()
2531
+ return R
2532
+
2533
+
2534
+ def assemble_mixed_surface_jacobian(
2535
+ supermesh_coords: np.ndarray,
2536
+ supermesh_conn: np.ndarray,
2537
+ source_facets_a: Iterable[int],
2538
+ source_facets_b: Iterable[int],
2539
+ surface_a: SurfaceMesh,
2540
+ surface_b: SurfaceMesh,
2541
+ res_form,
2542
+ u_a: np.ndarray,
2543
+ u_b: np.ndarray,
2544
+ params,
2545
+ *,
2546
+ value_dim_a: int = 1,
2547
+ value_dim_b: int = 1,
2548
+ offset_a: int = 0,
2549
+ offset_b: int | None = None,
2550
+ field_a: str = "a",
2551
+ field_b: str = "b",
2552
+ elem_conn_a: np.ndarray | None = None,
2553
+ elem_conn_b: np.ndarray | None = None,
2554
+ facet_to_elem_a: np.ndarray | None = None,
2555
+ facet_to_elem_b: np.ndarray | None = None,
2556
+ normal_source: str = "master",
2557
+ normal_from: str | None = None,
2558
+ master_field: str | None = None,
2559
+ normal_sign: float = 1.0,
2560
+ grad_source: str = "volume",
2561
+ dof_source: str = "surface",
2562
+ quad_order: int = 0,
2563
+ tol: float = 1e-8,
2564
+ sparse: bool = False,
2565
+ backend: str = "jax",
2566
+ batch_jac: bool | None = None,
2567
+ fd_eps: float = 1e-6,
2568
+ fd_mode: str = "central",
2569
+ fd_block_size: int = 1,
2570
+ ):
2571
+ """
2572
+ Assemble mixed surface Jacobian over a supermesh (centroid quadrature).
2573
+
2574
+ normal_source can be "master", "slave", "a", "b", or "avg"; use master_field
2575
+ to pick which field acts as the master when normal_source is "master"/"slave".
2576
+ dof_source="volume" assembles into element nodes (requires elem_conn_* mappings).
2577
+ """
2578
+ from ..core.forms import FieldPair
2579
+ _mortar_dbg(
2580
+ f"[mortar] enter assemble_mixed_surface_jacobian quad_order={quad_order} backend={backend}"
2581
+ )
2582
+ trace = os.getenv("FLUXFEM_MORTAR_TRACE", "0") not in ("0", "", "false", "False")
2583
+ trace_max = int(os.getenv("FLUXFEM_MORTAR_TRACE_MAX", "5"))
2584
+ trace_every = int(os.getenv("FLUXFEM_MORTAR_TRACE_EVERY", "50"))
2585
+ trace_fd_max = int(os.getenv("FLUXFEM_MORTAR_TRACE_FD_MAX", "5"))
2586
+ def _trace(msg: str) -> None:
2587
+ if trace:
2588
+ print(msg, flush=True)
2589
+ def _trace_time(msg: str, t0: float) -> None:
2590
+ if trace:
2591
+ print(f"{msg}: {time.perf_counter() - t0:.6f}s", flush=True)
2592
+ t_prep = time.perf_counter()
2593
+ coords_a = np.asarray(surface_a.coords, dtype=float)
2594
+ coords_b = np.asarray(surface_b.coords, dtype=float)
2595
+ facets_a = np.asarray(surface_a.conn, dtype=int)
2596
+ facets_b = np.asarray(surface_b.conn, dtype=int)
2597
+ n_a = int(coords_a.shape[0] * value_dim_a)
2598
+ n_b = int(coords_b.shape[0] * value_dim_b)
2599
+ if offset_b is None:
2600
+ offset_b = offset_a + n_a
2601
+ n_total = int(offset_b + n_b)
2602
+ if trace:
2603
+ _trace("[CONTACT] assemble_mixed_surface_jacobian ENTER")
2604
+ _trace(f"[CONTACT] shapes: coords_a={coords_a.shape} coords_b={coords_b.shape} supermesh={supermesh_conn.shape}")
2605
+ _trace(f"[CONTACT] dtypes: coords_a={coords_a.dtype} coords_b={coords_b.dtype} supermesh={supermesh_conn.dtype}")
2606
+ _trace(f"[CONTACT] finite: coords_a={np.isfinite(coords_a).all()} coords_b={np.isfinite(coords_b).all()}")
2607
+ _trace_time("[CONTACT] prep_done", t_prep)
2608
+
2609
+ guard = os.getenv("FLUXFEM_CONTACT_GUARD", "0") == "1"
2610
+ detj_eps = float(os.getenv("FLUXFEM_CONTACT_DETJ_EPS", "0.0"))
2611
+ tri_timeout = float(os.getenv("FLUXFEM_CONTACT_TRI_TIMEOUT_S", "0.0"))
2612
+ skip_nonfinite = os.getenv("FLUXFEM_CONTACT_SKIP_NONFINITE", "1") == "1"
2613
+ if guard:
2614
+ if not (np.isfinite(coords_a).all() and np.isfinite(coords_b).all()):
2615
+ raise RuntimeError("[CONTACT] non-finite coords in contact surfaces")
2616
+ if not np.isfinite(supermesh_coords).all():
2617
+ raise RuntimeError("[CONTACT] non-finite supermesh coords")
2618
+ if supermesh_conn.size:
2619
+ min_idx = int(supermesh_conn.min())
2620
+ max_idx = int(supermesh_conn.max())
2621
+ if min_idx < 0 or max_idx >= supermesh_coords.shape[0]:
2622
+ raise RuntimeError(
2623
+ f"[CONTACT] supermesh_conn index out of range: min={min_idx} max={max_idx} n={supermesh_coords.shape[0]}"
2624
+ )
2625
+ if len(supermesh_conn) != len(source_facets_a) or len(supermesh_conn) != len(source_facets_b):
2626
+ raise RuntimeError(
2627
+ "[CONTACT] supermesh_conn and source_facets lengths mismatch "
2628
+ f"conn={len(supermesh_conn)} fa={len(source_facets_a)} fb={len(source_facets_b)}"
2629
+ )
2630
+
2631
+ normals_a = None
2632
+ normals_b = None
2633
+ if hasattr(surface_a, "facet_normals"):
2634
+ normals_a = surface_a.facet_normals()
2635
+ if hasattr(surface_b, "facet_normals"):
2636
+ normals_b = surface_b.facet_normals()
2637
+
2638
+ area_scale = float(os.getenv("FLUXFEM_SMALL_TRI_EPS_SCALE", "0.0"))
2639
+ skip_small_tri = os.getenv("FLUXFEM_SKIP_SMALL_TRI", "0") == "1" and area_scale > 0.0
2640
+ facet_area_a = None
2641
+ facet_area_b = None
2642
+ if area_scale > 0.0:
2643
+ facet_area_a = np.array([_facet_area_estimate(fa, coords_a) for fa in facets_a], dtype=float)
2644
+ facet_area_b = np.array([_facet_area_estimate(fb, coords_b) for fb in facets_b], dtype=float)
2645
+
2646
+ includes_measure = getattr(res_form, "_includes_measure", {})
2647
+
2648
+ rows: list[int] = []
2649
+ cols: list[int] = []
2650
+ data: list[float] = []
2651
+ K_dense = np.zeros((n_total, n_total), dtype=float) if not sparse else None
2652
+
2653
+ use_elem_a = elem_conn_a is not None and facet_to_elem_a is not None
2654
+ use_elem_b = elem_conn_b is not None and facet_to_elem_b is not None
2655
+
2656
+ if grad_source not in {"volume", "surface"}:
2657
+ raise ValueError("grad_source must be 'volume' or 'surface'")
2658
+ if dof_source not in {"surface", "volume"}:
2659
+ raise ValueError("dof_source must be 'surface' or 'volume'")
2660
+ if dof_source == "volume" and grad_source == "surface":
2661
+ raise ValueError("dof_source 'volume' requires grad_source 'volume'")
2662
+ global _DEBUG_SURFACE_SOURCE_ONCE
2663
+ if grad_source == "surface" and not _DEBUG_SURFACE_SOURCE_ONCE:
2664
+ print("[fluxfem] using surface gradN in mortar")
2665
+ _DEBUG_SURFACE_SOURCE_ONCE = True
2666
+ diag_map = os.getenv("FLUXFEM_DIAG_CONTACT_MAP", "0") == "1"
2667
+ diag_n = os.getenv("FLUXFEM_DIAG_CONTACT_N", "0") == "1"
2668
+ proj_diag = _proj_diag_enabled()
2669
+ if proj_diag:
2670
+ _proj_diag_reset()
2671
+ diag_force = os.getenv("FLUXFEM_PROJ_DIAG_FORCE", "0") == "1"
2672
+ diag_qp_mode = os.getenv("FLUXFEM_PROJ_DIAG_QP_MODE", "").strip().lower()
2673
+ diag_qp_path = os.getenv("FLUXFEM_PROJ_DIAG_QP_PATH", "").strip()
2674
+ diag_normal = os.getenv("FLUXFEM_PROJ_DIAG_NORMAL", "").strip().lower()
2675
+ diag_facet = int(os.getenv("FLUXFEM_PROJ_DIAG_FACET", "-1"))
2676
+ diag_max_q = int(os.getenv("FLUXFEM_PROJ_DIAG_MAX_Q", "3"))
2677
+ diag_abs_detj = os.getenv("FLUXFEM_PROJ_DIAG_ABS_DETJ", "1") == "1"
2678
+ if backend not in {"jax", "numpy"}:
2679
+ raise ValueError("backend must be 'jax' or 'numpy'")
2680
+ if backend == "numpy":
2681
+ if fd_eps <= 0.0:
2682
+ raise ValueError("fd_eps must be positive for numpy backend")
2683
+ if fd_mode not in {"central", "forward"}:
2684
+ raise ValueError("fd_mode must be 'central' or 'forward' for numpy backend")
2685
+ if batch_jac is None:
2686
+ batch_jac = _env_flag("FLUXFEM_MORTAR_BATCH_JAC", True)
2687
+
2688
+ if normal_from is not None:
2689
+ if normal_from not in {"master", "slave"}:
2690
+ raise ValueError("normal_from must be 'master' or 'slave'")
2691
+ master_name = field_a if master_field is None else master_field
2692
+ if master_name not in {field_a, field_b}:
2693
+ raise ValueError("master_field must match field_a or field_b")
2694
+ if normal_from == "master":
2695
+ normal_source = "a" if master_name == field_a else "b"
2696
+ else:
2697
+ normal_source = "b" if master_name == field_a else "a"
2698
+ if diag_force and diag_normal:
2699
+ normal_source = diag_normal
2700
+ if normal_source not in {"a", "b", "avg", "master", "slave"}:
2701
+ raise ValueError("normal_source must be 'a', 'b', 'avg', 'master', or 'slave'")
2702
+ if normal_source == "master":
2703
+ normal_source = "a" if (master_field is None or master_field == field_a) else "b"
2704
+ if normal_source == "slave":
2705
+ normal_source = "b" if (master_field is None or master_field == field_a) else "a"
2706
+
2707
+ mortar_mode = os.getenv("FLUXFEM_MORTAR_MODE", "supermesh").lower()
2708
+ _mortar_dbg(f"[mortar] mode={mortar_mode}")
2709
+ if mortar_mode == "projection":
2710
+ batches, fallback = _projection_surface_batches(
2711
+ source_facets_a,
2712
+ source_facets_b,
2713
+ surface_a,
2714
+ surface_b,
2715
+ elem_conn_a=elem_conn_a,
2716
+ elem_conn_b=elem_conn_b,
2717
+ facet_to_elem_a=facet_to_elem_a,
2718
+ facet_to_elem_b=facet_to_elem_b,
2719
+ quad_order=quad_order,
2720
+ grad_source=grad_source,
2721
+ dof_source=dof_source,
2722
+ normal_source=normal_source,
2723
+ normal_sign=normal_sign,
2724
+ tol=tol,
2725
+ )
2726
+ if batches is not None and not fallback:
2727
+ for batch in batches:
2728
+ Na = batch["Na"]
2729
+ Nb = batch["Nb"]
2730
+ gradNa = batch["gradNa"]
2731
+ gradNb = batch["gradNb"]
2732
+ nodes_a = batch["nodes_a"]
2733
+ nodes_b = batch["nodes_b"]
2734
+ normal_q = batch["normal"]
2735
+
2736
+ field_a_obj = SurfaceMixedFormField(
2737
+ N=Na,
2738
+ gradN=gradNa,
2739
+ value_dim=value_dim_a,
2740
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2741
+ )
2742
+ field_b_obj = SurfaceMixedFormField(
2743
+ N=Nb,
2744
+ gradN=gradNb,
2745
+ value_dim=value_dim_b,
2746
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2747
+ )
2748
+ fields = {
2749
+ field_a: FieldPair(test=field_a_obj, trial=field_a_obj),
2750
+ field_b: FieldPair(test=field_b_obj, trial=field_b_obj),
2751
+ }
2752
+ ctx = SurfaceMixedFormContext(
2753
+ fields=fields,
2754
+ x_q=batch["x_q"],
2755
+ w=batch["w"],
2756
+ detJ=batch["detJ"],
2757
+ normal=normal_q,
2758
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2759
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2760
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2761
+ )
2762
+
2763
+ u_elem = {
2764
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
2765
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
2766
+ }
2767
+ u_local = np.concatenate([u_elem[field_a], u_elem[field_b]], axis=0)
2768
+ sizes = (u_elem[field_a].shape[0], u_elem[field_b].shape[0])
2769
+ slices = {
2770
+ field_a: slice(0, sizes[0]),
2771
+ field_b: slice(sizes[0], sizes[0] + sizes[1]),
2772
+ }
2773
+
2774
+ def _res_local_np(u_vec):
2775
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
2776
+ fe_q = res_form(ctx, u_dict, params)
2777
+ res_parts = []
2778
+ for name in (field_a, field_b):
2779
+ fe_field = fe_q[name]
2780
+ if includes_measure.get(name, False):
2781
+ fe = np.sum(np.asarray(fe_field), axis=0)
2782
+ else:
2783
+ wJ = np.asarray(ctx.w) * np.asarray(ctx.detJ)
2784
+ fe = np.einsum("qi,q->i", np.asarray(fe_field), wJ)
2785
+ res_parts.append(np.asarray(fe))
2786
+ return np.concatenate(res_parts, axis=0)
2787
+
2788
+ if backend == "jax":
2789
+ def _res_local(u_vec):
2790
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
2791
+ fe_q = res_form(ctx, u_dict, params)
2792
+ res_parts = []
2793
+ for name in (field_a, field_b):
2794
+ fe_field = fe_q[name]
2795
+ if includes_measure.get(name, False):
2796
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
2797
+ else:
2798
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
2799
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
2800
+ res_parts.append(fe)
2801
+ return jnp.concatenate(res_parts, axis=0)
2802
+
2803
+ J_local = jax.jacrev(_res_local)(jnp.asarray(u_local))
2804
+ J_local_np = np.asarray(J_local)
2805
+ else:
2806
+ n_ldofs = int(u_local.shape[0])
2807
+ J_local_np = np.zeros((n_ldofs, n_ldofs), dtype=float)
2808
+ u_base = np.asarray(u_local, dtype=float)
2809
+ r0 = _res_local_np(u_base) if fd_mode == "forward" else None
2810
+ for i in range(n_ldofs):
2811
+ u_p = u_base.copy()
2812
+ u_p[i] += fd_eps
2813
+ r_p = _res_local_np(u_p)
2814
+ if fd_mode == "central":
2815
+ u_m = u_base.copy()
2816
+ u_m[i] -= fd_eps
2817
+ r_m = _res_local_np(u_m)
2818
+ col = (r_p - r_m) / (2.0 * fd_eps)
2819
+ else:
2820
+ col = (r_p - r0) / fd_eps
2821
+ J_local_np[:, i] = np.asarray(col, dtype=float)
2822
+
2823
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
2824
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
2825
+ dofs = np.concatenate([dofs_a, dofs_b], axis=0)
2826
+ for i, gi in enumerate(dofs):
2827
+ for j, gj in enumerate(dofs):
2828
+ val = float(J_local_np[i, j])
2829
+ if sparse:
2830
+ rows.append(int(gi))
2831
+ cols.append(int(gj))
2832
+ data.append(val)
2833
+ else:
2834
+ K_dense[int(gi), int(gj)] += val
2835
+ if sparse:
2836
+ return np.asarray(rows, dtype=int), np.asarray(cols, dtype=int), np.asarray(data, dtype=float), n_total
2837
+ assert K_dense is not None
2838
+ return K_dense
2839
+
2840
+ if (
2841
+ batch_jac
2842
+ and backend == "jax"
2843
+ and dof_source == "volume"
2844
+ and grad_source == "volume"
2845
+ and use_elem_a
2846
+ and use_elem_b
2847
+ and not proj_diag
2848
+ and not diag_force
2849
+ ):
2850
+ if trace:
2851
+ _trace("[CONTACT] batch_jac_enter")
2852
+ batch_items = []
2853
+ dofs_batch = []
2854
+ u_local_batch = []
2855
+ batch_rows: list[np.ndarray] = []
2856
+ batch_cols: list[np.ndarray] = []
2857
+ batch_data: list[np.ndarray] = []
2858
+ batch_size = int(os.getenv("FLUXFEM_MORTAR_BATCH_SIZE", "128"))
2859
+ if batch_size <= 0:
2860
+ batch_size = 0
2861
+ n_q = None
2862
+ n_nodes_a = None
2863
+ n_nodes_b = None
2864
+ n_a_local_const = None
2865
+ n_b_local_const = None
2866
+ batch_failed = False
2867
+ jit_batch = _env_flag("FLUXFEM_MORTAR_BATCH_JIT", False)
2868
+
2869
+ def _make_jac_fun(n_a_local: int, n_b_local: int):
2870
+ def _res_local_batch(u_vec, Na, Nb, gradNa, gradNb, x_q, w, detJ, normal):
2871
+ field_a_obj = SurfaceMixedFormField(
2872
+ N=Na,
2873
+ gradN=gradNa,
2874
+ value_dim=value_dim_a,
2875
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
2876
+ )
2877
+ field_b_obj = SurfaceMixedFormField(
2878
+ N=Nb,
2879
+ gradN=gradNb,
2880
+ value_dim=value_dim_b,
2881
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
2882
+ )
2883
+ fields = {
2884
+ field_a: FieldPair(test=field_a_obj, trial=field_a_obj),
2885
+ field_b: FieldPair(test=field_b_obj, trial=field_b_obj),
2886
+ }
2887
+ normal_q = jnp.repeat(normal[None, :], x_q.shape[0], axis=0)
2888
+ ctx = SurfaceMixedFormContext(
2889
+ fields=fields,
2890
+ x_q=x_q,
2891
+ w=w,
2892
+ detJ=detJ,
2893
+ normal=normal_q,
2894
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
2895
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
2896
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
2897
+ )
2898
+ u_dict = {
2899
+ field_a: u_vec[:n_a_local],
2900
+ field_b: u_vec[n_a_local:],
2901
+ }
2902
+ fe_q = res_form(ctx, u_dict, params)
2903
+ res_parts = []
2904
+ for name in (field_a, field_b):
2905
+ fe_field = fe_q[name]
2906
+ if includes_measure.get(name, False):
2907
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
2908
+ else:
2909
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
2910
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
2911
+ res_parts.append(fe)
2912
+ return jnp.concatenate(res_parts, axis=0)
2913
+
2914
+ if trace:
2915
+ _trace(f"[CONTACT] batch_jac_build n_a={n_a_local} n_b={n_b_local} jit={jit_batch}")
2916
+ jac_fun = jax.vmap(jax.jacrev(_res_local_batch))
2917
+ return jax.jit(jac_fun) if jit_batch else jac_fun
2918
+
2919
+ jac_fun_cache: dict[tuple[int, int], object] = {}
2920
+
2921
+ def _emit_batch(
2922
+ Na_b,
2923
+ Nb_b,
2924
+ gradNa_b,
2925
+ gradNb_b,
2926
+ x_q_b,
2927
+ w_b,
2928
+ detJ_b,
2929
+ normal_b,
2930
+ u_local_b,
2931
+ dofs_batch_np,
2932
+ n_a_local,
2933
+ n_b_local,
2934
+ batch_n,
2935
+ ) -> None:
2936
+ if trace:
2937
+ _trace(f"[CONTACT] batch_emit start n={int(Na_b.shape[0])}")
2938
+ if batch_size and batch_n < batch_size:
2939
+ pad = int(batch_size - batch_n)
2940
+ if trace:
2941
+ _trace(f"[CONTACT] batch_pad n={batch_n} target={batch_size}")
2942
+
2943
+ def _pad_batch(x, pad_value: float = 0.0):
2944
+ pad_width = [(0, pad)] + [(0, 0)] * (x.ndim - 1)
2945
+ return jnp.pad(jnp.asarray(x), pad_width, mode="constant", constant_values=pad_value)
2946
+
2947
+ Na_b = _pad_batch(Na_b)
2948
+ Nb_b = _pad_batch(Nb_b)
2949
+ gradNa_b = _pad_batch(gradNa_b)
2950
+ gradNb_b = _pad_batch(gradNb_b)
2951
+ x_q_b = _pad_batch(x_q_b)
2952
+ w_b = _pad_batch(w_b)
2953
+ detJ_b = _pad_batch(detJ_b)
2954
+ normal_b = _pad_batch(normal_b)
2955
+ u_local_b = _pad_batch(u_local_b)
2956
+ key = (n_a_local, n_b_local)
2957
+ jac_fun = jac_fun_cache.get(key)
2958
+ if jac_fun is None:
2959
+ jac_fun = _make_jac_fun(n_a_local, n_b_local)
2960
+ jac_fun_cache[key] = jac_fun
2961
+ t_batch = time.perf_counter()
2962
+ J_b = jac_fun(u_local_b, Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b)
2963
+ J_b_np = np.asarray(J_b)[:batch_n]
2964
+ if trace:
2965
+ _trace_time("[CONTACT] batch_emit jac_done", t_batch)
2966
+ n_ldofs = dofs_batch_np.shape[1]
2967
+ rows = np.repeat(dofs_batch_np, n_ldofs, axis=1).reshape(-1)
2968
+ cols = np.tile(dofs_batch_np, (1, n_ldofs)).reshape(-1)
2969
+ data = J_b_np.reshape(-1)
2970
+ if sparse:
2971
+ batch_rows.append(rows)
2972
+ batch_cols.append(cols)
2973
+ batch_data.append(data)
2974
+ else:
2975
+ assert K_dense is not None
2976
+ K_dense[rows, cols] += data
2977
+ for (tri, a, b, c), fa, fb in zip(
2978
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
2979
+ source_facets_a,
2980
+ source_facets_b,
2981
+ ):
2982
+ area = _tri_area(a, b, c)
2983
+ if area <= tol:
2984
+ continue
2985
+ if skip_small_tri and facet_area_a is not None and facet_area_b is not None:
2986
+ area_ref = max(float(facet_area_a[int(fa)]), float(facet_area_b[int(fb)]))
2987
+ if area_ref > 0.0 and area < area_scale * area_ref:
2988
+ continue
2989
+ detJ = 2.0 * area
2990
+ if diag_force and diag_abs_detj:
2991
+ detJ = abs(detJ)
2992
+ if quad_order <= 0:
2993
+ quad_pts = np.array([[1.0 / 3.0, 1.0 / 3.0]], dtype=float)
2994
+ quad_w = np.array([0.5], dtype=float)
2995
+ else:
2996
+ quad_pts, quad_w = _tri_quadrature(quad_order)
2997
+
2998
+ facet_a = facets_a[int(fa)]
2999
+ facet_b = facets_b[int(fb)]
3000
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
3001
+
3002
+ elem_id_a = int(facet_to_elem_a[int(fa)])
3003
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
3004
+ elem_coords_a = coords_a[elem_nodes_a]
3005
+ elem_id_b = int(facet_to_elem_b[int(fb)])
3006
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
3007
+ elem_coords_b = coords_b[elem_nodes_b]
3008
+
3009
+ Na = _volume_shape_values_at_points(x_q, elem_coords_a, tol=tol)
3010
+ Nb = _volume_shape_values_at_points(x_q, elem_coords_b, tol=tol)
3011
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, tol=tol)
3012
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, tol=tol)
3013
+
3014
+ na = normals_a[int(fa)] if normals_a is not None else None
3015
+ nb = normals_b[int(fb)] if normals_b is not None else None
3016
+ if normal_source == "a":
3017
+ normal = na
3018
+ elif normal_source == "b":
3019
+ normal = nb
3020
+ else:
3021
+ if na is not None and nb is not None:
3022
+ avg = na + nb
3023
+ norm = np.linalg.norm(avg)
3024
+ normal = avg / norm if norm > tol else na
3025
+ else:
3026
+ normal = na if na is not None else nb
3027
+ if normal is not None:
3028
+ normal = normal_sign * normal
3029
+ if normal is None:
3030
+ batch_failed = True
3031
+ break
3032
+
3033
+ u_elem = {
3034
+ field_a: _gather_u_local(u_a, elem_nodes_a, value_dim_a),
3035
+ field_b: _gather_u_local(u_b, elem_nodes_b, value_dim_b),
3036
+ }
3037
+ u_local = np.concatenate([u_elem[field_a], u_elem[field_b]], axis=0)
3038
+
3039
+ dofs_a = _global_dof_indices(elem_nodes_a, value_dim_a, int(offset_a))
3040
+ dofs_b = _global_dof_indices(elem_nodes_b, value_dim_b, int(offset_b))
3041
+ dofs = np.concatenate([dofs_a, dofs_b], axis=0)
3042
+
3043
+ batch_items.append((Na, Nb, gradNa, gradNb, x_q, quad_w, detJ, normal))
3044
+ dofs_batch.append(dofs)
3045
+ u_local_batch.append(u_local)
3046
+
3047
+ if n_q is None:
3048
+ n_q = Na.shape[0]
3049
+ n_nodes_a = Na.shape[1]
3050
+ n_nodes_b = Nb.shape[1]
3051
+ n_a_local_const = dofs_a.shape[0]
3052
+ n_b_local_const = dofs_b.shape[0]
3053
+ else:
3054
+ shape_mismatch = (
3055
+ Na.shape[0] != n_q
3056
+ or Nb.shape[0] != n_q
3057
+ or Na.shape[1] != n_nodes_a
3058
+ or Nb.shape[1] != n_nodes_b
3059
+ or dofs_a.shape[0] != n_a_local_const
3060
+ or dofs_b.shape[0] != n_b_local_const
3061
+ )
3062
+ if shape_mismatch:
3063
+ if batch_items:
3064
+ Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b = zip(*batch_items)
3065
+ Na_b = jnp.asarray(np.stack(Na_b, axis=0))
3066
+ Nb_b = jnp.asarray(np.stack(Nb_b, axis=0))
3067
+ gradNa_b = jnp.asarray(np.stack(gradNa_b, axis=0))
3068
+ gradNb_b = jnp.asarray(np.stack(gradNb_b, axis=0))
3069
+ x_q_b = jnp.asarray(np.stack(x_q_b, axis=0))
3070
+ w_b = jnp.asarray(np.stack(w_b, axis=0))
3071
+ detJ_b = jnp.asarray(np.array(detJ_b, dtype=float)).reshape(-1, 1)
3072
+ normal_b = jnp.asarray(np.stack(normal_b, axis=0))
3073
+ u_local_b = jnp.asarray(np.stack(u_local_batch, axis=0))
3074
+ dofs_batch_np = np.asarray(dofs_batch, dtype=int)
3075
+ _emit_batch(
3076
+ Na_b,
3077
+ Nb_b,
3078
+ gradNa_b,
3079
+ gradNb_b,
3080
+ x_q_b,
3081
+ w_b,
3082
+ detJ_b,
3083
+ normal_b,
3084
+ u_local_b,
3085
+ dofs_batch_np,
3086
+ int(n_a_local_const),
3087
+ int(n_b_local_const),
3088
+ int(Na_b.shape[0]),
3089
+ )
3090
+ batch_items = [(Na, Nb, gradNa, gradNb, x_q, quad_w, detJ, normal)]
3091
+ dofs_batch = [dofs]
3092
+ u_local_batch = [u_local]
3093
+ n_q = Na.shape[0]
3094
+ n_nodes_a = Na.shape[1]
3095
+ n_nodes_b = Nb.shape[1]
3096
+ n_a_local_const = dofs_a.shape[0]
3097
+ n_b_local_const = dofs_b.shape[0]
3098
+
3099
+ if batch_size and len(batch_items) >= batch_size:
3100
+ Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b = zip(*batch_items)
3101
+ Na_b = jnp.asarray(np.stack(Na_b, axis=0))
3102
+ Nb_b = jnp.asarray(np.stack(Nb_b, axis=0))
3103
+ gradNa_b = jnp.asarray(np.stack(gradNa_b, axis=0))
3104
+ gradNb_b = jnp.asarray(np.stack(gradNb_b, axis=0))
3105
+ x_q_b = jnp.asarray(np.stack(x_q_b, axis=0))
3106
+ w_b = jnp.asarray(np.stack(w_b, axis=0))
3107
+ detJ_b = jnp.asarray(np.array(detJ_b, dtype=float)).reshape(-1, 1)
3108
+ normal_b = jnp.asarray(np.stack(normal_b, axis=0))
3109
+ u_local_b = jnp.asarray(np.stack(u_local_batch, axis=0))
3110
+ dofs_batch_np = np.asarray(dofs_batch, dtype=int)
3111
+ _emit_batch(
3112
+ Na_b,
3113
+ Nb_b,
3114
+ gradNa_b,
3115
+ gradNb_b,
3116
+ x_q_b,
3117
+ w_b,
3118
+ detJ_b,
3119
+ normal_b,
3120
+ u_local_b,
3121
+ dofs_batch_np,
3122
+ int(n_a_local_const),
3123
+ int(n_b_local_const),
3124
+ int(Na_b.shape[0]),
3125
+ )
3126
+ batch_items = []
3127
+ dofs_batch = []
3128
+ u_local_batch = []
3129
+
3130
+ if not batch_failed and batch_items:
3131
+ Na_b, Nb_b, gradNa_b, gradNb_b, x_q_b, w_b, detJ_b, normal_b = zip(*batch_items)
3132
+ Na_b = jnp.asarray(np.stack(Na_b, axis=0))
3133
+ Nb_b = jnp.asarray(np.stack(Nb_b, axis=0))
3134
+ gradNa_b = jnp.asarray(np.stack(gradNa_b, axis=0))
3135
+ gradNb_b = jnp.asarray(np.stack(gradNb_b, axis=0))
3136
+ x_q_b = jnp.asarray(np.stack(x_q_b, axis=0))
3137
+ w_b = jnp.asarray(np.stack(w_b, axis=0))
3138
+ detJ_b = jnp.asarray(np.array(detJ_b, dtype=float)).reshape(-1, 1)
3139
+ normal_b = jnp.asarray(np.stack(normal_b, axis=0))
3140
+ u_local_b = jnp.asarray(np.stack(u_local_batch, axis=0))
3141
+ dofs_batch_np = np.asarray(dofs_batch, dtype=int)
3142
+ _emit_batch(
3143
+ Na_b,
3144
+ Nb_b,
3145
+ gradNa_b,
3146
+ gradNb_b,
3147
+ x_q_b,
3148
+ w_b,
3149
+ detJ_b,
3150
+ normal_b,
3151
+ u_local_b,
3152
+ dofs_batch_np,
3153
+ int(n_a_local_const),
3154
+ int(n_b_local_const),
3155
+ int(Na_b.shape[0]),
3156
+ )
3157
+
3158
+ if not batch_failed and (batch_rows or (not sparse and K_dense is not None)):
3159
+ if sparse:
3160
+ if batch_rows:
3161
+ rows = np.concatenate(batch_rows)
3162
+ cols = np.concatenate(batch_cols)
3163
+ data = np.concatenate(batch_data)
3164
+ else:
3165
+ rows = np.zeros((0,), dtype=int)
3166
+ cols = np.zeros((0,), dtype=int)
3167
+ data = np.zeros((0,), dtype=float)
3168
+ return rows, cols, data, n_total
3169
+ assert K_dense is not None
3170
+ return K_dense
3171
+
3172
+ if trace:
3173
+ _trace("[CONTACT] batch_jac_fallback")
3174
+
3175
+ if trace:
3176
+ _trace("[CONTACT] supermesh_loop_enter")
3177
+ _mortar_dbg("[mortar] step: supermesh loop START")
3178
+ t_loop = time.perf_counter()
3179
+ for it, ((tri, a, b, c), fa, fb) in enumerate(
3180
+ zip(
3181
+ _iter_supermesh_tris(supermesh_coords, supermesh_conn),
3182
+ source_facets_a,
3183
+ source_facets_b,
3184
+ )
3185
+ ):
3186
+ log_tri = trace and (it < trace_max or it % trace_every == 0)
3187
+ t_tri0 = time.perf_counter()
3188
+ def _tri_check(stage: str) -> None:
3189
+ if tri_timeout > 0.0 and (time.perf_counter() - t_tri0) > tri_timeout:
3190
+ raise RuntimeError(f"[CONTACT] tri {it} timeout at {stage}")
3191
+ if log_tri:
3192
+ _trace(f"[CONTACT] tri {it} start fa={int(fa)} fb={int(fb)}")
3193
+ t_geom = time.perf_counter()
3194
+ area = _tri_area(a, b, c)
3195
+ if area <= tol:
3196
+ continue
3197
+ if skip_small_tri and facet_area_a is not None and facet_area_b is not None:
3198
+ area_ref = max(float(facet_area_a[int(fa)]), float(facet_area_b[int(fb)]))
3199
+ if area_ref > 0.0 and area < area_scale * area_ref:
3200
+ continue
3201
+ detJ = 2.0 * area
3202
+ if diag_force and diag_abs_detj:
3203
+ detJ = abs(detJ)
3204
+ if guard:
3205
+ if not np.isfinite(detJ):
3206
+ if log_tri:
3207
+ _trace(f"[CONTACT] tri {it} detJ non-finite; skip")
3208
+ if skip_nonfinite:
3209
+ continue
3210
+ raise RuntimeError(f"[CONTACT] tri {it} detJ non-finite")
3211
+ if detj_eps > 0.0 and abs(detJ) < detj_eps:
3212
+ if log_tri:
3213
+ _trace(f"[CONTACT] tri {it} detJ too small {detJ:.3e}; skip")
3214
+ continue
3215
+ if quad_order <= 0:
3216
+ quad_pts = np.array([[1.0 / 3.0, 1.0 / 3.0]], dtype=float)
3217
+ quad_w = np.array([0.5], dtype=float)
3218
+ else:
3219
+ quad_pts, quad_w = _tri_quadrature(quad_order)
3220
+ quad_source = "fluxfem"
3221
+ quad_override = _diag_quad_override(diag_force, diag_qp_mode, diag_qp_path)
3222
+ if quad_override is not None:
3223
+ quad_pts, quad_w = quad_override
3224
+ quad_source = _DEBUG_PROJ_QP_SOURCE or "override"
3225
+ _diag_quad_dump(diag_force, diag_qp_mode, diag_qp_path, quad_pts, quad_w)
3226
+
3227
+ facet_a = facets_a[int(fa)]
3228
+ facet_b = facets_b[int(fb)]
3229
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
3230
+ if guard and not np.isfinite(x_q).all():
3231
+ if log_tri:
3232
+ _trace(f"[CONTACT] tri {it} x_q non-finite; skip")
3233
+ if skip_nonfinite:
3234
+ continue
3235
+ raise RuntimeError(f"[CONTACT] tri {it} x_q non-finite")
3236
+ if log_tri:
3237
+ _trace_time(f"[CONTACT] tri {it} geom_done", t_geom)
3238
+ _tri_check("geom_done")
3239
+
3240
+ gradNa = None
3241
+ gradNb = None
3242
+ nodes_a = facet_a
3243
+ nodes_b = facet_b
3244
+
3245
+ Na = None
3246
+ Nb = None
3247
+
3248
+ elem_id_a = -1
3249
+ elem_nodes_a = None
3250
+ elem_coords_a = None
3251
+ local_a = None
3252
+ if use_elem_a:
3253
+ elem_id_a = int(facet_to_elem_a[int(fa)])
3254
+ if elem_id_a < 0:
3255
+ raise ValueError("facet_to_elem_a has invalid mapping")
3256
+ elem_nodes_a = np.asarray(elem_conn_a[elem_id_a], dtype=int)
3257
+ elem_coords_a = coords_a[elem_nodes_a]
3258
+ if elem_coords_a.shape[0] not in {4, 8, 10, 20, 27}:
3259
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
3260
+
3261
+ elem_id_b = -1
3262
+ elem_nodes_b = None
3263
+ elem_coords_b = None
3264
+ local_b = None
3265
+ if use_elem_b:
3266
+ elem_id_b = int(facet_to_elem_b[int(fb)])
3267
+ if elem_id_b < 0:
3268
+ raise ValueError("facet_to_elem_b has invalid mapping")
3269
+ elem_nodes_b = np.asarray(elem_conn_b[elem_id_b], dtype=int)
3270
+ elem_coords_b = coords_b[elem_nodes_b]
3271
+ if elem_coords_b.shape[0] not in {4, 8, 10, 20, 27}:
3272
+ raise NotImplementedError("surface sym_grad is implemented for tet4/tet10/hex8/hex20/hex27 only")
3273
+ if proj_diag:
3274
+ _proj_diag_set_context(
3275
+ fa=int(fa),
3276
+ fb=int(fb),
3277
+ face_a=_facet_label(facet_a),
3278
+ face_b=_facet_label(facet_b),
3279
+ elem_a=elem_id_a,
3280
+ elem_b=elem_id_b,
3281
+ )
3282
+
3283
+ t_basis = time.perf_counter()
3284
+ if grad_source == "surface":
3285
+ gradNa = np.array(
3286
+ [_surface_gradN(pt, facet_a, coords_a, tol=tol) for pt in x_q],
3287
+ dtype=float,
3288
+ )
3289
+ gradNb = np.array(
3290
+ [_surface_gradN(pt, facet_b, coords_b, tol=tol) for pt in x_q],
3291
+ dtype=float,
3292
+ )
3293
+ if use_elem_a and grad_source == "volume":
3294
+ local_a = _local_indices(elem_nodes_a, facet_a)
3295
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, local=local_a, tol=tol)
3296
+
3297
+ if use_elem_b and grad_source == "volume":
3298
+ local_b = _local_indices(elem_nodes_b, facet_b)
3299
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, local=local_b, tol=tol)
3300
+
3301
+ if dof_source == "volume":
3302
+ if not use_elem_a or elem_nodes_a is None or elem_coords_a is None:
3303
+ raise ValueError("dof_source 'volume' requires elem_conn_a and facet_to_elem_a")
3304
+ if not use_elem_b or elem_nodes_b is None or elem_coords_b is None:
3305
+ raise ValueError("dof_source 'volume' requires elem_conn_b and facet_to_elem_b")
3306
+ nodes_a = elem_nodes_a
3307
+ nodes_b = elem_nodes_b
3308
+ Na = _volume_shape_values_at_points(x_q, elem_coords_a, tol=tol)
3309
+ Nb = _volume_shape_values_at_points(x_q, elem_coords_b, tol=tol)
3310
+ if grad_source == "volume":
3311
+ gradNa = _tet_gradN_at_points(x_q, elem_coords_a, tol=tol)
3312
+ gradNb = _tet_gradN_at_points(x_q, elem_coords_b, tol=tol)
3313
+ else:
3314
+ Na = np.array([_facet_shape_values(pt, facet_a, coords_a, tol=tol) for pt in x_q], dtype=float)
3315
+ Nb = np.array([_facet_shape_values(pt, facet_b, coords_b, tol=tol) for pt in x_q], dtype=float)
3316
+ if guard and (not np.isfinite(Na).all() or not np.isfinite(Nb).all()):
3317
+ if log_tri:
3318
+ _trace(f"[CONTACT] tri {it} N non-finite; skip")
3319
+ if skip_nonfinite:
3320
+ continue
3321
+ raise RuntimeError(f"[CONTACT] tri {it} N non-finite")
3322
+ if log_tri:
3323
+ _trace_time(f"[CONTACT] tri {it} basis_done", t_basis)
3324
+ _tri_check("basis_done")
3325
+
3326
+ global _DEBUG_CONTACT_MAP_ONCE
3327
+ if diag_map and not _DEBUG_CONTACT_MAP_ONCE:
3328
+ elem_id_a = int(facet_to_elem_a[int(fa)]) if use_elem_a else -1
3329
+ elem_id_b = int(facet_to_elem_b[int(fb)]) if use_elem_b else -1
3330
+ print("[fluxfem][diag][contact-map] first facet")
3331
+ print(f" fa={int(fa)} fb={int(fb)} elem_a={elem_id_a} elem_b={elem_id_b}")
3332
+ print(f" facet_nodes_a={facet_a.tolist()}")
3333
+ print(f" facet_nodes_b={facet_b.tolist()}")
3334
+ print(f" facet_coords_a={coords_a[facet_a].tolist()}")
3335
+ print(f" facet_coords_b={coords_b[facet_b].tolist()}")
3336
+ if elem_nodes_a is not None:
3337
+ if local_a is None:
3338
+ local_a = _local_indices(elem_nodes_a, facet_a)
3339
+ match_a = np.all(elem_nodes_a[local_a] == facet_a)
3340
+ print(f" elem_nodes_a={elem_nodes_a.tolist()}")
3341
+ print(f" local_indices_a={local_a.tolist()} match={bool(match_a)}")
3342
+ if elem_nodes_b is not None:
3343
+ if local_b is None:
3344
+ local_b = _local_indices(elem_nodes_b, facet_b)
3345
+ match_b = np.all(elem_nodes_b[local_b] == facet_b)
3346
+ print(f" elem_nodes_b={elem_nodes_b.tolist()}")
3347
+ print(f" local_indices_b={local_b.tolist()} match={bool(match_b)}")
3348
+ _DEBUG_CONTACT_MAP_ONCE = True
3349
+
3350
+ global _DEBUG_CONTACT_N_ONCE
3351
+ if diag_n and not _DEBUG_CONTACT_N_ONCE:
3352
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
3353
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
3354
+ samples = min(3, Na.shape[0])
3355
+ print("[fluxfem][diag][contact-n] first facet q-points")
3356
+ print(f" nodes_a={nodes_a.tolist()} nodes_b={nodes_b.tolist()}")
3357
+ print(f" dofs_a={dofs_a.tolist()} dofs_b={dofs_b.tolist()}")
3358
+ for qi in range(samples):
3359
+ print(f" q{qi} x={x_q[qi].tolist()} Na={Na[qi].tolist()} Nb={Nb[qi].tolist()}")
3360
+ _DEBUG_CONTACT_N_ONCE = True
3361
+
3362
+ normal = None
3363
+ na = normals_a[int(fa)] if normals_a is not None else None
3364
+ nb = normals_b[int(fb)] if normals_b is not None else None
3365
+ if normal_source == "a":
3366
+ normal = na
3367
+ elif normal_source == "b":
3368
+ normal = nb
3369
+ else:
3370
+ if na is not None and nb is not None:
3371
+ avg = na + nb
3372
+ norm = np.linalg.norm(avg)
3373
+ normal = avg / norm if norm > tol else na
3374
+ else:
3375
+ normal = na if na is not None else nb
3376
+ if normal is not None:
3377
+ normal = normal_sign * normal
3378
+
3379
+ field_a_obj = SurfaceMixedFormField(
3380
+ N=Na,
3381
+ gradN=gradNa,
3382
+ value_dim=value_dim_a,
3383
+ basis=_SurfaceBasis(dofs_per_node=value_dim_a),
3384
+ )
3385
+ field_b_obj = SurfaceMixedFormField(
3386
+ N=Nb,
3387
+ gradN=gradNb,
3388
+ value_dim=value_dim_b,
3389
+ basis=_SurfaceBasis(dofs_per_node=value_dim_b),
3390
+ )
3391
+ fields = {
3392
+ field_a: FieldPair(test=field_a_obj, trial=field_a_obj),
3393
+ field_b: FieldPair(test=field_b_obj, trial=field_b_obj),
3394
+ }
3395
+ normal_q = None if normal is None else np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
3396
+ ctx = SurfaceMixedFormContext(
3397
+ fields=fields,
3398
+ x_q=x_q,
3399
+ w=quad_w,
3400
+ detJ=np.array([detJ], dtype=float),
3401
+ normal=normal_q,
3402
+ trial_fields={field_a: field_a_obj, field_b: field_b_obj},
3403
+ test_fields={field_a: field_a_obj, field_b: field_b_obj},
3404
+ unknown_fields={field_a: field_a_obj, field_b: field_b_obj},
3405
+ )
3406
+
3407
+ u_elem = {
3408
+ field_a: _gather_u_local(u_a, nodes_a, value_dim_a),
3409
+ field_b: _gather_u_local(u_b, nodes_b, value_dim_b),
3410
+ }
3411
+ u_local = np.concatenate([u_elem[field_a], u_elem[field_b]], axis=0)
3412
+ sizes = (u_elem[field_a].shape[0], u_elem[field_b].shape[0])
3413
+ slices = {
3414
+ field_a: slice(0, sizes[0]),
3415
+ field_b: slice(sizes[0], sizes[0] + sizes[1]),
3416
+ }
3417
+
3418
+ def _res_local(u_vec):
3419
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
3420
+ fe_q = res_form(ctx, u_dict, params)
3421
+ res_parts = []
3422
+ for name in (field_a, field_b):
3423
+ fe_field = fe_q[name]
3424
+ if includes_measure.get(name, False):
3425
+ fe = jnp.sum(jnp.asarray(fe_field), axis=0)
3426
+ else:
3427
+ wJ = jnp.asarray(ctx.w) * jnp.asarray(ctx.detJ)
3428
+ fe = jnp.einsum("qi,q->i", jnp.asarray(fe_field), wJ)
3429
+ res_parts.append(fe)
3430
+ return jnp.concatenate(res_parts, axis=0)
3431
+
3432
+ def _res_local_np(u_vec):
3433
+ u_dict = {name: u_vec[slices[name]] for name in (field_a, field_b)}
3434
+ fe_q = res_form(ctx, u_dict, params)
3435
+ res_parts = []
3436
+ for name in (field_a, field_b):
3437
+ fe_field = fe_q[name]
3438
+ if includes_measure.get(name, False):
3439
+ fe = np.sum(np.asarray(fe_field), axis=0)
3440
+ else:
3441
+ wJ = np.asarray(ctx.w) * np.asarray(ctx.detJ)
3442
+ fe = np.einsum("qi...,q->i...", np.asarray(fe_field), wJ)
3443
+ res_parts.append(np.asarray(fe))
3444
+ return np.concatenate(res_parts, axis=0)
3445
+
3446
+ t_jac = time.perf_counter()
3447
+ if backend == "jax":
3448
+ J_local = jax.jacrev(_res_local)(jnp.asarray(u_local))
3449
+ J_local_np = np.asarray(J_local)
3450
+ else:
3451
+ n_ldofs = int(u_local.shape[0])
3452
+ J_local_np = np.zeros((n_ldofs, n_ldofs), dtype=float)
3453
+ u_base = np.asarray(u_local, dtype=float)
3454
+ if log_tri:
3455
+ _trace(f"[CONTACT] tri {it} fd_start n_ldofs={n_ldofs} fd_mode={fd_mode}")
3456
+ r0 = _res_local_np(u_base) if fd_mode == "forward" else None
3457
+ if log_tri and fd_mode == "forward":
3458
+ _trace(f"[CONTACT] tri {it} fd_r0_done")
3459
+ block = max(1, int(fd_block_size))
3460
+ if block <= 1:
3461
+ for i in range(n_ldofs):
3462
+ log_fd = log_tri and i < trace_fd_max
3463
+ if log_fd:
3464
+ _trace(f"[CONTACT] tri {it} fd_col {i} start")
3465
+ u_p = u_base.copy()
3466
+ u_p[i] += fd_eps
3467
+ t_rp = time.perf_counter()
3468
+ r_p = _res_local_np(u_p)
3469
+ if log_fd:
3470
+ _trace_time(f"[CONTACT] tri {it} fd_col {i} r_p", t_rp)
3471
+ if fd_mode == "central":
3472
+ u_m = u_base.copy()
3473
+ u_m[i] -= fd_eps
3474
+ t_rm = time.perf_counter()
3475
+ r_m = _res_local_np(u_m)
3476
+ if log_fd:
3477
+ _trace_time(f"[CONTACT] tri {it} fd_col {i} r_m", t_rm)
3478
+ col = (r_p - r_m) / (2.0 * fd_eps)
3479
+ else:
3480
+ col = (r_p - r0) / fd_eps
3481
+ J_local_np[:, i] = np.asarray(col, dtype=float)
3482
+ if log_fd:
3483
+ _trace(f"[CONTACT] tri {it} fd_col {i} done")
3484
+ else:
3485
+ for i0 in range(0, n_ldofs, block):
3486
+ idxs = np.arange(i0, min(i0 + block, n_ldofs))
3487
+ u_block = np.repeat(u_base[:, None], idxs.size, axis=1)
3488
+ for bi, idx in enumerate(idxs):
3489
+ u_block[idx, bi] += fd_eps
3490
+ t_rp = time.perf_counter()
3491
+ r_p = _res_local_np(u_block)
3492
+ if log_tri and i0 < trace_fd_max:
3493
+ _trace_time(f"[CONTACT] tri {it} fd_block r_p", t_rp)
3494
+ if fd_mode == "central":
3495
+ u_block_m = np.repeat(u_base[:, None], idxs.size, axis=1)
3496
+ for bi, idx in enumerate(idxs):
3497
+ u_block_m[idx, bi] -= fd_eps
3498
+ t_rm = time.perf_counter()
3499
+ r_m = _res_local_np(u_block_m)
3500
+ if log_tri and i0 < trace_fd_max:
3501
+ _trace_time(f"[CONTACT] tri {it} fd_block r_m", t_rm)
3502
+ cols = (r_p - r_m) / (2.0 * fd_eps)
3503
+ else:
3504
+ cols = (r_p - r0[:, None]) / fd_eps
3505
+ J_local_np[:, idxs] = np.asarray(cols, dtype=float)
3506
+ if log_tri:
3507
+ _trace_time(f"[CONTACT] tri {it} jac_done", t_jac)
3508
+ _tri_check("jac_done")
3509
+
3510
+ dofs_a = _global_dof_indices(nodes_a, value_dim_a, int(offset_a))
3511
+ dofs_b = _global_dof_indices(nodes_b, value_dim_b, int(offset_b))
3512
+ if diag_force:
3513
+ _diag_contact_projection(
3514
+ fa=int(fa),
3515
+ fb=int(fb),
3516
+ quad_pts=quad_pts,
3517
+ quad_w=quad_w,
3518
+ x_q=x_q,
3519
+ Na=Na,
3520
+ Nb=Nb,
3521
+ nodes_a=nodes_a,
3522
+ nodes_b=nodes_b,
3523
+ dofs_a=dofs_a,
3524
+ dofs_b=dofs_b,
3525
+ elem_coords_a=elem_coords_a if dof_source == "volume" else None,
3526
+ elem_coords_b=elem_coords_b if dof_source == "volume" else None,
3527
+ na=na,
3528
+ nb=nb,
3529
+ normal=normal,
3530
+ normal_source=normal_source,
3531
+ normal_sign=normal_sign,
3532
+ detJ=detJ,
3533
+ diag_facet=diag_facet,
3534
+ diag_max_q=diag_max_q,
3535
+ quad_source=quad_source,
3536
+ tol=tol,
3537
+ )
3538
+ t_scatter = time.perf_counter()
3539
+ dofs = np.concatenate([dofs_a, dofs_b], axis=0)
3540
+ if sparse:
3541
+ n_ldofs = int(dofs.shape[0])
3542
+ rows.extend(np.repeat(dofs, n_ldofs).tolist())
3543
+ cols.extend(np.tile(dofs, n_ldofs).tolist())
3544
+ data.extend(J_local_np.reshape(-1).tolist())
3545
+ else:
3546
+ K_dense[np.ix_(dofs, dofs)] += J_local_np
3547
+ if log_tri:
3548
+ _trace_time(f"[CONTACT] tri {it} scatter_done", t_scatter)
3549
+ _tri_check("scatter_done")
3550
+
3551
+
3552
+ if proj_diag:
3553
+ _proj_diag_report()
3554
+ if sparse:
3555
+ return np.asarray(rows, dtype=int), np.asarray(cols, dtype=int), np.asarray(data, dtype=float), n_total
3556
+ assert K_dense is not None
3557
+ return K_dense
3558
+
3559
+
3560
+ def assemble_onesided_bilinear(
3561
+ surface_slave: SurfaceMesh,
3562
+ u_hat_fn,
3563
+ params: "WeakParams",
3564
+ *,
3565
+ surface_master: SurfaceMesh | None = None,
3566
+ u_master: np.ndarray | None = None,
3567
+ value_dim: int = 3,
3568
+ elem_conn: np.ndarray | None = None,
3569
+ facet_to_elem: np.ndarray | None = None,
3570
+ elem_conn_master: np.ndarray | None = None,
3571
+ facet_to_elem_master: np.ndarray | None = None,
3572
+ grad_source: str = "volume",
3573
+ dof_source: str = "volume",
3574
+ quad_order: int = 2,
3575
+ normal_sign: float = 1.0,
3576
+ tol: float = 1e-8,
3577
+ ) -> tuple[np.ndarray, np.ndarray]:
3578
+ """
3579
+ Assemble one-sided (slave-only) Nitsche matrices without supermesh.
3580
+
3581
+ The master side is treated as prescribed displacement u_hat(x). Provide
3582
+ either u_hat_fn(x_q) or u_master with master element mappings to evaluate
3583
+ u_hat at slave quadrature points.
3584
+
3585
+ Note: this implementation currently assumes volume-trace bases for both
3586
+ gradients and DOFs. Surface-only bases are not supported here yet.
3587
+ """
3588
+ from ..core.forms import FieldPair
3589
+ coords_s = np.asarray(surface_slave.coords, dtype=float)
3590
+ facets_s = np.asarray(surface_slave.conn, dtype=int)
3591
+ coords_m = np.asarray(surface_master.coords, dtype=float) if surface_master is not None else coords_s
3592
+ facets_m = np.asarray(surface_master.conn, dtype=int) if surface_master is not None else facets_s
3593
+ n_s = int(coords_s.shape[0] * value_dim)
3594
+ K = np.zeros((n_s, n_s), dtype=float)
3595
+ f = np.zeros((n_s,), dtype=float)
3596
+
3597
+ normals_s = surface_slave.facet_normals() if hasattr(surface_slave, "facet_normals") else None
3598
+ use_elem = elem_conn is not None and facet_to_elem is not None
3599
+ use_master = u_master is not None
3600
+
3601
+ if use_master:
3602
+ if surface_master is None:
3603
+ raise ValueError("surface_master is required when u_master is provided")
3604
+ if elem_conn_master is None or facet_to_elem_master is None:
3605
+ raise ValueError("elem_conn_master and facet_to_elem_master are required when u_master is provided")
3606
+ else:
3607
+ if u_hat_fn is None:
3608
+ raise ValueError("u_hat_fn or u_master must be provided")
3609
+ if surface_master is None:
3610
+ surface_master = surface_slave
3611
+
3612
+ if grad_source != "volume" or dof_source != "volume":
3613
+ raise ValueError("one-sided Nitsche currently supports only volume/volume")
3614
+
3615
+ from ..core.weakform import (
3616
+ Params,
3617
+ compile_mixed_surface_residual_numpy,
3618
+ param_ref,
3619
+ test_ref,
3620
+ unknown_ref,
3621
+ )
3622
+ import fluxfem.helpers_wf as h_wf
3623
+
3624
+ u = unknown_ref("u")
3625
+ v = test_ref("u")
3626
+ p = param_ref()
3627
+ n = h_wf.normal()
3628
+ t_u = h_wf.traction(u, n, p)
3629
+ t_v = h_wf.traction(v, n, p)
3630
+ sym_term = h_wf.einsum("qia,qi->qa", t_v, u.val)
3631
+ sym_term_hat = h_wf.einsum("qia,qi->qa", t_v, p.u_hat)
3632
+ expr = (
3633
+ -h_wf.dot(v, t_u)
3634
+ - sym_term
3635
+ + (p.alpha * p.inv_h) * h_wf.dot(v, u.val)
3636
+ + sym_term_hat
3637
+ - (p.alpha * p.inv_h) * h_wf.dot(v, p.u_hat)
3638
+ ) * h_wf.ds()
3639
+ res_form = compile_mixed_surface_residual_numpy({"u": expr})
3640
+ includes_measure = res_form._includes_measure
3641
+
3642
+ quad_pts, quad_w = _tri_quadrature(quad_order) if quad_order > 0 else (np.array([[1.0 / 3.0, 1.0 / 3.0]]), np.array([0.5]))
3643
+
3644
+ for f_id, facet in enumerate(facets_s):
3645
+ triangles = _facet_triangles(coords_s, facet)
3646
+ if not triangles:
3647
+ continue
3648
+ area_f = _facet_area_estimate(facet, coords_s)
3649
+ if area_f <= tol:
3650
+ continue
3651
+ inv_h = 1.0 / max(np.sqrt(area_f), tol)
3652
+
3653
+ elem_nodes = None
3654
+ elem_coords = None
3655
+ local = None
3656
+ if use_elem:
3657
+ elem_id = int(facet_to_elem[int(f_id)])
3658
+ if elem_id < 0:
3659
+ raise ValueError("facet_to_elem has invalid mapping")
3660
+ elem_nodes = np.asarray(elem_conn[elem_id], dtype=int)
3661
+ elem_coords = coords_s[elem_nodes]
3662
+
3663
+ for a, b, c in triangles:
3664
+ area = _tri_area(a, b, c)
3665
+ if area <= tol:
3666
+ continue
3667
+ detJ = 2.0 * area
3668
+ x_q = np.array([a + r * (b - a) + s * (c - a) for r, s in quad_pts], dtype=float)
3669
+ if use_master:
3670
+ if dof_source == "surface":
3671
+ facet_m = facets_m[int(f_id)]
3672
+ u_master_local = _gather_u_local(u_master, facet_m, value_dim).reshape(-1, value_dim)
3673
+ N_master = np.array(
3674
+ [_facet_shape_values(pt, facet_m, coords_m, tol=tol) for pt in x_q],
3675
+ dtype=float,
3676
+ )
3677
+ u_hat = N_master @ u_master_local
3678
+ else:
3679
+ elem_id_m = int(facet_to_elem_master[int(f_id)])
3680
+ if elem_id_m < 0:
3681
+ raise ValueError("facet_to_elem_master has invalid mapping")
3682
+ elem_nodes_m = np.asarray(elem_conn_master[elem_id_m], dtype=int)
3683
+ elem_coords_m = coords_m[elem_nodes_m]
3684
+ u_master_local = _gather_u_local(u_master, elem_nodes_m, value_dim).reshape(-1, value_dim)
3685
+ N_master = _volume_shape_values_at_points(x_q, elem_coords_m, tol=tol)
3686
+ u_hat = N_master @ u_master_local
3687
+ else:
3688
+ u_hat = np.asarray(u_hat_fn(x_q), dtype=float)
3689
+ if u_hat.shape[0] != x_q.shape[0]:
3690
+ raise ValueError("u_hat_fn must return shape (n_q, value_dim)")
3691
+
3692
+ gradN = None
3693
+ nodes = facet
3694
+ N = None
3695
+
3696
+ if grad_source == "surface":
3697
+ gradN = np.array(
3698
+ [_surface_gradN(pt, facet, coords_s, tol=tol) for pt in x_q],
3699
+ dtype=float,
3700
+ )
3701
+ if use_elem and grad_source == "volume":
3702
+ local = _local_indices(elem_nodes, facet)
3703
+ gradN = _tet_gradN_at_points(x_q, elem_coords, local=local, tol=tol)
3704
+
3705
+ if dof_source == "volume":
3706
+ if not use_elem or elem_nodes is None or elem_coords is None:
3707
+ raise ValueError("dof_source 'volume' requires elem_conn and facet_to_elem")
3708
+ nodes = elem_nodes
3709
+ N = _volume_shape_values_at_points(x_q, elem_coords, tol=tol)
3710
+ if grad_source == "volume":
3711
+ gradN = _tet_gradN_at_points(x_q, elem_coords, tol=tol)
3712
+ else:
3713
+ N = np.array([_facet_shape_values(pt, facet, coords_s, tol=tol) for pt in x_q], dtype=float)
3714
+
3715
+ field = SurfaceMixedFormField(
3716
+ N=N,
3717
+ gradN=gradN,
3718
+ value_dim=value_dim,
3719
+ basis=_SurfaceBasis(dofs_per_node=value_dim),
3720
+ )
3721
+ fields = {"u": FieldPair(test=field, trial=field)}
3722
+ normal = normals_s[int(f_id)] if normals_s is not None else None
3723
+ if normal is not None:
3724
+ normal = normal_sign * normal
3725
+ normal_q = None if normal is None else np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
3726
+ ctx = SurfaceMixedFormContext(
3727
+ fields=fields,
3728
+ x_q=x_q,
3729
+ w=quad_w,
3730
+ detJ=np.array([detJ], dtype=float),
3731
+ normal=normal_q,
3732
+ trial_fields={"u": field},
3733
+ test_fields={"u": field},
3734
+ unknown_fields={"u": field},
3735
+ )
3736
+ params_local = Params(
3737
+ lam=params.lam,
3738
+ mu=params.mu,
3739
+ alpha=params.alpha,
3740
+ inv_h=inv_h,
3741
+ u_hat=u_hat,
3742
+ )
3743
+ u_zero = np.zeros((len(nodes) * value_dim,), dtype=float)
3744
+ u_dict = {"u": u_zero}
3745
+ sizes = (u_zero.shape[0],)
3746
+ slices = {"u": slice(0, sizes[0])}
3747
+
3748
+ def _res_local_np_single(u_vec: np.ndarray) -> np.ndarray:
3749
+ u_local = {"u": u_vec[slices["u"]]}
3750
+ fe_q = res_form(ctx, u_local, params_local)["u"]
3751
+ if includes_measure.get("u", False):
3752
+ return np.sum(np.asarray(fe_q), axis=0)
3753
+ wJ = np.asarray(ctx.w) * np.asarray(ctx.detJ)
3754
+ return np.einsum("qi,q->i", np.asarray(fe_q), wJ)
3755
+
3756
+ def _res_local_np(u_vec: np.ndarray) -> np.ndarray:
3757
+ if u_vec.ndim == 1:
3758
+ return _res_local_np_single(u_vec)
3759
+ out = np.empty((u_vec.shape[0], u_vec.shape[1]), dtype=float)
3760
+ for col in range(u_vec.shape[1]):
3761
+ out[:, col] = _res_local_np_single(u_vec[:, col])
3762
+ return out
3763
+
3764
+ f_local = _res_local_np(u_zero)
3765
+ n_ldofs = int(u_zero.shape[0])
3766
+ k_local = np.zeros((n_ldofs, n_ldofs), dtype=float)
3767
+ block = max(1, int(os.getenv("FLUXFEM_ONESIDE_BLOCK_SIZE", "16")))
3768
+ for start in range(0, n_ldofs, block):
3769
+ idxs = np.arange(start, min(n_ldofs, start + block), dtype=int)
3770
+ u_block = np.zeros((n_ldofs, idxs.size), dtype=float)
3771
+ u_block[idxs, np.arange(idxs.size, dtype=int)] = 1.0
3772
+ r_block = _res_local_np(u_block)
3773
+ k_local[:, idxs] = r_block - f_local[:, None]
3774
+
3775
+ dofs = _global_dof_indices(nodes, value_dim, 0)
3776
+ f[dofs] += f_local
3777
+ K[np.ix_(dofs, dofs)] += k_local
3778
+
3779
+ return K, f
3780
+
3781
+
3782
+ def assemble_contact_onesided_floor(
3783
+ surface_slave: SurfaceMesh,
3784
+ u: np.ndarray,
3785
+ *,
3786
+ n: np.ndarray | None = None,
3787
+ c: float,
3788
+ k: float,
3789
+ beta: float,
3790
+ value_dim: int = 3,
3791
+ elem_conn: np.ndarray | None = None,
3792
+ facet_to_elem: np.ndarray | None = None,
3793
+ quad_order: int = 2,
3794
+ normal_sign: float = 1.0,
3795
+ tol: float = 1e-8,
3796
+ return_metrics: bool = False,
3797
+ ) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, dict[str, float]]:
3798
+ """
3799
+ Assemble one-sided contact penalty against a rigid plane g = n·x - c.
3800
+
3801
+ Uses softplus for a smooth contact pressure:
3802
+ p(g) = k * softplus(-g; beta)
3803
+ with softplus(z; beta) = (1 / beta) * log(1 + exp(beta z)).
3804
+
3805
+ Note: the resulting stiffness matrix can be nonsymmetric; avoid CG.
3806
+ """
3807
+ if elem_conn is None or facet_to_elem is None:
3808
+ raise ValueError("elem_conn and facet_to_elem are required")
3809
+ if beta <= 0.0:
3810
+ raise ValueError("beta must be positive")
3811
+
3812
+ import jax
3813
+ import jax.numpy as jnp
3814
+
3815
+ coords_s = np.asarray(surface_slave.coords, dtype=float)
3816
+ facets_s = np.asarray(surface_slave.conn, dtype=int)
3817
+ n_s = int(coords_s.shape[0] * value_dim)
3818
+ K = np.zeros((n_s, n_s), dtype=float)
3819
+ f = np.zeros((n_s,), dtype=float)
3820
+
3821
+ normals_s = surface_slave.facet_normals() if hasattr(surface_slave, "facet_normals") else None
3822
+ if n is not None:
3823
+ n = np.asarray(n, dtype=float).reshape(-1)
3824
+ if n.shape[0] != 3:
3825
+ raise ValueError("n must be a 3-vector")
3826
+ n_norm = np.linalg.norm(n)
3827
+ if n_norm <= tol:
3828
+ raise ValueError("n must be non-zero")
3829
+ n = (n / n_norm) * float(normal_sign)
3830
+ elif normals_s is None:
3831
+ raise ValueError("surface normals are required when n is not provided")
3832
+
3833
+ penetration = 0.0
3834
+ min_g = float("inf")
3835
+ quad_pts, quad_w = _tri_quadrature(quad_order) if quad_order > 0 else (np.array([[1.0 / 3.0, 1.0 / 3.0]]), np.array([0.5]))
3836
+
3837
+ for f_id, facet in enumerate(facets_s):
3838
+ triangles = _facet_triangles(coords_s, facet)
3839
+ if not triangles:
3840
+ continue
3841
+ area_f = _facet_area_estimate(facet, coords_s)
3842
+ if area_f <= tol:
3843
+ continue
3844
+
3845
+ elem_id = int(facet_to_elem[int(f_id)])
3846
+ if elem_id < 0:
3847
+ raise ValueError("facet_to_elem has invalid mapping")
3848
+ elem_nodes = np.asarray(elem_conn[elem_id], dtype=int)
3849
+ elem_coords = coords_s[elem_nodes]
3850
+ u_local = _gather_u_local(u, elem_nodes, value_dim).reshape(-1, value_dim)
3851
+
3852
+ if n is not None:
3853
+ normal = n
3854
+ else:
3855
+ normal = normal_sign * normals_s[int(f_id)]
3856
+
3857
+ for a, b, c_tri in triangles:
3858
+ area = _tri_area(a, b, c_tri)
3859
+ if area <= tol:
3860
+ continue
3861
+ detJ = 2.0 * area
3862
+ x_q_ref = np.array([a + r * (b - a) + s * (c_tri - a) for r, s in quad_pts], dtype=float)
3863
+ N = _volume_shape_values_at_points(x_q_ref, elem_coords, tol=tol)
3864
+
3865
+ normal_q = np.repeat(normal[None, :], quad_pts.shape[0], axis=0)
3866
+
3867
+ u_q_np = N @ u_local
3868
+ x_q_cur = x_q_ref + u_q_np
3869
+ g_np = np.sum(normal_q * x_q_cur, axis=1) - float(c)
3870
+ min_g = min(min_g, float(np.min(g_np)))
3871
+ z_np = -float(beta) * g_np
3872
+ z_clip = np.minimum(z_np, 30.0)
3873
+ softplus_np = np.where(z_np > 30.0, z_np, np.log1p(np.exp(z_clip))) / float(beta)
3874
+ penetration += float(np.sum(softplus_np * quad_w) * detJ)
3875
+
3876
+ def _res_local(u_vec):
3877
+ u_loc = u_vec.reshape(-1, value_dim)
3878
+ u_q = jnp.einsum("qi,ia->qa", jnp.asarray(N), u_loc)
3879
+ x_q_j = jnp.asarray(x_q_ref)
3880
+ n_q = jnp.asarray(normal_q)
3881
+ x_q_cur_j = x_q_j + u_q
3882
+ g = jnp.einsum("qa,qa->q", n_q, x_q_cur_j) - float(c)
3883
+ p = float(k) * jax.nn.softplus(-float(beta) * g) / float(beta)
3884
+ t = p[:, None] * n_q
3885
+ wJ = jnp.asarray(quad_w) * float(detJ)
3886
+ nodal = jnp.einsum("qi,qa,q->ia", jnp.asarray(N), t, wJ)
3887
+ return nodal.reshape(-1)
3888
+
3889
+ u_vec0 = np.asarray(u_local.reshape(-1), dtype=float)
3890
+ f_local = np.asarray(_res_local(jnp.asarray(u_vec0)))
3891
+ k_local = np.asarray(jax.jacrev(_res_local)(jnp.asarray(u_vec0)))
3892
+
3893
+ dofs = _global_dof_indices(elem_nodes, value_dim, 0)
3894
+ for i, gi in enumerate(dofs):
3895
+ f[int(gi)] += float(f_local[i])
3896
+ for j, gj in enumerate(dofs):
3897
+ K[int(gi), int(gj)] += float(k_local[i, j])
3898
+
3899
+ if return_metrics:
3900
+ if min_g == float("inf"):
3901
+ min_g = 0.0
3902
+ metrics = {
3903
+ "penetration": float(penetration),
3904
+ "min_g": float(min_g),
3905
+ }
3906
+ return K, f, metrics
3907
+ return K, f