fluxfem 0.1.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fluxfem might be problematic. Click here for more details.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +318 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +828 -0
- fluxfem/helpers_ts.py +11 -0
- fluxfem/helpers_wf.py +44 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.3a0.dist-info/LICENSE +201 -0
- fluxfem-0.1.3a0.dist-info/METADATA +125 -0
- fluxfem-0.1.3a0.dist-info/RECORD +47 -0
- fluxfem-0.1.3a0.dist-info/WHEEL +4 -0
fluxfem/solver/newton.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
from ..core.assembly import (
|
|
9
|
+
assemble_residual_scatter,
|
|
10
|
+
assemble_jacobian_scatter,
|
|
11
|
+
make_element_residual_kernel,
|
|
12
|
+
make_element_jacobian_kernel,
|
|
13
|
+
make_sparsity_pattern,
|
|
14
|
+
)
|
|
15
|
+
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
16
|
+
from .cg import cg_solve, cg_solve_jax
|
|
17
|
+
from .result import SolverResult
|
|
18
|
+
from .sparse import SparsityPattern, FluxSparseMatrix
|
|
19
|
+
from .dirichlet import _normalize_dirichlet
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def newton_solve(
|
|
23
|
+
space,
|
|
24
|
+
res_form,
|
|
25
|
+
u0,
|
|
26
|
+
params,
|
|
27
|
+
*,
|
|
28
|
+
tol: float = 1e-8,
|
|
29
|
+
atol: float = 0.0,
|
|
30
|
+
maxiter: int = 20,
|
|
31
|
+
linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg" (jax), "cg_jax", or "cg_custom"
|
|
32
|
+
linear_maxiter: int | None = None,
|
|
33
|
+
linear_tol: float | None = None,
|
|
34
|
+
linear_preconditioner=None,
|
|
35
|
+
dirichlet=None,
|
|
36
|
+
callback=None,
|
|
37
|
+
line_search: bool = False,
|
|
38
|
+
max_ls: int = 10,
|
|
39
|
+
ls_c: float = 1e-4,
|
|
40
|
+
external_vector=None,
|
|
41
|
+
jacobian_pattern=None,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Gridap-style Newton–Raphson solver on free DOFs only.
|
|
45
|
+
|
|
46
|
+
- Unknown vector = free DOFs (Dirichlet eliminated).
|
|
47
|
+
- Residual/Jacobian are assembled on full DOFs; we slice to free DOFs.
|
|
48
|
+
- Convergence: ||R_free||_inf < max(atol, tol * ||R_free0||_inf).
|
|
49
|
+
- external_vector: optional global RHS (internal - external).
|
|
50
|
+
- CG path accepts an operator with matvec that acts on free DOFs via a wrapper.
|
|
51
|
+
- linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | callable).
|
|
52
|
+
- linear_tol: CG tolerance (defaults to 0.1 * tol if not provided).
|
|
53
|
+
- jacobian_pattern: optional SparsityPattern to reuse sparsity across load steps.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
if dirichlet is not None:
|
|
57
|
+
dir_dofs, dir_vals = dirichlet
|
|
58
|
+
dir_dofs, dir_vals = _normalize_dirichlet(dir_dofs, dir_vals)
|
|
59
|
+
if dir_vals.ndim == 0:
|
|
60
|
+
dir_vals = np.full(dir_dofs.shape[0], float(dir_vals))
|
|
61
|
+
all_dofs = np.arange(space.n_dofs, dtype=int)
|
|
62
|
+
mask = np.ones(space.n_dofs, dtype=bool)
|
|
63
|
+
mask[dir_dofs] = False
|
|
64
|
+
free_dofs = all_dofs[mask]
|
|
65
|
+
else:
|
|
66
|
+
dir_dofs = dir_vals = None
|
|
67
|
+
free_dofs = np.arange(space.n_dofs, dtype=int)
|
|
68
|
+
|
|
69
|
+
free_dofs_j = jnp.asarray(free_dofs, dtype=jnp.int32)
|
|
70
|
+
# For block-Jacobi (3x3 per node) we keep node ids of free dofs.
|
|
71
|
+
node_ids = free_dofs // 3
|
|
72
|
+
node_ids_unique, node_ids_inv = np.unique(node_ids, return_inverse=True)
|
|
73
|
+
n_block = len(node_ids_unique)
|
|
74
|
+
dir_dofs_j = jnp.asarray(dir_dofs, dtype=jnp.int32) if dir_dofs is not None else None
|
|
75
|
+
dir_vals_j = jnp.asarray(dir_vals, dtype=jnp.asarray(u0).dtype) if dir_vals is not None else None
|
|
76
|
+
|
|
77
|
+
# Unknown is free DOFs only
|
|
78
|
+
u = jnp.asarray(u0)[free_dofs]
|
|
79
|
+
|
|
80
|
+
# Sparsity pattern does not depend on u; cache once
|
|
81
|
+
J_pattern = jacobian_pattern if jacobian_pattern is not None else make_sparsity_pattern(
|
|
82
|
+
space, with_idx=True
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Build free-DOF subpattern once to avoid scatter/gather in every matvec.
|
|
86
|
+
free_map = -np.ones(space.n_dofs, dtype=np.int32)
|
|
87
|
+
free_map[free_dofs] = np.arange(len(free_dofs), dtype=np.int32)
|
|
88
|
+
pat_rows = np.asarray(J_pattern.rows)
|
|
89
|
+
pat_cols = np.asarray(J_pattern.cols)
|
|
90
|
+
mask_free = (free_map[pat_rows] >= 0) & (free_map[pat_cols] >= 0)
|
|
91
|
+
free_data_idx = jnp.asarray(np.nonzero(mask_free)[0], dtype=jnp.int32)
|
|
92
|
+
rows_f = free_map[pat_rows[mask_free]]
|
|
93
|
+
cols_f = free_map[pat_cols[mask_free]]
|
|
94
|
+
diag_idx_f = np.nonzero(rows_f == cols_f)[0].astype(np.int32)
|
|
95
|
+
J_free_pattern = SparsityPattern(
|
|
96
|
+
rows=jnp.asarray(rows_f, dtype=jnp.int32),
|
|
97
|
+
cols=jnp.asarray(cols_f, dtype=jnp.int32),
|
|
98
|
+
n_dofs=int(len(free_dofs)),
|
|
99
|
+
idx=None,
|
|
100
|
+
diag_idx=jnp.asarray(diag_idx_f, dtype=jnp.int32),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def restrict_free_matrix(J: FluxSparseMatrix) -> FluxSparseMatrix:
|
|
104
|
+
data_f = jnp.asarray(J.data)[free_data_idx]
|
|
105
|
+
return FluxSparseMatrix(J_free_pattern, data_f)
|
|
106
|
+
|
|
107
|
+
def build_block_jacobi(J_free: FluxSparseMatrix):
|
|
108
|
+
"""
|
|
109
|
+
Build 3x3 block-Jacobi inverse per free node.
|
|
110
|
+
Assumes DOF ordering per node is [ux, uy, uz].
|
|
111
|
+
"""
|
|
112
|
+
if len(free_dofs) % 3 != 0:
|
|
113
|
+
raise ValueError("block_jacobi assumes 3 DOFs per node.")
|
|
114
|
+
rows = np.asarray(J_free.rows)
|
|
115
|
+
cols = np.asarray(J_free.cols)
|
|
116
|
+
data = np.asarray(J_free.data)
|
|
117
|
+
block_rows = node_ids_inv[rows]
|
|
118
|
+
block_cols = node_ids_inv[cols]
|
|
119
|
+
local_r = rows % 3
|
|
120
|
+
local_c = cols % 3
|
|
121
|
+
mask_blk = block_rows == block_cols
|
|
122
|
+
blk_rows = block_rows[mask_blk]
|
|
123
|
+
blk_lr = local_r[mask_blk]
|
|
124
|
+
blk_lc = local_c[mask_blk]
|
|
125
|
+
blk_data = data[mask_blk]
|
|
126
|
+
inv_blocks = np.zeros((n_block, 3, 3), dtype=blk_data.dtype)
|
|
127
|
+
inv_blocks[blk_rows, blk_lr, blk_lc] += blk_data
|
|
128
|
+
inv_blocks = jnp.asarray(inv_blocks)
|
|
129
|
+
# Add tiny damping to avoid singular blocks
|
|
130
|
+
inv_blocks = inv_blocks + 1e-12 * jnp.eye(3)[None, :, :]
|
|
131
|
+
inv_blocks = jnp.linalg.inv(inv_blocks)
|
|
132
|
+
|
|
133
|
+
def precon(r):
|
|
134
|
+
r_blocks = r.reshape((n_block, 3))
|
|
135
|
+
z_blocks = jnp.einsum("bij,bj->bi", inv_blocks, r_blocks)
|
|
136
|
+
return z_blocks.reshape((-1,))
|
|
137
|
+
|
|
138
|
+
return precon
|
|
139
|
+
|
|
140
|
+
def expand_full(u_free: jnp.ndarray) -> jnp.ndarray:
|
|
141
|
+
if dir_dofs is None:
|
|
142
|
+
return u_free
|
|
143
|
+
u_full = jnp.zeros((space.n_dofs,), dtype=u_free.dtype)
|
|
144
|
+
u_full = u_full.at[free_dofs_j].set(u_free)
|
|
145
|
+
u_full = u_full.at[dir_dofs_j].set(dir_vals_j)
|
|
146
|
+
return u_full
|
|
147
|
+
|
|
148
|
+
def eval_residual(u_free_vec):
|
|
149
|
+
"""Residual on free DOFs only."""
|
|
150
|
+
u_full = expand_full(u_free_vec)
|
|
151
|
+
R_full = assemble_R(u_full)
|
|
152
|
+
if external_vector is not None:
|
|
153
|
+
R_full = R_full - external_vector
|
|
154
|
+
R_free = R_full[free_dofs_j]
|
|
155
|
+
res_inf = float(jnp.linalg.norm(R_free, ord=jnp.inf))
|
|
156
|
+
res_two = float(jnp.linalg.norm(R_free, ord=2))
|
|
157
|
+
return R_free, res_inf, res_two, u_full
|
|
158
|
+
|
|
159
|
+
# Pre-jitted element kernels to avoid recompiling inside Newton
|
|
160
|
+
res_kernel = make_element_residual_kernel(res_form, params)
|
|
161
|
+
jac_kernel = make_element_jacobian_kernel(res_form, params)
|
|
162
|
+
|
|
163
|
+
def assemble_R(u_full_vec):
|
|
164
|
+
return assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
|
|
165
|
+
|
|
166
|
+
eff_linear_tol = linear_tol if linear_tol is not None else max(0.1 * tol, 1e-12)
|
|
167
|
+
|
|
168
|
+
def assemble_J(u_full_vec):
|
|
169
|
+
return assemble_jacobian_scatter(
|
|
170
|
+
space,
|
|
171
|
+
res_form,
|
|
172
|
+
u_full_vec,
|
|
173
|
+
params,
|
|
174
|
+
kernel=jac_kernel,
|
|
175
|
+
sparse=True,
|
|
176
|
+
return_flux_matrix=True,
|
|
177
|
+
pattern=J_pattern,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Initial residual/Jacobian
|
|
181
|
+
R_full_init = assemble_R(expand_full(u))
|
|
182
|
+
if external_vector is not None:
|
|
183
|
+
R_full_init = R_full_init - external_vector
|
|
184
|
+
finite_init = jnp.all(jnp.isfinite(R_full_init))
|
|
185
|
+
if not bool(jax.block_until_ready(finite_init)):
|
|
186
|
+
n_bad = int(jnp.size(R_full_init) - jnp.count_nonzero(jnp.isfinite(R_full_init)))
|
|
187
|
+
rows_dbg, data_dbg, n_dofs_dbg = assemble_residual_scatter(
|
|
188
|
+
space, res_form, expand_full(u), params, sparse=True
|
|
189
|
+
)
|
|
190
|
+
rows_np = np.asarray(rows_dbg)
|
|
191
|
+
data_np = np.asarray(data_dbg)
|
|
192
|
+
bad_data = np.count_nonzero(~np.isfinite(data_np))
|
|
193
|
+
row_min = int(rows_np.min()) if rows_np.size else -1
|
|
194
|
+
row_max = int(rows_np.max()) if rows_np.size else -1
|
|
195
|
+
raise RuntimeError(f"[newton] init residual has non-finite entries: {n_bad}")
|
|
196
|
+
R_free = R_full_init[free_dofs_j]
|
|
197
|
+
res0_inf = float(jnp.linalg.norm(R_free, ord=jnp.inf))
|
|
198
|
+
res0_two = float(jnp.linalg.norm(R_free, ord=2))
|
|
199
|
+
u_full = expand_full(u)
|
|
200
|
+
if res0_inf == 0.0:
|
|
201
|
+
return expand_full(u), SolverResult(
|
|
202
|
+
converged=True,
|
|
203
|
+
iters=0,
|
|
204
|
+
residual_norm=0.0,
|
|
205
|
+
residual0=0.0,
|
|
206
|
+
rel_residual=0.0,
|
|
207
|
+
tol=tol,
|
|
208
|
+
atol=atol,
|
|
209
|
+
stopping_criterion=max(atol, tol * 0.0),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if callback is not None:
|
|
213
|
+
callback({"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan})
|
|
214
|
+
|
|
215
|
+
J = assemble_J(u_full)
|
|
216
|
+
finite_j = jnp.all(jnp.isfinite(J.data))
|
|
217
|
+
if not bool(jax.block_until_ready(finite_j)):
|
|
218
|
+
n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
|
|
219
|
+
raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
|
|
220
|
+
J_free = restrict_free_matrix(J)
|
|
221
|
+
for k in range(maxiter):
|
|
222
|
+
# --- Newton residual (iteration start) ---
|
|
223
|
+
t_iter0 = time.perf_counter()
|
|
224
|
+
|
|
225
|
+
# Always log this to show progress.
|
|
226
|
+
res_prev_inf = jnp.linalg.norm(R_free, ord=jnp.inf)
|
|
227
|
+
res_prev_two = jnp.linalg.norm(R_free, ord=2)
|
|
228
|
+
# JAX is async; synchronize to ensure logs are emitted.
|
|
229
|
+
res_prev_inf_f = float(jax.block_until_ready(res_prev_inf))
|
|
230
|
+
res_prev_two_f = float(jax.block_until_ready(res_prev_two))
|
|
231
|
+
if not (np.isfinite(res_prev_inf_f) and np.isfinite(res_prev_two_f)):
|
|
232
|
+
raise RuntimeError("[newton] residual became non-finite; aborting.")
|
|
233
|
+
|
|
234
|
+
crit = max(atol, tol * res0_inf)
|
|
235
|
+
print(
|
|
236
|
+
f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}",
|
|
237
|
+
flush=True,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# --- Linear solve (J_free * du = -R_free) ---
|
|
241
|
+
rhs = jnp.asarray(-R_free, dtype=u.dtype)
|
|
242
|
+
|
|
243
|
+
# Separate preconditioner build time from linear solve time.
|
|
244
|
+
t_pre0 = time.perf_counter()
|
|
245
|
+
cg_precon = linear_preconditioner
|
|
246
|
+
linear_converged = True
|
|
247
|
+
linear_residual = None
|
|
248
|
+
lin_iters = None
|
|
249
|
+
|
|
250
|
+
if linear_solver in ("cg", "cg_jax", "cg_custom"):
|
|
251
|
+
# Preconditioner build
|
|
252
|
+
if linear_preconditioner == "jacobi":
|
|
253
|
+
print(f"[newton] k={k:02d} PRECOND jacobi: diag...", flush=True)
|
|
254
|
+
diag = jnp.asarray(J_free.diag(), dtype=rhs.dtype)
|
|
255
|
+
diag = jax.block_until_ready(diag) # If it's heavy, it blocks here.
|
|
256
|
+
inv_diag = jnp.where(diag != 0.0, 1.0 / diag, 0.0)
|
|
257
|
+
|
|
258
|
+
def cg_preconditioner_fn(r):
|
|
259
|
+
return inv_diag * r
|
|
260
|
+
|
|
261
|
+
cg_precon = cg_preconditioner_fn
|
|
262
|
+
|
|
263
|
+
elif linear_preconditioner == "block_jacobi":
|
|
264
|
+
print(f"[newton] k={k:02d} PRECOND block_jacobi: build...", flush=True)
|
|
265
|
+
cg_precon = build_block_jacobi(J_free)
|
|
266
|
+
# Sync point if build is heavy (apply once).
|
|
267
|
+
_ = jax.block_until_ready(cg_precon(rhs))
|
|
268
|
+
|
|
269
|
+
pre_dt = time.perf_counter() - t_pre0
|
|
270
|
+
|
|
271
|
+
# Linear solve
|
|
272
|
+
cg_solver = cg_solve_jax if linear_solver in ("cg", "cg_jax") else cg_solve
|
|
273
|
+
print(f"[linear] k={k:02d} {linear_solver}: solve...", flush=True)
|
|
274
|
+
t_lin0 = time.perf_counter()
|
|
275
|
+
du_free, lin_info = cg_solver(
|
|
276
|
+
J_free,
|
|
277
|
+
rhs,
|
|
278
|
+
tol=eff_linear_tol,
|
|
279
|
+
maxiter=linear_maxiter,
|
|
280
|
+
preconditioner=cg_precon,
|
|
281
|
+
)
|
|
282
|
+
du_free = jax.block_until_ready(du_free)
|
|
283
|
+
lin_dt = time.perf_counter() - t_lin0
|
|
284
|
+
|
|
285
|
+
linear_residual = lin_info.get("residual_norm")
|
|
286
|
+
linear_converged = bool(lin_info.get("converged", True))
|
|
287
|
+
lin_iters = lin_info.get("iters", None)
|
|
288
|
+
|
|
289
|
+
elif linear_solver in ("spsolve", "spdirect_solve_gpu"):
|
|
290
|
+
pre_dt = 0.0
|
|
291
|
+
print(f"[linear] k={k:02d} {linear_solver}: csr/slice...", flush=True)
|
|
292
|
+
t_lin0 = time.perf_counter()
|
|
293
|
+
J_csr = J.to_csr()
|
|
294
|
+
J_ff = J_csr[np.ix_(free_dofs, free_dofs)]
|
|
295
|
+
print(f"[linear] k={k:02d} {linear_solver}: solve...", flush=True)
|
|
296
|
+
if linear_solver == "spdirect_solve_gpu":
|
|
297
|
+
du_free = spdirect_solve_gpu(J_ff, rhs)
|
|
298
|
+
else:
|
|
299
|
+
du_free = spdirect_solve_cpu(J_ff, rhs)
|
|
300
|
+
du_free = jax.block_until_ready(du_free)
|
|
301
|
+
lin_dt = time.perf_counter() - t_lin0
|
|
302
|
+
lin_info = {"iters": 1, "converged": True}
|
|
303
|
+
linear_converged = True
|
|
304
|
+
lin_iters = 1
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
raise ValueError(f"Unknown linear solver: {linear_solver}")
|
|
308
|
+
|
|
309
|
+
lr = float(linear_residual) if linear_residual is not None else float("nan")
|
|
310
|
+
print(
|
|
311
|
+
f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
|
|
312
|
+
f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
|
|
313
|
+
flush=True,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# --- Trial update & residual evaluation ---
|
|
317
|
+
# Start with alpha=1 and eval_residual (if heavy, assemble_R is heavy/compiled).
|
|
318
|
+
alpha = 1.0
|
|
319
|
+
u_trial_free = u + alpha * du_free
|
|
320
|
+
|
|
321
|
+
print(f"[newton] k={k:02d} EVAL alpha={alpha:.3e} ...", flush=True)
|
|
322
|
+
t_eval0 = time.perf_counter()
|
|
323
|
+
R_free_trial, res_trial_inf, res_trial_two, u_full_trial = eval_residual(u_trial_free)
|
|
324
|
+
# eval_residual casts to float so it usually syncs, but keep this for safety.
|
|
325
|
+
_ = jax.block_until_ready(R_free_trial)
|
|
326
|
+
eval_dt = time.perf_counter() - t_eval0
|
|
327
|
+
|
|
328
|
+
# --- Backtracking line search ---
|
|
329
|
+
if line_search:
|
|
330
|
+
accepted = False
|
|
331
|
+
ls_used = 0
|
|
332
|
+
for ls_iter in range(max_ls):
|
|
333
|
+
ls_used = ls_iter + 1
|
|
334
|
+
# Armijo on 2-norm
|
|
335
|
+
if res_trial_two <= (1.0 - ls_c * alpha) * res_prev_two_f:
|
|
336
|
+
accepted = True
|
|
337
|
+
break
|
|
338
|
+
alpha *= 0.5
|
|
339
|
+
u_trial_free = u + alpha * du_free
|
|
340
|
+
t_eval0 = time.perf_counter()
|
|
341
|
+
R_free_trial, res_trial_inf, res_trial_two, u_full_trial = eval_residual(u_trial_free)
|
|
342
|
+
_ = jax.block_until_ready(R_free_trial)
|
|
343
|
+
eval_dt += time.perf_counter() - t_eval0 # Accumulate eval time.
|
|
344
|
+
|
|
345
|
+
print(
|
|
346
|
+
f"[newton] k={k:02d} LS done alpha={alpha:.3e} accepted={accepted} steps={ls_used} |R|inf={res_trial_inf:.3e} |R|2={res_trial_two:.3e} eval_dt={eval_dt:.3f}s",
|
|
347
|
+
flush=True,
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
print(
|
|
351
|
+
f"[newton] k={k:02d} STEP alpha={alpha:.3e} |R|inf={res_trial_inf:.3e} |R|2={res_trial_two:.3e} eval_dt={eval_dt:.3f}s",
|
|
352
|
+
flush=True,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# --- Commit update ---
|
|
356
|
+
u = u_trial_free
|
|
357
|
+
R_free = R_free_trial
|
|
358
|
+
u_full = u_full_trial
|
|
359
|
+
|
|
360
|
+
# Step norm (minimize host transfer: compute in jnp then sync to float).
|
|
361
|
+
step_norm = float(jax.block_until_ready(jnp.linalg.norm(alpha * du_free, ord=2)))
|
|
362
|
+
|
|
363
|
+
# callback
|
|
364
|
+
if callback is not None:
|
|
365
|
+
callback(
|
|
366
|
+
{
|
|
367
|
+
"iter": k + 1,
|
|
368
|
+
"res_inf": res_trial_inf,
|
|
369
|
+
"res_two": res_trial_two,
|
|
370
|
+
"rel_residual": res_trial_inf / res0_inf,
|
|
371
|
+
"alpha": alpha,
|
|
372
|
+
"step_norm": step_norm,
|
|
373
|
+
"linear_iters": lin_info.get("iters"),
|
|
374
|
+
"linear_converged": linear_converged,
|
|
375
|
+
"linear_residual": lr,
|
|
376
|
+
"nan_detected": bool(np.isnan(res_trial_inf)),
|
|
377
|
+
}
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# --- Convergence check ---
|
|
381
|
+
if res_trial_inf < crit and linear_converged and not np.isnan(res_trial_inf):
|
|
382
|
+
it_dt = time.perf_counter() - t_iter0
|
|
383
|
+
print(f"[newton] k={k:02d} CONVERGED dt={it_dt:.3f}s", flush=True)
|
|
384
|
+
return u_full, SolverResult(
|
|
385
|
+
converged=True,
|
|
386
|
+
iters=k + 1,
|
|
387
|
+
residual_norm=res_trial_inf,
|
|
388
|
+
residual0=res0_inf,
|
|
389
|
+
rel_residual=res_trial_inf / res0_inf,
|
|
390
|
+
line_search_steps=(0 if not line_search else ls_used),
|
|
391
|
+
linear_iters=lin_info.get("iters"),
|
|
392
|
+
linear_converged=linear_converged,
|
|
393
|
+
linear_residual=lr,
|
|
394
|
+
tol=tol,
|
|
395
|
+
atol=atol,
|
|
396
|
+
stopping_criterion=crit,
|
|
397
|
+
step_norm=step_norm,
|
|
398
|
+
stop_reason="converged",
|
|
399
|
+
nan_detected=bool(np.isnan(res_trial_inf)),
|
|
400
|
+
)
|
fluxfem/solver/result.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Any, List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class SolverResult:
|
|
9
|
+
"""Common solver result for linear/nonlinear solves."""
|
|
10
|
+
|
|
11
|
+
converged: bool
|
|
12
|
+
iters: int
|
|
13
|
+
|
|
14
|
+
residual_norm: Optional[float] = None
|
|
15
|
+
residual0: Optional[float] = None
|
|
16
|
+
rel_residual: Optional[float] = None
|
|
17
|
+
|
|
18
|
+
line_search_steps: int = 0
|
|
19
|
+
|
|
20
|
+
# linear-solver stats (for Newton inner solve or standalone linear solve)
|
|
21
|
+
linear_iters: Optional[int] = None
|
|
22
|
+
linear_converged: Optional[bool] = None
|
|
23
|
+
linear_residual: Optional[float] = None
|
|
24
|
+
|
|
25
|
+
tol: Optional[float] = None
|
|
26
|
+
atol: Optional[float] = None
|
|
27
|
+
stopping_criterion: Optional[float] = None
|
|
28
|
+
step_norm: Optional[float] = None
|
|
29
|
+
|
|
30
|
+
stop_reason: Optional[str] = None # converged|maxiter|linfail|nan|exception|unknown
|
|
31
|
+
nan_detected: bool = False
|
|
32
|
+
|
|
33
|
+
def __str__(self) -> str: # pragma: no cover - simple formatting
|
|
34
|
+
status = "converged" if self.converged else "not converged"
|
|
35
|
+
parts = [f"{status} in {self.iters} iters"]
|
|
36
|
+
if self.residual_norm is not None:
|
|
37
|
+
parts.append(f"||R||={self.residual_norm:.3e}")
|
|
38
|
+
if self.residual0 is not None:
|
|
39
|
+
parts.append(f"||R0||={self.residual0:.3e}")
|
|
40
|
+
if self.rel_residual is not None:
|
|
41
|
+
parts.append(f"rel={self.rel_residual:.3e}")
|
|
42
|
+
if self.tol is not None:
|
|
43
|
+
parts.append(f"tol={self.tol:.1e}")
|
|
44
|
+
if self.atol is not None and self.atol > 0:
|
|
45
|
+
parts.append(f"atol={self.atol:.1e}")
|
|
46
|
+
if self.stopping_criterion is not None:
|
|
47
|
+
parts.append(f"crit={self.stopping_criterion:.3e}")
|
|
48
|
+
if self.step_norm is not None:
|
|
49
|
+
parts.append(f"step2={self.step_norm:.3e}")
|
|
50
|
+
if self.line_search_steps:
|
|
51
|
+
parts.append(f"ls_steps={self.line_search_steps}")
|
|
52
|
+
if self.linear_iters is not None:
|
|
53
|
+
parts.append(f"lin_iters={self.linear_iters}")
|
|
54
|
+
if self.linear_converged is not None:
|
|
55
|
+
parts.append(f"lin_conv={self.linear_converged}")
|
|
56
|
+
if self.linear_residual is not None:
|
|
57
|
+
parts.append(f"lin_res={self.linear_residual:.3e}")
|
|
58
|
+
if self.stop_reason:
|
|
59
|
+
parts.append(f"reason={self.stop_reason}")
|
|
60
|
+
if self.nan_detected:
|
|
61
|
+
parts.append("nan_detected=True")
|
|
62
|
+
return ", ".join(parts)
|