fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +68 -0
- fluxfem/core/__init__.py +115 -10
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +768 -7
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +316 -7
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +6 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/solver/newton.py
CHANGED
|
@@ -14,6 +14,7 @@ from ..core.assembly import (
|
|
|
14
14
|
)
|
|
15
15
|
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
16
16
|
from .cg import cg_solve, cg_solve_jax
|
|
17
|
+
from .preconditioner import make_block_jacobi_preconditioner
|
|
17
18
|
from .result import SolverResult
|
|
18
19
|
from .sparse import SparsityPattern, FluxSparseMatrix
|
|
19
20
|
from .dirichlet import _normalize_dirichlet
|
|
@@ -28,10 +29,12 @@ def newton_solve(
|
|
|
28
29
|
tol: float = 1e-8,
|
|
29
30
|
atol: float = 0.0,
|
|
30
31
|
maxiter: int = 20,
|
|
31
|
-
linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg"
|
|
32
|
+
linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg", "cg_jax", "cg_custom", or "cg_matfree"
|
|
32
33
|
linear_maxiter: int | None = None,
|
|
33
34
|
linear_tol: float | None = None,
|
|
34
35
|
linear_preconditioner=None,
|
|
36
|
+
matfree_mode: str = "linearize",
|
|
37
|
+
matfree_cache: dict | None = None,
|
|
35
38
|
dirichlet=None,
|
|
36
39
|
callback=None,
|
|
37
40
|
line_search: bool = False,
|
|
@@ -39,6 +42,7 @@ def newton_solve(
|
|
|
39
42
|
ls_c: float = 1e-4,
|
|
40
43
|
external_vector=None,
|
|
41
44
|
jacobian_pattern=None,
|
|
45
|
+
extra_terms=None,
|
|
42
46
|
):
|
|
43
47
|
"""
|
|
44
48
|
Gridap-style Newton–Raphson solver on free DOFs only.
|
|
@@ -48,9 +52,12 @@ def newton_solve(
|
|
|
48
52
|
- Convergence: ||R_free||_inf < max(atol, tol * ||R_free0||_inf).
|
|
49
53
|
- external_vector: optional global RHS (internal - external).
|
|
50
54
|
- CG path accepts an operator with matvec that acts on free DOFs via a wrapper.
|
|
51
|
-
-
|
|
55
|
+
- cg_matfree uses JVP/linearize to form a matrix-free matvec (no global Jacobian).
|
|
56
|
+
- linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | "diag0" | callable).
|
|
57
|
+
- matfree_cache: optional dict for reusing matrix-free preconditioners across calls.
|
|
52
58
|
- linear_tol: CG tolerance (defaults to 0.1 * tol if not provided).
|
|
53
59
|
- jacobian_pattern: optional SparsityPattern to reuse sparsity across load steps.
|
|
60
|
+
- extra_terms: optional list of callbacks returning (K, f[, metrics]) for extra terms.
|
|
54
61
|
"""
|
|
55
62
|
|
|
56
63
|
if dirichlet is not None:
|
|
@@ -66,6 +73,13 @@ def newton_solve(
|
|
|
66
73
|
dir_dofs = dir_vals = None
|
|
67
74
|
free_dofs = np.arange(space.n_dofs, dtype=int)
|
|
68
75
|
|
|
76
|
+
use_matfree = linear_solver in ("cg_matfree", "cg_jvp")
|
|
77
|
+
if use_matfree and matfree_mode not in ("linearize", "jvp"):
|
|
78
|
+
raise ValueError("matfree_mode must be 'linearize' or 'jvp'")
|
|
79
|
+
|
|
80
|
+
if extra_terms is not None and linear_solver in ("cg", "cg_jax", "cg_custom", "cg_matfree", "cg_jvp"):
|
|
81
|
+
raise ValueError("extra_terms may yield nonsymmetric K; avoid CG-based solvers")
|
|
82
|
+
|
|
69
83
|
free_dofs_j = jnp.asarray(free_dofs, dtype=jnp.int32)
|
|
70
84
|
# For block-Jacobi (3x3 per node) we keep node ids of free dofs.
|
|
71
85
|
node_ids = free_dofs // 3
|
|
@@ -109,33 +123,7 @@ def newton_solve(
|
|
|
109
123
|
Build 3x3 block-Jacobi inverse per free node.
|
|
110
124
|
Assumes DOF ordering per node is [ux, uy, uz].
|
|
111
125
|
"""
|
|
112
|
-
|
|
113
|
-
raise ValueError("block_jacobi assumes 3 DOFs per node.")
|
|
114
|
-
rows = np.asarray(J_free.rows)
|
|
115
|
-
cols = np.asarray(J_free.cols)
|
|
116
|
-
data = np.asarray(J_free.data)
|
|
117
|
-
block_rows = node_ids_inv[rows]
|
|
118
|
-
block_cols = node_ids_inv[cols]
|
|
119
|
-
local_r = rows % 3
|
|
120
|
-
local_c = cols % 3
|
|
121
|
-
mask_blk = block_rows == block_cols
|
|
122
|
-
blk_rows = block_rows[mask_blk]
|
|
123
|
-
blk_lr = local_r[mask_blk]
|
|
124
|
-
blk_lc = local_c[mask_blk]
|
|
125
|
-
blk_data = data[mask_blk]
|
|
126
|
-
inv_blocks = np.zeros((n_block, 3, 3), dtype=blk_data.dtype)
|
|
127
|
-
inv_blocks[blk_rows, blk_lr, blk_lc] += blk_data
|
|
128
|
-
inv_blocks = jnp.asarray(inv_blocks)
|
|
129
|
-
# Add tiny damping to avoid singular blocks
|
|
130
|
-
inv_blocks = inv_blocks + 1e-12 * jnp.eye(3)[None, :, :]
|
|
131
|
-
inv_blocks = jnp.linalg.inv(inv_blocks)
|
|
132
|
-
|
|
133
|
-
def precon(r):
|
|
134
|
-
r_blocks = r.reshape((n_block, 3))
|
|
135
|
-
z_blocks = jnp.einsum("bij,bj->bi", inv_blocks, r_blocks)
|
|
136
|
-
return z_blocks.reshape((-1,))
|
|
137
|
-
|
|
138
|
-
return precon
|
|
126
|
+
return make_block_jacobi_preconditioner(J_free, dof_per_node=3)
|
|
139
127
|
|
|
140
128
|
def expand_full(u_free: jnp.ndarray) -> jnp.ndarray:
|
|
141
129
|
if dir_dofs is None:
|
|
@@ -145,6 +133,39 @@ def newton_solve(
|
|
|
145
133
|
u_full = u_full.at[dir_dofs_j].set(dir_vals_j)
|
|
146
134
|
return u_full
|
|
147
135
|
|
|
136
|
+
extra_metrics = None
|
|
137
|
+
|
|
138
|
+
def _call_extra(u_full_vec):
|
|
139
|
+
if extra_terms is None:
|
|
140
|
+
return None
|
|
141
|
+
K_sum = None
|
|
142
|
+
f_sum = None
|
|
143
|
+
metrics_sum = {}
|
|
144
|
+
for term in extra_terms:
|
|
145
|
+
out = term(np.asarray(u_full_vec))
|
|
146
|
+
if out is None:
|
|
147
|
+
continue
|
|
148
|
+
if len(out) == 2:
|
|
149
|
+
Kc, fc = out
|
|
150
|
+
metrics = None
|
|
151
|
+
else:
|
|
152
|
+
Kc, fc, metrics = out
|
|
153
|
+
Kc = np.asarray(Kc, dtype=float)
|
|
154
|
+
fc = np.asarray(fc, dtype=float)
|
|
155
|
+
if K_sum is None:
|
|
156
|
+
K_sum = Kc
|
|
157
|
+
else:
|
|
158
|
+
K_sum = K_sum + Kc
|
|
159
|
+
if f_sum is None:
|
|
160
|
+
f_sum = fc
|
|
161
|
+
else:
|
|
162
|
+
f_sum = f_sum + fc
|
|
163
|
+
if isinstance(metrics, dict):
|
|
164
|
+
metrics_sum.update(metrics)
|
|
165
|
+
if K_sum is None or f_sum is None:
|
|
166
|
+
return None
|
|
167
|
+
return K_sum, f_sum, (metrics_sum or None)
|
|
168
|
+
|
|
148
169
|
def eval_residual(u_free_vec):
|
|
149
170
|
"""Residual on free DOFs only."""
|
|
150
171
|
u_full = expand_full(u_free_vec)
|
|
@@ -156,17 +177,32 @@ def newton_solve(
|
|
|
156
177
|
res_two = float(jnp.linalg.norm(R_free, ord=2))
|
|
157
178
|
return R_free, res_inf, res_two, u_full
|
|
158
179
|
|
|
180
|
+
def residual_free(u_free_vec):
|
|
181
|
+
u_full = expand_full(u_free_vec)
|
|
182
|
+
R_full = assemble_R(u_full)
|
|
183
|
+
if external_vector is not None:
|
|
184
|
+
R_full = R_full - external_vector
|
|
185
|
+
return R_full[free_dofs_j]
|
|
186
|
+
|
|
159
187
|
# Pre-jitted element kernels to avoid recompiling inside Newton
|
|
160
188
|
res_kernel = make_element_residual_kernel(res_form, params)
|
|
161
189
|
jac_kernel = make_element_jacobian_kernel(res_form, params)
|
|
162
190
|
|
|
163
191
|
def assemble_R(u_full_vec):
|
|
164
|
-
|
|
192
|
+
nonlocal extra_metrics
|
|
193
|
+
R = assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
|
|
194
|
+
extra_out = _call_extra(u_full_vec)
|
|
195
|
+
if extra_out is not None:
|
|
196
|
+
_Kc, fc, metrics = extra_out
|
|
197
|
+
extra_metrics = metrics
|
|
198
|
+
R = R + jnp.asarray(fc, dtype=R.dtype)
|
|
199
|
+
return R
|
|
165
200
|
|
|
166
201
|
eff_linear_tol = linear_tol if linear_tol is not None else max(0.1 * tol, 1e-12)
|
|
167
202
|
|
|
168
203
|
def assemble_J(u_full_vec):
|
|
169
|
-
|
|
204
|
+
nonlocal extra_metrics
|
|
205
|
+
J = assemble_jacobian_scatter(
|
|
170
206
|
space,
|
|
171
207
|
res_form,
|
|
172
208
|
u_full_vec,
|
|
@@ -176,6 +212,40 @@ def newton_solve(
|
|
|
176
212
|
return_flux_matrix=True,
|
|
177
213
|
pattern=J_pattern,
|
|
178
214
|
)
|
|
215
|
+
extra_out = _call_extra(u_full_vec)
|
|
216
|
+
if extra_out is not None:
|
|
217
|
+
Kc, _fc, metrics = extra_out
|
|
218
|
+
extra_metrics = metrics
|
|
219
|
+
rows = np.asarray(J.pattern.rows, dtype=int)
|
|
220
|
+
cols = np.asarray(J.pattern.cols, dtype=int)
|
|
221
|
+
data = jnp.asarray(J.data) + jnp.asarray(Kc[rows, cols], dtype=J.data.dtype)
|
|
222
|
+
J = J.with_data(data)
|
|
223
|
+
return J
|
|
224
|
+
|
|
225
|
+
matfree_precon = None
|
|
226
|
+
if use_matfree and linear_preconditioner == "diag0":
|
|
227
|
+
cached = matfree_cache.get("inv_diag0") if matfree_cache is not None else None
|
|
228
|
+
if cached is not None:
|
|
229
|
+
inv_diag0 = cached
|
|
230
|
+
matfree_precon = lambda r: inv_diag0 * r
|
|
231
|
+
print("[PRECOND] reuse diag0", flush=True)
|
|
232
|
+
else:
|
|
233
|
+
print("[PRECOND] build diag0", flush=True)
|
|
234
|
+
t_pre0 = time.perf_counter()
|
|
235
|
+
J0 = assemble_J(expand_full(u))
|
|
236
|
+
J0_free = restrict_free_matrix(J0)
|
|
237
|
+
diag0 = jnp.asarray(J0_free.diag(), dtype=u.dtype)
|
|
238
|
+
diag0 = jax.block_until_ready(diag0)
|
|
239
|
+
inv_diag0 = jnp.where(diag0 != 0.0, 1.0 / diag0, 0.0)
|
|
240
|
+
|
|
241
|
+
def precon(r):
|
|
242
|
+
return inv_diag0 * r
|
|
243
|
+
|
|
244
|
+
matfree_precon = precon
|
|
245
|
+
if matfree_cache is not None:
|
|
246
|
+
matfree_cache["inv_diag0"] = inv_diag0
|
|
247
|
+
pre_dt0 = time.perf_counter() - t_pre0
|
|
248
|
+
print(f"[PRECOND] diag0 ready dt={pre_dt0:.3f}s", flush=True)
|
|
179
249
|
|
|
180
250
|
# Initial residual/Jacobian
|
|
181
251
|
R_full_init = assemble_R(expand_full(u))
|
|
@@ -210,14 +280,25 @@ def newton_solve(
|
|
|
210
280
|
)
|
|
211
281
|
|
|
212
282
|
if callback is not None:
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
283
|
+
payload = {"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan}
|
|
284
|
+
if extra_metrics is not None:
|
|
285
|
+
payload["extra_metrics"] = extra_metrics
|
|
286
|
+
callback(payload)
|
|
287
|
+
|
|
288
|
+
if not use_matfree:
|
|
289
|
+
J = assemble_J(u_full)
|
|
290
|
+
finite_j = jnp.all(jnp.isfinite(J.data))
|
|
291
|
+
if not bool(jax.block_until_ready(finite_j)):
|
|
292
|
+
n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
|
|
293
|
+
raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
|
|
294
|
+
J_free = restrict_free_matrix(J)
|
|
295
|
+
else:
|
|
296
|
+
J = None
|
|
297
|
+
J_free = None
|
|
298
|
+
lin_info = {}
|
|
299
|
+
step_norm = float("nan")
|
|
300
|
+
linear_converged = True
|
|
301
|
+
lr = None
|
|
221
302
|
for k in range(maxiter):
|
|
222
303
|
# --- Newton residual (iteration start) ---
|
|
223
304
|
t_iter0 = time.perf_counter()
|
|
@@ -232,13 +313,34 @@ def newton_solve(
|
|
|
232
313
|
raise RuntimeError("[newton] residual became non-finite; aborting.")
|
|
233
314
|
|
|
234
315
|
crit = max(atol, tol * res0_inf)
|
|
316
|
+
contact_log = ""
|
|
317
|
+
if extra_metrics is not None and isinstance(extra_metrics, dict):
|
|
318
|
+
min_g = extra_metrics.get("min_g")
|
|
319
|
+
pen = extra_metrics.get("penetration")
|
|
320
|
+
if min_g is not None and pen is not None:
|
|
321
|
+
contact_log = f" min_g={float(min_g):.3e} pen={float(pen):.3e}"
|
|
235
322
|
print(
|
|
236
|
-
f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}",
|
|
323
|
+
f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}{contact_log}",
|
|
237
324
|
flush=True,
|
|
238
325
|
)
|
|
239
326
|
|
|
240
327
|
# --- Linear solve (J_free * du = -R_free) ---
|
|
328
|
+
t_rhs0 = time.perf_counter()
|
|
241
329
|
rhs = jnp.asarray(-R_free, dtype=u.dtype)
|
|
330
|
+
rhs_norm = jnp.linalg.norm(rhs)
|
|
331
|
+
rhs_norm_f = float(jax.block_until_ready(rhs_norm))
|
|
332
|
+
rhs_dt = time.perf_counter() - t_rhs0
|
|
333
|
+
if rhs_norm_f <= atol:
|
|
334
|
+
print(
|
|
335
|
+
f"[newton] k={k:02d} CONVERGED rhs<=atol ({rhs_norm_f:.3e} <= {atol:.3e})",
|
|
336
|
+
flush=True,
|
|
337
|
+
)
|
|
338
|
+
return expand_full(u), SolverResult(
|
|
339
|
+
converged=True,
|
|
340
|
+
iters=k,
|
|
341
|
+
stop_reason="rhs_atol",
|
|
342
|
+
nan_detected=False,
|
|
343
|
+
)
|
|
242
344
|
|
|
243
345
|
# Separate preconditioner build time from linear solve time.
|
|
244
346
|
t_pre0 = time.perf_counter()
|
|
@@ -247,7 +349,54 @@ def newton_solve(
|
|
|
247
349
|
linear_residual = None
|
|
248
350
|
lin_iters = None
|
|
249
351
|
|
|
250
|
-
|
|
352
|
+
linearize_dt = 0.0
|
|
353
|
+
if use_matfree:
|
|
354
|
+
if linear_preconditioner in ("jacobi", "block_jacobi"):
|
|
355
|
+
raise ValueError("cg_matfree does not support jacobi preconditioners")
|
|
356
|
+
if linear_preconditioner == "diag0":
|
|
357
|
+
cg_precon = matfree_precon
|
|
358
|
+
elif linear_preconditioner is not None and not callable(linear_preconditioner):
|
|
359
|
+
raise ValueError("cg_matfree preconditioner must be callable or None")
|
|
360
|
+
pre_dt = 0.0
|
|
361
|
+
if linear_preconditioner not in ("diag0", None):
|
|
362
|
+
cg_precon = linear_preconditioner
|
|
363
|
+
print(f"[linear] k={k:02d} {linear_solver}: linearize...", flush=True)
|
|
364
|
+
t_lin0 = time.perf_counter()
|
|
365
|
+
if matfree_mode == "linearize":
|
|
366
|
+
_res, lin_fun = jax.linearize(residual_free, u)
|
|
367
|
+
mv = lambda v: lin_fun(v)
|
|
368
|
+
else:
|
|
369
|
+
mv = lambda v: jax.jvp(residual_free, (u,), (v,))[1]
|
|
370
|
+
linearize_dt = time.perf_counter() - t_lin0
|
|
371
|
+
t_mv0 = 0.0
|
|
372
|
+
mv0_norm_f = None
|
|
373
|
+
t_mv0_0 = time.perf_counter()
|
|
374
|
+
mv0 = mv(rhs)
|
|
375
|
+
mv0_norm = jnp.linalg.norm(mv0)
|
|
376
|
+
mv0_norm_f = float(jax.block_until_ready(mv0_norm))
|
|
377
|
+
t_mv0 = time.perf_counter() - t_mv0_0
|
|
378
|
+
print(
|
|
379
|
+
f"[linear] k={k:02d} {linear_solver}: rhs_dt={rhs_dt:.3f}s "
|
|
380
|
+
f"mv0_dt={t_mv0:.3f}s ||b||={rhs_norm_f:.3e} ||Jb||={mv0_norm_f:.3e}",
|
|
381
|
+
flush=True,
|
|
382
|
+
)
|
|
383
|
+
cg_solver = cg_solve
|
|
384
|
+
print(f"[linear] k={k:02d} {linear_solver}: solve...", flush=True)
|
|
385
|
+
t_cg0 = time.perf_counter()
|
|
386
|
+
du_free, lin_info = cg_solver(
|
|
387
|
+
mv,
|
|
388
|
+
rhs,
|
|
389
|
+
tol=eff_linear_tol,
|
|
390
|
+
maxiter=linear_maxiter,
|
|
391
|
+
preconditioner=cg_precon,
|
|
392
|
+
)
|
|
393
|
+
du_free = jax.block_until_ready(du_free)
|
|
394
|
+
lin_dt = time.perf_counter() - t_cg0
|
|
395
|
+
linear_residual = lin_info.get("residual_norm")
|
|
396
|
+
linear_converged = bool(lin_info.get("converged", True))
|
|
397
|
+
lin_iters = lin_info.get("iters", None)
|
|
398
|
+
|
|
399
|
+
elif linear_solver in ("cg", "cg_jax", "cg_custom"):
|
|
251
400
|
# Preconditioner build
|
|
252
401
|
if linear_preconditioner == "jacobi":
|
|
253
402
|
print(f"[newton] k={k:02d} PRECOND jacobi: diag...", flush=True)
|
|
@@ -307,11 +456,18 @@ def newton_solve(
|
|
|
307
456
|
raise ValueError(f"Unknown linear solver: {linear_solver}")
|
|
308
457
|
|
|
309
458
|
lr = float(linear_residual) if linear_residual is not None else float("nan")
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
459
|
+
if use_matfree:
|
|
460
|
+
print(
|
|
461
|
+
f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
|
|
462
|
+
f"linz_dt={linearize_dt:.3f}s cg_dt={lin_dt:.3f}s",
|
|
463
|
+
flush=True,
|
|
464
|
+
)
|
|
465
|
+
else:
|
|
466
|
+
print(
|
|
467
|
+
f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
|
|
468
|
+
f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
|
|
469
|
+
flush=True,
|
|
470
|
+
)
|
|
315
471
|
|
|
316
472
|
# --- Trial update & residual evaluation ---
|
|
317
473
|
# Start with alpha=1 and eval_residual (if heavy, assemble_R is heavy/compiled).
|
|
@@ -362,20 +518,21 @@ def newton_solve(
|
|
|
362
518
|
|
|
363
519
|
# callback
|
|
364
520
|
if callback is not None:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
521
|
+
payload = {
|
|
522
|
+
"iter": k + 1,
|
|
523
|
+
"res_inf": res_trial_inf,
|
|
524
|
+
"res_two": res_trial_two,
|
|
525
|
+
"rel_residual": res_trial_inf / res0_inf,
|
|
526
|
+
"alpha": alpha,
|
|
527
|
+
"step_norm": step_norm,
|
|
528
|
+
"linear_iters": lin_info.get("iters"),
|
|
529
|
+
"linear_converged": linear_converged,
|
|
530
|
+
"linear_residual": lr,
|
|
531
|
+
"nan_detected": bool(np.isnan(res_trial_inf)),
|
|
532
|
+
}
|
|
533
|
+
if extra_metrics is not None:
|
|
534
|
+
payload["extra_metrics"] = extra_metrics
|
|
535
|
+
callback(payload)
|
|
379
536
|
|
|
380
537
|
# --- Convergence check ---
|
|
381
538
|
if res_trial_inf < crit and linear_converged and not np.isnan(res_trial_inf):
|
|
@@ -398,3 +555,23 @@ def newton_solve(
|
|
|
398
555
|
stop_reason="converged",
|
|
399
556
|
nan_detected=bool(np.isnan(res_trial_inf)),
|
|
400
557
|
)
|
|
558
|
+
|
|
559
|
+
res_final_inf = float(jnp.linalg.norm(R_free, ord=jnp.inf))
|
|
560
|
+
res_final_two = float(jnp.linalg.norm(R_free, ord=2))
|
|
561
|
+
return u_full, SolverResult(
|
|
562
|
+
converged=False,
|
|
563
|
+
iters=maxiter,
|
|
564
|
+
residual_norm=res_final_inf,
|
|
565
|
+
residual0=res0_inf,
|
|
566
|
+
rel_residual=(res_final_inf / res0_inf if res0_inf != 0.0 else float("inf")),
|
|
567
|
+
line_search_steps=0,
|
|
568
|
+
linear_iters=lin_info.get("iters"),
|
|
569
|
+
linear_converged=linear_converged,
|
|
570
|
+
linear_residual=lr,
|
|
571
|
+
tol=tol,
|
|
572
|
+
atol=atol,
|
|
573
|
+
stopping_criterion=crit,
|
|
574
|
+
step_norm=step_norm,
|
|
575
|
+
stop_reason="maxiter",
|
|
576
|
+
nan_detected=bool(np.isnan(res_final_inf)),
|
|
577
|
+
)
|