fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/solver/newton.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
import time
|
|
3
3
|
|
|
4
|
+
from typing import Any, Callable, Mapping, TYPE_CHECKING, TypeAlias
|
|
5
|
+
|
|
4
6
|
import numpy as np
|
|
5
7
|
import jax
|
|
6
8
|
import jax.numpy as jnp
|
|
@@ -8,38 +10,51 @@ import jax.numpy as jnp
|
|
|
8
10
|
from ..core.assembly import (
|
|
9
11
|
assemble_residual_scatter,
|
|
10
12
|
assemble_jacobian_scatter,
|
|
13
|
+
ResidualForm,
|
|
11
14
|
make_element_residual_kernel,
|
|
12
15
|
make_element_jacobian_kernel,
|
|
13
16
|
make_sparsity_pattern,
|
|
14
17
|
)
|
|
15
18
|
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
16
19
|
from .cg import cg_solve, cg_solve_jax
|
|
20
|
+
from .preconditioner import make_block_jacobi_preconditioner
|
|
17
21
|
from .result import SolverResult
|
|
18
22
|
from .sparse import SparsityPattern, FluxSparseMatrix
|
|
19
|
-
from .dirichlet import _normalize_dirichlet
|
|
23
|
+
from .dirichlet import DirichletBC, _normalize_dirichlet
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from jax import Array as JaxArray
|
|
27
|
+
|
|
28
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
29
|
+
else:
|
|
30
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
31
|
+
ExtraTerm: TypeAlias = Callable[[np.ndarray], tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Mapping[str, Any]] | None]
|
|
20
32
|
|
|
21
33
|
|
|
22
34
|
def newton_solve(
|
|
23
35
|
space,
|
|
24
|
-
res_form,
|
|
25
|
-
u0,
|
|
26
|
-
params,
|
|
36
|
+
res_form: ResidualForm[Any],
|
|
37
|
+
u0: ArrayLike,
|
|
38
|
+
params: Any,
|
|
27
39
|
*,
|
|
28
40
|
tol: float = 1e-8,
|
|
29
41
|
atol: float = 0.0,
|
|
30
42
|
maxiter: int = 20,
|
|
31
|
-
linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg"
|
|
43
|
+
linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg", "cg_jax", "cg_custom", or "cg_matfree"
|
|
32
44
|
linear_maxiter: int | None = None,
|
|
33
45
|
linear_tol: float | None = None,
|
|
34
|
-
linear_preconditioner=None,
|
|
35
|
-
|
|
36
|
-
|
|
46
|
+
linear_preconditioner: object | None = None,
|
|
47
|
+
matfree_mode: str = "linearize",
|
|
48
|
+
matfree_cache: dict[str, Any] | None = None,
|
|
49
|
+
dirichlet: tuple[np.ndarray, np.ndarray] | None = None,
|
|
50
|
+
callback: Callable[[np.ndarray, SolverResult], Any] | None = None,
|
|
37
51
|
line_search: bool = False,
|
|
38
52
|
max_ls: int = 10,
|
|
39
53
|
ls_c: float = 1e-4,
|
|
40
|
-
external_vector=None,
|
|
41
|
-
jacobian_pattern=None,
|
|
42
|
-
|
|
54
|
+
external_vector: np.ndarray | None = None,
|
|
55
|
+
jacobian_pattern: SparsityPattern | None = None,
|
|
56
|
+
extra_terms: list[ExtraTerm] | None = None,
|
|
57
|
+
) -> tuple[np.ndarray, SolverResult]:
|
|
43
58
|
"""
|
|
44
59
|
Gridap-style Newton–Raphson solver on free DOFs only.
|
|
45
60
|
|
|
@@ -48,12 +63,17 @@ def newton_solve(
|
|
|
48
63
|
- Convergence: ||R_free||_inf < max(atol, tol * ||R_free0||_inf).
|
|
49
64
|
- external_vector: optional global RHS (internal - external).
|
|
50
65
|
- CG path accepts an operator with matvec that acts on free DOFs via a wrapper.
|
|
51
|
-
-
|
|
66
|
+
- cg_matfree uses JVP/linearize to form a matrix-free matvec (no global Jacobian).
|
|
67
|
+
- linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | "diag0" | callable).
|
|
68
|
+
- matfree_cache: optional dict for reusing matrix-free preconditioners across calls.
|
|
52
69
|
- linear_tol: CG tolerance (defaults to 0.1 * tol if not provided).
|
|
53
70
|
- jacobian_pattern: optional SparsityPattern to reuse sparsity across load steps.
|
|
71
|
+
- extra_terms: optional list of callbacks returning (K, f[, metrics]) for extra terms.
|
|
54
72
|
"""
|
|
55
73
|
|
|
56
74
|
if dirichlet is not None:
|
|
75
|
+
if isinstance(dirichlet, DirichletBC):
|
|
76
|
+
dirichlet = dirichlet.as_tuple()
|
|
57
77
|
dir_dofs, dir_vals = dirichlet
|
|
58
78
|
dir_dofs, dir_vals = _normalize_dirichlet(dir_dofs, dir_vals)
|
|
59
79
|
if dir_vals.ndim == 0:
|
|
@@ -66,6 +86,13 @@ def newton_solve(
|
|
|
66
86
|
dir_dofs = dir_vals = None
|
|
67
87
|
free_dofs = np.arange(space.n_dofs, dtype=int)
|
|
68
88
|
|
|
89
|
+
use_matfree = linear_solver in ("cg_matfree", "cg_jvp")
|
|
90
|
+
if use_matfree and matfree_mode not in ("linearize", "jvp"):
|
|
91
|
+
raise ValueError("matfree_mode must be 'linearize' or 'jvp'")
|
|
92
|
+
|
|
93
|
+
if extra_terms is not None and linear_solver in ("cg", "cg_jax", "cg_custom", "cg_matfree", "cg_jvp"):
|
|
94
|
+
raise ValueError("extra_terms may yield nonsymmetric K; avoid CG-based solvers")
|
|
95
|
+
|
|
69
96
|
free_dofs_j = jnp.asarray(free_dofs, dtype=jnp.int32)
|
|
70
97
|
# For block-Jacobi (3x3 per node) we keep node ids of free dofs.
|
|
71
98
|
node_ids = free_dofs // 3
|
|
@@ -109,33 +136,7 @@ def newton_solve(
|
|
|
109
136
|
Build 3x3 block-Jacobi inverse per free node.
|
|
110
137
|
Assumes DOF ordering per node is [ux, uy, uz].
|
|
111
138
|
"""
|
|
112
|
-
|
|
113
|
-
raise ValueError("block_jacobi assumes 3 DOFs per node.")
|
|
114
|
-
rows = np.asarray(J_free.rows)
|
|
115
|
-
cols = np.asarray(J_free.cols)
|
|
116
|
-
data = np.asarray(J_free.data)
|
|
117
|
-
block_rows = node_ids_inv[rows]
|
|
118
|
-
block_cols = node_ids_inv[cols]
|
|
119
|
-
local_r = rows % 3
|
|
120
|
-
local_c = cols % 3
|
|
121
|
-
mask_blk = block_rows == block_cols
|
|
122
|
-
blk_rows = block_rows[mask_blk]
|
|
123
|
-
blk_lr = local_r[mask_blk]
|
|
124
|
-
blk_lc = local_c[mask_blk]
|
|
125
|
-
blk_data = data[mask_blk]
|
|
126
|
-
inv_blocks = np.zeros((n_block, 3, 3), dtype=blk_data.dtype)
|
|
127
|
-
inv_blocks[blk_rows, blk_lr, blk_lc] += blk_data
|
|
128
|
-
inv_blocks = jnp.asarray(inv_blocks)
|
|
129
|
-
# Add tiny damping to avoid singular blocks
|
|
130
|
-
inv_blocks = inv_blocks + 1e-12 * jnp.eye(3)[None, :, :]
|
|
131
|
-
inv_blocks = jnp.linalg.inv(inv_blocks)
|
|
132
|
-
|
|
133
|
-
def precon(r):
|
|
134
|
-
r_blocks = r.reshape((n_block, 3))
|
|
135
|
-
z_blocks = jnp.einsum("bij,bj->bi", inv_blocks, r_blocks)
|
|
136
|
-
return z_blocks.reshape((-1,))
|
|
137
|
-
|
|
138
|
-
return precon
|
|
139
|
+
return make_block_jacobi_preconditioner(J_free, dof_per_node=3)
|
|
139
140
|
|
|
140
141
|
def expand_full(u_free: jnp.ndarray) -> jnp.ndarray:
|
|
141
142
|
if dir_dofs is None:
|
|
@@ -145,6 +146,39 @@ def newton_solve(
|
|
|
145
146
|
u_full = u_full.at[dir_dofs_j].set(dir_vals_j)
|
|
146
147
|
return u_full
|
|
147
148
|
|
|
149
|
+
extra_metrics = None
|
|
150
|
+
|
|
151
|
+
def _call_extra(u_full_vec):
|
|
152
|
+
if extra_terms is None:
|
|
153
|
+
return None
|
|
154
|
+
K_sum = None
|
|
155
|
+
f_sum = None
|
|
156
|
+
metrics_sum = {}
|
|
157
|
+
for term in extra_terms:
|
|
158
|
+
out = term(np.asarray(u_full_vec))
|
|
159
|
+
if out is None:
|
|
160
|
+
continue
|
|
161
|
+
if len(out) == 2:
|
|
162
|
+
Kc, fc = out
|
|
163
|
+
metrics = None
|
|
164
|
+
else:
|
|
165
|
+
Kc, fc, metrics = out
|
|
166
|
+
Kc = np.asarray(Kc, dtype=float)
|
|
167
|
+
fc = np.asarray(fc, dtype=float)
|
|
168
|
+
if K_sum is None:
|
|
169
|
+
K_sum = Kc
|
|
170
|
+
else:
|
|
171
|
+
K_sum = K_sum + Kc
|
|
172
|
+
if f_sum is None:
|
|
173
|
+
f_sum = fc
|
|
174
|
+
else:
|
|
175
|
+
f_sum = f_sum + fc
|
|
176
|
+
if isinstance(metrics, dict):
|
|
177
|
+
metrics_sum.update(metrics)
|
|
178
|
+
if K_sum is None or f_sum is None:
|
|
179
|
+
return None
|
|
180
|
+
return K_sum, f_sum, (metrics_sum or None)
|
|
181
|
+
|
|
148
182
|
def eval_residual(u_free_vec):
|
|
149
183
|
"""Residual on free DOFs only."""
|
|
150
184
|
u_full = expand_full(u_free_vec)
|
|
@@ -156,17 +190,32 @@ def newton_solve(
|
|
|
156
190
|
res_two = float(jnp.linalg.norm(R_free, ord=2))
|
|
157
191
|
return R_free, res_inf, res_two, u_full
|
|
158
192
|
|
|
193
|
+
def residual_free(u_free_vec):
|
|
194
|
+
u_full = expand_full(u_free_vec)
|
|
195
|
+
R_full = assemble_R(u_full)
|
|
196
|
+
if external_vector is not None:
|
|
197
|
+
R_full = R_full - external_vector
|
|
198
|
+
return R_full[free_dofs_j]
|
|
199
|
+
|
|
159
200
|
# Pre-jitted element kernels to avoid recompiling inside Newton
|
|
160
201
|
res_kernel = make_element_residual_kernel(res_form, params)
|
|
161
202
|
jac_kernel = make_element_jacobian_kernel(res_form, params)
|
|
162
203
|
|
|
163
204
|
def assemble_R(u_full_vec):
|
|
164
|
-
|
|
205
|
+
nonlocal extra_metrics
|
|
206
|
+
R = assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
|
|
207
|
+
extra_out = _call_extra(u_full_vec)
|
|
208
|
+
if extra_out is not None:
|
|
209
|
+
_Kc, fc, metrics = extra_out
|
|
210
|
+
extra_metrics = metrics
|
|
211
|
+
R = R + jnp.asarray(fc, dtype=R.dtype)
|
|
212
|
+
return R
|
|
165
213
|
|
|
166
214
|
eff_linear_tol = linear_tol if linear_tol is not None else max(0.1 * tol, 1e-12)
|
|
167
215
|
|
|
168
216
|
def assemble_J(u_full_vec):
|
|
169
|
-
|
|
217
|
+
nonlocal extra_metrics
|
|
218
|
+
J = assemble_jacobian_scatter(
|
|
170
219
|
space,
|
|
171
220
|
res_form,
|
|
172
221
|
u_full_vec,
|
|
@@ -176,6 +225,40 @@ def newton_solve(
|
|
|
176
225
|
return_flux_matrix=True,
|
|
177
226
|
pattern=J_pattern,
|
|
178
227
|
)
|
|
228
|
+
extra_out = _call_extra(u_full_vec)
|
|
229
|
+
if extra_out is not None:
|
|
230
|
+
Kc, _fc, metrics = extra_out
|
|
231
|
+
extra_metrics = metrics
|
|
232
|
+
rows = np.asarray(J.pattern.rows, dtype=int)
|
|
233
|
+
cols = np.asarray(J.pattern.cols, dtype=int)
|
|
234
|
+
data = jnp.asarray(J.data) + jnp.asarray(Kc[rows, cols], dtype=J.data.dtype)
|
|
235
|
+
J = J.with_data(data)
|
|
236
|
+
return J
|
|
237
|
+
|
|
238
|
+
matfree_precon = None
|
|
239
|
+
if use_matfree and linear_preconditioner == "diag0":
|
|
240
|
+
cached = matfree_cache.get("inv_diag0") if matfree_cache is not None else None
|
|
241
|
+
if cached is not None:
|
|
242
|
+
inv_diag0 = cached
|
|
243
|
+
matfree_precon = lambda r: inv_diag0 * r
|
|
244
|
+
print("[PRECOND] reuse diag0", flush=True)
|
|
245
|
+
else:
|
|
246
|
+
print("[PRECOND] build diag0", flush=True)
|
|
247
|
+
t_pre0 = time.perf_counter()
|
|
248
|
+
J0 = assemble_J(expand_full(u))
|
|
249
|
+
J0_free = restrict_free_matrix(J0)
|
|
250
|
+
diag0 = jnp.asarray(J0_free.diag(), dtype=u.dtype)
|
|
251
|
+
diag0 = jax.block_until_ready(diag0)
|
|
252
|
+
inv_diag0 = jnp.where(diag0 != 0.0, 1.0 / diag0, 0.0)
|
|
253
|
+
|
|
254
|
+
def precon(r):
|
|
255
|
+
return inv_diag0 * r
|
|
256
|
+
|
|
257
|
+
matfree_precon = precon
|
|
258
|
+
if matfree_cache is not None:
|
|
259
|
+
matfree_cache["inv_diag0"] = inv_diag0
|
|
260
|
+
pre_dt0 = time.perf_counter() - t_pre0
|
|
261
|
+
print(f"[PRECOND] diag0 ready dt={pre_dt0:.3f}s", flush=True)
|
|
179
262
|
|
|
180
263
|
# Initial residual/Jacobian
|
|
181
264
|
R_full_init = assemble_R(expand_full(u))
|
|
@@ -210,14 +293,25 @@ def newton_solve(
|
|
|
210
293
|
)
|
|
211
294
|
|
|
212
295
|
if callback is not None:
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
296
|
+
payload = {"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan}
|
|
297
|
+
if extra_metrics is not None:
|
|
298
|
+
payload["extra_metrics"] = extra_metrics
|
|
299
|
+
callback(payload)
|
|
300
|
+
|
|
301
|
+
if not use_matfree:
|
|
302
|
+
J = assemble_J(u_full)
|
|
303
|
+
finite_j = jnp.all(jnp.isfinite(J.data))
|
|
304
|
+
if not bool(jax.block_until_ready(finite_j)):
|
|
305
|
+
n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
|
|
306
|
+
raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
|
|
307
|
+
J_free = restrict_free_matrix(J)
|
|
308
|
+
else:
|
|
309
|
+
J = None
|
|
310
|
+
J_free = None
|
|
311
|
+
lin_info = {}
|
|
312
|
+
step_norm = float("nan")
|
|
313
|
+
linear_converged = True
|
|
314
|
+
lr = None
|
|
221
315
|
for k in range(maxiter):
|
|
222
316
|
# --- Newton residual (iteration start) ---
|
|
223
317
|
t_iter0 = time.perf_counter()
|
|
@@ -232,13 +326,34 @@ def newton_solve(
|
|
|
232
326
|
raise RuntimeError("[newton] residual became non-finite; aborting.")
|
|
233
327
|
|
|
234
328
|
crit = max(atol, tol * res0_inf)
|
|
329
|
+
contact_log = ""
|
|
330
|
+
if extra_metrics is not None and isinstance(extra_metrics, dict):
|
|
331
|
+
min_g = extra_metrics.get("min_g")
|
|
332
|
+
pen = extra_metrics.get("penetration")
|
|
333
|
+
if min_g is not None and pen is not None:
|
|
334
|
+
contact_log = f" min_g={float(min_g):.3e} pen={float(pen):.3e}"
|
|
235
335
|
print(
|
|
236
|
-
f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}",
|
|
336
|
+
f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}{contact_log}",
|
|
237
337
|
flush=True,
|
|
238
338
|
)
|
|
239
339
|
|
|
240
340
|
# --- Linear solve (J_free * du = -R_free) ---
|
|
341
|
+
t_rhs0 = time.perf_counter()
|
|
241
342
|
rhs = jnp.asarray(-R_free, dtype=u.dtype)
|
|
343
|
+
rhs_norm = jnp.linalg.norm(rhs)
|
|
344
|
+
rhs_norm_f = float(jax.block_until_ready(rhs_norm))
|
|
345
|
+
rhs_dt = time.perf_counter() - t_rhs0
|
|
346
|
+
if rhs_norm_f <= atol:
|
|
347
|
+
print(
|
|
348
|
+
f"[newton] k={k:02d} CONVERGED rhs<=atol ({rhs_norm_f:.3e} <= {atol:.3e})",
|
|
349
|
+
flush=True,
|
|
350
|
+
)
|
|
351
|
+
return expand_full(u), SolverResult(
|
|
352
|
+
converged=True,
|
|
353
|
+
iters=k,
|
|
354
|
+
stop_reason="rhs_atol",
|
|
355
|
+
nan_detected=False,
|
|
356
|
+
)
|
|
242
357
|
|
|
243
358
|
# Separate preconditioner build time from linear solve time.
|
|
244
359
|
t_pre0 = time.perf_counter()
|
|
@@ -247,7 +362,54 @@ def newton_solve(
|
|
|
247
362
|
linear_residual = None
|
|
248
363
|
lin_iters = None
|
|
249
364
|
|
|
250
|
-
|
|
365
|
+
linearize_dt = 0.0
|
|
366
|
+
if use_matfree:
|
|
367
|
+
if linear_preconditioner in ("jacobi", "block_jacobi"):
|
|
368
|
+
raise ValueError("cg_matfree does not support jacobi preconditioners")
|
|
369
|
+
if linear_preconditioner == "diag0":
|
|
370
|
+
cg_precon = matfree_precon
|
|
371
|
+
elif linear_preconditioner is not None and not callable(linear_preconditioner):
|
|
372
|
+
raise ValueError("cg_matfree preconditioner must be callable or None")
|
|
373
|
+
pre_dt = 0.0
|
|
374
|
+
if linear_preconditioner not in ("diag0", None):
|
|
375
|
+
cg_precon = linear_preconditioner
|
|
376
|
+
print(f"[linear] k={k:02d} {linear_solver}: linearize...", flush=True)
|
|
377
|
+
t_lin0 = time.perf_counter()
|
|
378
|
+
if matfree_mode == "linearize":
|
|
379
|
+
_res, lin_fun = jax.linearize(residual_free, u)
|
|
380
|
+
mv = lambda v: lin_fun(v)
|
|
381
|
+
else:
|
|
382
|
+
mv = lambda v: jax.jvp(residual_free, (u,), (v,))[1]
|
|
383
|
+
linearize_dt = time.perf_counter() - t_lin0
|
|
384
|
+
t_mv0 = 0.0
|
|
385
|
+
mv0_norm_f = None
|
|
386
|
+
t_mv0_0 = time.perf_counter()
|
|
387
|
+
mv0 = mv(rhs)
|
|
388
|
+
mv0_norm = jnp.linalg.norm(mv0)
|
|
389
|
+
mv0_norm_f = float(jax.block_until_ready(mv0_norm))
|
|
390
|
+
t_mv0 = time.perf_counter() - t_mv0_0
|
|
391
|
+
print(
|
|
392
|
+
f"[linear] k={k:02d} {linear_solver}: rhs_dt={rhs_dt:.3f}s "
|
|
393
|
+
f"mv0_dt={t_mv0:.3f}s ||b||={rhs_norm_f:.3e} ||Jb||={mv0_norm_f:.3e}",
|
|
394
|
+
flush=True,
|
|
395
|
+
)
|
|
396
|
+
cg_solver = cg_solve
|
|
397
|
+
print(f"[linear] k={k:02d} {linear_solver}: solve...", flush=True)
|
|
398
|
+
t_cg0 = time.perf_counter()
|
|
399
|
+
du_free, lin_info = cg_solver(
|
|
400
|
+
mv,
|
|
401
|
+
rhs,
|
|
402
|
+
tol=eff_linear_tol,
|
|
403
|
+
maxiter=linear_maxiter,
|
|
404
|
+
preconditioner=cg_precon,
|
|
405
|
+
)
|
|
406
|
+
du_free = jax.block_until_ready(du_free)
|
|
407
|
+
lin_dt = time.perf_counter() - t_cg0
|
|
408
|
+
linear_residual = lin_info.get("residual_norm")
|
|
409
|
+
linear_converged = bool(lin_info.get("converged", True))
|
|
410
|
+
lin_iters = lin_info.get("iters", None)
|
|
411
|
+
|
|
412
|
+
elif linear_solver in ("cg", "cg_jax", "cg_custom"):
|
|
251
413
|
# Preconditioner build
|
|
252
414
|
if linear_preconditioner == "jacobi":
|
|
253
415
|
print(f"[newton] k={k:02d} PRECOND jacobi: diag...", flush=True)
|
|
@@ -307,11 +469,18 @@ def newton_solve(
|
|
|
307
469
|
raise ValueError(f"Unknown linear solver: {linear_solver}")
|
|
308
470
|
|
|
309
471
|
lr = float(linear_residual) if linear_residual is not None else float("nan")
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
472
|
+
if use_matfree:
|
|
473
|
+
print(
|
|
474
|
+
f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
|
|
475
|
+
f"linz_dt={linearize_dt:.3f}s cg_dt={lin_dt:.3f}s",
|
|
476
|
+
flush=True,
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
print(
|
|
480
|
+
f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
|
|
481
|
+
f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
|
|
482
|
+
flush=True,
|
|
483
|
+
)
|
|
315
484
|
|
|
316
485
|
# --- Trial update & residual evaluation ---
|
|
317
486
|
# Start with alpha=1 and eval_residual (if heavy, assemble_R is heavy/compiled).
|
|
@@ -362,20 +531,21 @@ def newton_solve(
|
|
|
362
531
|
|
|
363
532
|
# callback
|
|
364
533
|
if callback is not None:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
534
|
+
payload = {
|
|
535
|
+
"iter": k + 1,
|
|
536
|
+
"res_inf": res_trial_inf,
|
|
537
|
+
"res_two": res_trial_two,
|
|
538
|
+
"rel_residual": res_trial_inf / res0_inf,
|
|
539
|
+
"alpha": alpha,
|
|
540
|
+
"step_norm": step_norm,
|
|
541
|
+
"linear_iters": lin_info.get("iters"),
|
|
542
|
+
"linear_converged": linear_converged,
|
|
543
|
+
"linear_residual": lr,
|
|
544
|
+
"nan_detected": bool(np.isnan(res_trial_inf)),
|
|
545
|
+
}
|
|
546
|
+
if extra_metrics is not None:
|
|
547
|
+
payload["extra_metrics"] = extra_metrics
|
|
548
|
+
callback(payload)
|
|
379
549
|
|
|
380
550
|
# --- Convergence check ---
|
|
381
551
|
if res_trial_inf < crit and linear_converged and not np.isnan(res_trial_inf):
|
|
@@ -398,3 +568,23 @@ def newton_solve(
|
|
|
398
568
|
stop_reason="converged",
|
|
399
569
|
nan_detected=bool(np.isnan(res_trial_inf)),
|
|
400
570
|
)
|
|
571
|
+
|
|
572
|
+
res_final_inf = float(jnp.linalg.norm(R_free, ord=jnp.inf))
|
|
573
|
+
res_final_two = float(jnp.linalg.norm(R_free, ord=2))
|
|
574
|
+
return u_full, SolverResult(
|
|
575
|
+
converged=False,
|
|
576
|
+
iters=maxiter,
|
|
577
|
+
residual_norm=res_final_inf,
|
|
578
|
+
residual0=res0_inf,
|
|
579
|
+
rel_residual=(res_final_inf / res0_inf if res0_inf != 0.0 else float("inf")),
|
|
580
|
+
line_search_steps=0,
|
|
581
|
+
linear_iters=lin_info.get("iters"),
|
|
582
|
+
linear_converged=linear_converged,
|
|
583
|
+
linear_residual=lr,
|
|
584
|
+
tol=tol,
|
|
585
|
+
atol=atol,
|
|
586
|
+
stopping_criterion=crit,
|
|
587
|
+
step_norm=step_norm,
|
|
588
|
+
stop_reason="maxiter",
|
|
589
|
+
nan_detected=bool(np.isnan(res_final_inf)),
|
|
590
|
+
)
|