fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/solver/newton.py CHANGED
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
  import time
3
3
 
4
+ from typing import Any, Callable, Mapping, TYPE_CHECKING, TypeAlias
5
+
4
6
  import numpy as np
5
7
  import jax
6
8
  import jax.numpy as jnp
@@ -8,38 +10,51 @@ import jax.numpy as jnp
8
10
  from ..core.assembly import (
9
11
  assemble_residual_scatter,
10
12
  assemble_jacobian_scatter,
13
+ ResidualForm,
11
14
  make_element_residual_kernel,
12
15
  make_element_jacobian_kernel,
13
16
  make_sparsity_pattern,
14
17
  )
15
18
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
16
19
  from .cg import cg_solve, cg_solve_jax
20
+ from .preconditioner import make_block_jacobi_preconditioner
17
21
  from .result import SolverResult
18
22
  from .sparse import SparsityPattern, FluxSparseMatrix
19
- from .dirichlet import _normalize_dirichlet
23
+ from .dirichlet import DirichletBC, _normalize_dirichlet
24
+
25
+ if TYPE_CHECKING:
26
+ from jax import Array as JaxArray
27
+
28
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
29
+ else:
30
+ ArrayLike: TypeAlias = np.ndarray
31
+ ExtraTerm: TypeAlias = Callable[[np.ndarray], tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Mapping[str, Any]] | None]
20
32
 
21
33
 
22
34
  def newton_solve(
23
35
  space,
24
- res_form,
25
- u0,
26
- params,
36
+ res_form: ResidualForm[Any],
37
+ u0: ArrayLike,
38
+ params: Any,
27
39
  *,
28
40
  tol: float = 1e-8,
29
41
  atol: float = 0.0,
30
42
  maxiter: int = 20,
31
- linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg" (jax), "cg_jax", or "cg_custom"
43
+ linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg", "cg_jax", "cg_custom", or "cg_matfree"
32
44
  linear_maxiter: int | None = None,
33
45
  linear_tol: float | None = None,
34
- linear_preconditioner=None,
35
- dirichlet=None,
36
- callback=None,
46
+ linear_preconditioner: object | None = None,
47
+ matfree_mode: str = "linearize",
48
+ matfree_cache: dict[str, Any] | None = None,
49
+ dirichlet: tuple[np.ndarray, np.ndarray] | None = None,
50
+ callback: Callable[[np.ndarray, SolverResult], Any] | None = None,
37
51
  line_search: bool = False,
38
52
  max_ls: int = 10,
39
53
  ls_c: float = 1e-4,
40
- external_vector=None,
41
- jacobian_pattern=None,
42
- ):
54
+ external_vector: np.ndarray | None = None,
55
+ jacobian_pattern: SparsityPattern | None = None,
56
+ extra_terms: list[ExtraTerm] | None = None,
57
+ ) -> tuple[np.ndarray, SolverResult]:
43
58
  """
44
59
  Gridap-style Newton–Raphson solver on free DOFs only.
45
60
 
@@ -48,12 +63,17 @@ def newton_solve(
48
63
  - Convergence: ||R_free||_inf < max(atol, tol * ||R_free0||_inf).
49
64
  - external_vector: optional global RHS (internal - external).
50
65
  - CG path accepts an operator with matvec that acts on free DOFs via a wrapper.
51
- - linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | callable).
66
+ - cg_matfree uses JVP/linearize to form a matrix-free matvec (no global Jacobian).
67
+ - linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | "diag0" | callable).
68
+ - matfree_cache: optional dict for reusing matrix-free preconditioners across calls.
52
69
  - linear_tol: CG tolerance (defaults to 0.1 * tol if not provided).
53
70
  - jacobian_pattern: optional SparsityPattern to reuse sparsity across load steps.
71
+ - extra_terms: optional list of callbacks returning (K, f[, metrics]) for extra terms.
54
72
  """
55
73
 
56
74
  if dirichlet is not None:
75
+ if isinstance(dirichlet, DirichletBC):
76
+ dirichlet = dirichlet.as_tuple()
57
77
  dir_dofs, dir_vals = dirichlet
58
78
  dir_dofs, dir_vals = _normalize_dirichlet(dir_dofs, dir_vals)
59
79
  if dir_vals.ndim == 0:
@@ -66,6 +86,13 @@ def newton_solve(
66
86
  dir_dofs = dir_vals = None
67
87
  free_dofs = np.arange(space.n_dofs, dtype=int)
68
88
 
89
+ use_matfree = linear_solver in ("cg_matfree", "cg_jvp")
90
+ if use_matfree and matfree_mode not in ("linearize", "jvp"):
91
+ raise ValueError("matfree_mode must be 'linearize' or 'jvp'")
92
+
93
+ if extra_terms is not None and linear_solver in ("cg", "cg_jax", "cg_custom", "cg_matfree", "cg_jvp"):
94
+ raise ValueError("extra_terms may yield nonsymmetric K; avoid CG-based solvers")
95
+
69
96
  free_dofs_j = jnp.asarray(free_dofs, dtype=jnp.int32)
70
97
  # For block-Jacobi (3x3 per node) we keep node ids of free dofs.
71
98
  node_ids = free_dofs // 3
@@ -109,33 +136,7 @@ def newton_solve(
109
136
  Build 3x3 block-Jacobi inverse per free node.
110
137
  Assumes DOF ordering per node is [ux, uy, uz].
111
138
  """
112
- if len(free_dofs) % 3 != 0:
113
- raise ValueError("block_jacobi assumes 3 DOFs per node.")
114
- rows = np.asarray(J_free.rows)
115
- cols = np.asarray(J_free.cols)
116
- data = np.asarray(J_free.data)
117
- block_rows = node_ids_inv[rows]
118
- block_cols = node_ids_inv[cols]
119
- local_r = rows % 3
120
- local_c = cols % 3
121
- mask_blk = block_rows == block_cols
122
- blk_rows = block_rows[mask_blk]
123
- blk_lr = local_r[mask_blk]
124
- blk_lc = local_c[mask_blk]
125
- blk_data = data[mask_blk]
126
- inv_blocks = np.zeros((n_block, 3, 3), dtype=blk_data.dtype)
127
- inv_blocks[blk_rows, blk_lr, blk_lc] += blk_data
128
- inv_blocks = jnp.asarray(inv_blocks)
129
- # Add tiny damping to avoid singular blocks
130
- inv_blocks = inv_blocks + 1e-12 * jnp.eye(3)[None, :, :]
131
- inv_blocks = jnp.linalg.inv(inv_blocks)
132
-
133
- def precon(r):
134
- r_blocks = r.reshape((n_block, 3))
135
- z_blocks = jnp.einsum("bij,bj->bi", inv_blocks, r_blocks)
136
- return z_blocks.reshape((-1,))
137
-
138
- return precon
139
+ return make_block_jacobi_preconditioner(J_free, dof_per_node=3)
139
140
 
140
141
  def expand_full(u_free: jnp.ndarray) -> jnp.ndarray:
141
142
  if dir_dofs is None:
@@ -145,6 +146,39 @@ def newton_solve(
145
146
  u_full = u_full.at[dir_dofs_j].set(dir_vals_j)
146
147
  return u_full
147
148
 
149
+ extra_metrics = None
150
+
151
+ def _call_extra(u_full_vec):
152
+ if extra_terms is None:
153
+ return None
154
+ K_sum = None
155
+ f_sum = None
156
+ metrics_sum = {}
157
+ for term in extra_terms:
158
+ out = term(np.asarray(u_full_vec))
159
+ if out is None:
160
+ continue
161
+ if len(out) == 2:
162
+ Kc, fc = out
163
+ metrics = None
164
+ else:
165
+ Kc, fc, metrics = out
166
+ Kc = np.asarray(Kc, dtype=float)
167
+ fc = np.asarray(fc, dtype=float)
168
+ if K_sum is None:
169
+ K_sum = Kc
170
+ else:
171
+ K_sum = K_sum + Kc
172
+ if f_sum is None:
173
+ f_sum = fc
174
+ else:
175
+ f_sum = f_sum + fc
176
+ if isinstance(metrics, dict):
177
+ metrics_sum.update(metrics)
178
+ if K_sum is None or f_sum is None:
179
+ return None
180
+ return K_sum, f_sum, (metrics_sum or None)
181
+
148
182
  def eval_residual(u_free_vec):
149
183
  """Residual on free DOFs only."""
150
184
  u_full = expand_full(u_free_vec)
@@ -156,17 +190,32 @@ def newton_solve(
156
190
  res_two = float(jnp.linalg.norm(R_free, ord=2))
157
191
  return R_free, res_inf, res_two, u_full
158
192
 
193
+ def residual_free(u_free_vec):
194
+ u_full = expand_full(u_free_vec)
195
+ R_full = assemble_R(u_full)
196
+ if external_vector is not None:
197
+ R_full = R_full - external_vector
198
+ return R_full[free_dofs_j]
199
+
159
200
  # Pre-jitted element kernels to avoid recompiling inside Newton
160
201
  res_kernel = make_element_residual_kernel(res_form, params)
161
202
  jac_kernel = make_element_jacobian_kernel(res_form, params)
162
203
 
163
204
  def assemble_R(u_full_vec):
164
- return assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
205
+ nonlocal extra_metrics
206
+ R = assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
207
+ extra_out = _call_extra(u_full_vec)
208
+ if extra_out is not None:
209
+ _Kc, fc, metrics = extra_out
210
+ extra_metrics = metrics
211
+ R = R + jnp.asarray(fc, dtype=R.dtype)
212
+ return R
165
213
 
166
214
  eff_linear_tol = linear_tol if linear_tol is not None else max(0.1 * tol, 1e-12)
167
215
 
168
216
  def assemble_J(u_full_vec):
169
- return assemble_jacobian_scatter(
217
+ nonlocal extra_metrics
218
+ J = assemble_jacobian_scatter(
170
219
  space,
171
220
  res_form,
172
221
  u_full_vec,
@@ -176,6 +225,40 @@ def newton_solve(
176
225
  return_flux_matrix=True,
177
226
  pattern=J_pattern,
178
227
  )
228
+ extra_out = _call_extra(u_full_vec)
229
+ if extra_out is not None:
230
+ Kc, _fc, metrics = extra_out
231
+ extra_metrics = metrics
232
+ rows = np.asarray(J.pattern.rows, dtype=int)
233
+ cols = np.asarray(J.pattern.cols, dtype=int)
234
+ data = jnp.asarray(J.data) + jnp.asarray(Kc[rows, cols], dtype=J.data.dtype)
235
+ J = J.with_data(data)
236
+ return J
237
+
238
+ matfree_precon = None
239
+ if use_matfree and linear_preconditioner == "diag0":
240
+ cached = matfree_cache.get("inv_diag0") if matfree_cache is not None else None
241
+ if cached is not None:
242
+ inv_diag0 = cached
243
+ matfree_precon = lambda r: inv_diag0 * r
244
+ print("[PRECOND] reuse diag0", flush=True)
245
+ else:
246
+ print("[PRECOND] build diag0", flush=True)
247
+ t_pre0 = time.perf_counter()
248
+ J0 = assemble_J(expand_full(u))
249
+ J0_free = restrict_free_matrix(J0)
250
+ diag0 = jnp.asarray(J0_free.diag(), dtype=u.dtype)
251
+ diag0 = jax.block_until_ready(diag0)
252
+ inv_diag0 = jnp.where(diag0 != 0.0, 1.0 / diag0, 0.0)
253
+
254
+ def precon(r):
255
+ return inv_diag0 * r
256
+
257
+ matfree_precon = precon
258
+ if matfree_cache is not None:
259
+ matfree_cache["inv_diag0"] = inv_diag0
260
+ pre_dt0 = time.perf_counter() - t_pre0
261
+ print(f"[PRECOND] diag0 ready dt={pre_dt0:.3f}s", flush=True)
179
262
 
180
263
  # Initial residual/Jacobian
181
264
  R_full_init = assemble_R(expand_full(u))
@@ -210,14 +293,25 @@ def newton_solve(
210
293
  )
211
294
 
212
295
  if callback is not None:
213
- callback({"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan})
214
-
215
- J = assemble_J(u_full)
216
- finite_j = jnp.all(jnp.isfinite(J.data))
217
- if not bool(jax.block_until_ready(finite_j)):
218
- n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
219
- raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
220
- J_free = restrict_free_matrix(J)
296
+ payload = {"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan}
297
+ if extra_metrics is not None:
298
+ payload["extra_metrics"] = extra_metrics
299
+ callback(payload)
300
+
301
+ if not use_matfree:
302
+ J = assemble_J(u_full)
303
+ finite_j = jnp.all(jnp.isfinite(J.data))
304
+ if not bool(jax.block_until_ready(finite_j)):
305
+ n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
306
+ raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
307
+ J_free = restrict_free_matrix(J)
308
+ else:
309
+ J = None
310
+ J_free = None
311
+ lin_info = {}
312
+ step_norm = float("nan")
313
+ linear_converged = True
314
+ lr = None
221
315
  for k in range(maxiter):
222
316
  # --- Newton residual (iteration start) ---
223
317
  t_iter0 = time.perf_counter()
@@ -232,13 +326,34 @@ def newton_solve(
232
326
  raise RuntimeError("[newton] residual became non-finite; aborting.")
233
327
 
234
328
  crit = max(atol, tol * res0_inf)
329
+ contact_log = ""
330
+ if extra_metrics is not None and isinstance(extra_metrics, dict):
331
+ min_g = extra_metrics.get("min_g")
332
+ pen = extra_metrics.get("penetration")
333
+ if min_g is not None and pen is not None:
334
+ contact_log = f" min_g={float(min_g):.3e} pen={float(pen):.3e}"
235
335
  print(
236
- f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}",
336
+ f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}{contact_log}",
237
337
  flush=True,
238
338
  )
239
339
 
240
340
  # --- Linear solve (J_free * du = -R_free) ---
341
+ t_rhs0 = time.perf_counter()
241
342
  rhs = jnp.asarray(-R_free, dtype=u.dtype)
343
+ rhs_norm = jnp.linalg.norm(rhs)
344
+ rhs_norm_f = float(jax.block_until_ready(rhs_norm))
345
+ rhs_dt = time.perf_counter() - t_rhs0
346
+ if rhs_norm_f <= atol:
347
+ print(
348
+ f"[newton] k={k:02d} CONVERGED rhs<=atol ({rhs_norm_f:.3e} <= {atol:.3e})",
349
+ flush=True,
350
+ )
351
+ return expand_full(u), SolverResult(
352
+ converged=True,
353
+ iters=k,
354
+ stop_reason="rhs_atol",
355
+ nan_detected=False,
356
+ )
242
357
 
243
358
  # Separate preconditioner build time from linear solve time.
244
359
  t_pre0 = time.perf_counter()
@@ -247,7 +362,54 @@ def newton_solve(
247
362
  linear_residual = None
248
363
  lin_iters = None
249
364
 
250
- if linear_solver in ("cg", "cg_jax", "cg_custom"):
365
+ linearize_dt = 0.0
366
+ if use_matfree:
367
+ if linear_preconditioner in ("jacobi", "block_jacobi"):
368
+ raise ValueError("cg_matfree does not support jacobi preconditioners")
369
+ if linear_preconditioner == "diag0":
370
+ cg_precon = matfree_precon
371
+ elif linear_preconditioner is not None and not callable(linear_preconditioner):
372
+ raise ValueError("cg_matfree preconditioner must be callable or None")
373
+ pre_dt = 0.0
374
+ if linear_preconditioner not in ("diag0", None):
375
+ cg_precon = linear_preconditioner
376
+ print(f"[linear] k={k:02d} {linear_solver}: linearize...", flush=True)
377
+ t_lin0 = time.perf_counter()
378
+ if matfree_mode == "linearize":
379
+ _res, lin_fun = jax.linearize(residual_free, u)
380
+ mv = lambda v: lin_fun(v)
381
+ else:
382
+ mv = lambda v: jax.jvp(residual_free, (u,), (v,))[1]
383
+ linearize_dt = time.perf_counter() - t_lin0
384
+ t_mv0 = 0.0
385
+ mv0_norm_f = None
386
+ t_mv0_0 = time.perf_counter()
387
+ mv0 = mv(rhs)
388
+ mv0_norm = jnp.linalg.norm(mv0)
389
+ mv0_norm_f = float(jax.block_until_ready(mv0_norm))
390
+ t_mv0 = time.perf_counter() - t_mv0_0
391
+ print(
392
+ f"[linear] k={k:02d} {linear_solver}: rhs_dt={rhs_dt:.3f}s "
393
+ f"mv0_dt={t_mv0:.3f}s ||b||={rhs_norm_f:.3e} ||Jb||={mv0_norm_f:.3e}",
394
+ flush=True,
395
+ )
396
+ cg_solver = cg_solve
397
+ print(f"[linear] k={k:02d} {linear_solver}: solve...", flush=True)
398
+ t_cg0 = time.perf_counter()
399
+ du_free, lin_info = cg_solver(
400
+ mv,
401
+ rhs,
402
+ tol=eff_linear_tol,
403
+ maxiter=linear_maxiter,
404
+ preconditioner=cg_precon,
405
+ )
406
+ du_free = jax.block_until_ready(du_free)
407
+ lin_dt = time.perf_counter() - t_cg0
408
+ linear_residual = lin_info.get("residual_norm")
409
+ linear_converged = bool(lin_info.get("converged", True))
410
+ lin_iters = lin_info.get("iters", None)
411
+
412
+ elif linear_solver in ("cg", "cg_jax", "cg_custom"):
251
413
  # Preconditioner build
252
414
  if linear_preconditioner == "jacobi":
253
415
  print(f"[newton] k={k:02d} PRECOND jacobi: diag...", flush=True)
@@ -307,11 +469,18 @@ def newton_solve(
307
469
  raise ValueError(f"Unknown linear solver: {linear_solver}")
308
470
 
309
471
  lr = float(linear_residual) if linear_residual is not None else float("nan")
310
- print(
311
- f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
312
- f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
313
- flush=True,
314
- )
472
+ if use_matfree:
473
+ print(
474
+ f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
475
+ f"linz_dt={linearize_dt:.3f}s cg_dt={lin_dt:.3f}s",
476
+ flush=True,
477
+ )
478
+ else:
479
+ print(
480
+ f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
481
+ f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
482
+ flush=True,
483
+ )
315
484
 
316
485
  # --- Trial update & residual evaluation ---
317
486
  # Start with alpha=1 and eval_residual (if heavy, assemble_R is heavy/compiled).
@@ -362,20 +531,21 @@ def newton_solve(
362
531
 
363
532
  # callback
364
533
  if callback is not None:
365
- callback(
366
- {
367
- "iter": k + 1,
368
- "res_inf": res_trial_inf,
369
- "res_two": res_trial_two,
370
- "rel_residual": res_trial_inf / res0_inf,
371
- "alpha": alpha,
372
- "step_norm": step_norm,
373
- "linear_iters": lin_info.get("iters"),
374
- "linear_converged": linear_converged,
375
- "linear_residual": lr,
376
- "nan_detected": bool(np.isnan(res_trial_inf)),
377
- }
378
- )
534
+ payload = {
535
+ "iter": k + 1,
536
+ "res_inf": res_trial_inf,
537
+ "res_two": res_trial_two,
538
+ "rel_residual": res_trial_inf / res0_inf,
539
+ "alpha": alpha,
540
+ "step_norm": step_norm,
541
+ "linear_iters": lin_info.get("iters"),
542
+ "linear_converged": linear_converged,
543
+ "linear_residual": lr,
544
+ "nan_detected": bool(np.isnan(res_trial_inf)),
545
+ }
546
+ if extra_metrics is not None:
547
+ payload["extra_metrics"] = extra_metrics
548
+ callback(payload)
379
549
 
380
550
  # --- Convergence check ---
381
551
  if res_trial_inf < crit and linear_converged and not np.isnan(res_trial_inf):
@@ -398,3 +568,23 @@ def newton_solve(
398
568
  stop_reason="converged",
399
569
  nan_detected=bool(np.isnan(res_trial_inf)),
400
570
  )
571
+
572
+ res_final_inf = float(jnp.linalg.norm(R_free, ord=jnp.inf))
573
+ res_final_two = float(jnp.linalg.norm(R_free, ord=2))
574
+ return u_full, SolverResult(
575
+ converged=False,
576
+ iters=maxiter,
577
+ residual_norm=res_final_inf,
578
+ residual0=res0_inf,
579
+ rel_residual=(res_final_inf / res0_inf if res0_inf != 0.0 else float("inf")),
580
+ line_search_steps=0,
581
+ linear_iters=lin_info.get("iters"),
582
+ linear_converged=linear_converged,
583
+ linear_residual=lr,
584
+ tol=tol,
585
+ atol=atol,
586
+ stopping_criterion=crit,
587
+ step_norm=step_norm,
588
+ stop_reason="maxiter",
589
+ nan_detected=bool(np.isnan(res_final_inf)),
590
+ )