fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/solver/solve_runner.py
CHANGED
|
@@ -2,21 +2,35 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import Any, Callable, Iterable, List, Sequence, TYPE_CHECKING, TypeAlias
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import jax.numpy as jnp
|
|
9
10
|
|
|
10
|
-
from ..core.assembly import assemble_bilinear_form
|
|
11
|
+
from ..core.assembly import FormKernel, ResidualForm, assemble_bilinear_form
|
|
11
12
|
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
12
13
|
from .cg import cg_solve, cg_solve_jax
|
|
14
|
+
from .petsc import petsc_shell_solve
|
|
13
15
|
from .sparse import FluxSparseMatrix
|
|
14
|
-
from .dirichlet import expand_dirichlet_solution
|
|
16
|
+
from .dirichlet import DirichletBC, expand_dirichlet_solution
|
|
15
17
|
from .newton import newton_solve
|
|
16
18
|
from .result import SolverResult
|
|
17
19
|
from .history import NewtonIterRecord, LoadStepResult
|
|
18
20
|
from ..tools.timer import SectionTimer, NullTimer
|
|
19
21
|
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from jax import Array as JaxArray
|
|
24
|
+
|
|
25
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
26
|
+
else:
|
|
27
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
28
|
+
DirichletLike: TypeAlias = tuple[np.ndarray, np.ndarray]
|
|
29
|
+
ExtraTerm: TypeAlias = Callable[
|
|
30
|
+
[np.ndarray],
|
|
31
|
+
tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, dict[str, Any]] | None,
|
|
32
|
+
]
|
|
33
|
+
|
|
20
34
|
|
|
21
35
|
@dataclass
|
|
22
36
|
class NonlinearAnalysis:
|
|
@@ -35,6 +49,8 @@ class NonlinearAnalysis:
|
|
|
35
49
|
Unscaled external load vector (scaled by load factor in `external_for_load`).
|
|
36
50
|
dirichlet : tuple | None
|
|
37
51
|
(dofs, values) for Dirichlet boundary conditions.
|
|
52
|
+
extra_terms : list[callable] | None
|
|
53
|
+
Optional extra term assemblers returning (K, f[, metrics]).
|
|
38
54
|
jacobian_pattern : Any | None
|
|
39
55
|
Optional sparsity pattern to reuse between load steps.
|
|
40
56
|
dtype : Any
|
|
@@ -42,14 +58,19 @@ class NonlinearAnalysis:
|
|
|
42
58
|
"""
|
|
43
59
|
|
|
44
60
|
space: Any
|
|
45
|
-
residual_form: Any
|
|
61
|
+
residual_form: ResidualForm[Any]
|
|
46
62
|
params: Any
|
|
47
|
-
base_external_vector:
|
|
48
|
-
dirichlet:
|
|
63
|
+
base_external_vector: ArrayLike | None = None
|
|
64
|
+
dirichlet: DirichletLike | None = None
|
|
65
|
+
extra_terms: list[ExtraTerm] | None = None
|
|
49
66
|
jacobian_pattern: Any | None = None
|
|
50
67
|
dtype: Any = jnp.float64
|
|
51
68
|
|
|
52
|
-
def
|
|
69
|
+
def __post_init__(self) -> None:
|
|
70
|
+
if isinstance(self.dirichlet, DirichletBC):
|
|
71
|
+
self.dirichlet = self.dirichlet.as_tuple()
|
|
72
|
+
|
|
73
|
+
def external_for_load(self, load_factor: float) -> ArrayLike | None:
|
|
53
74
|
if self.base_external_vector is None:
|
|
54
75
|
return None
|
|
55
76
|
return jnp.asarray(load_factor * self.base_external_vector, dtype=self.dtype)
|
|
@@ -71,6 +92,7 @@ class NewtonLoopConfig:
|
|
|
71
92
|
linear_maxiter: int | None = None
|
|
72
93
|
linear_tol: float | None = None
|
|
73
94
|
linear_preconditioner: Any | None = None
|
|
95
|
+
matfree_mode: str = "linearize"
|
|
74
96
|
load_sequence: Sequence[float] | None = None
|
|
75
97
|
n_steps: int = 1
|
|
76
98
|
|
|
@@ -95,17 +117,18 @@ class NewtonSolveRunner:
|
|
|
95
117
|
def __init__(self, analysis: NonlinearAnalysis, config: NewtonLoopConfig):
|
|
96
118
|
self.analysis = analysis
|
|
97
119
|
self.config = config
|
|
120
|
+
self._matfree_cache: dict[str, Any] = {}
|
|
98
121
|
|
|
99
122
|
def run(
|
|
100
123
|
self,
|
|
101
|
-
u0=None,
|
|
124
|
+
u0: ArrayLike | None = None,
|
|
102
125
|
*,
|
|
103
126
|
load_sequence: Sequence[float] | None = None,
|
|
104
|
-
newton_callback: Callable | None = None,
|
|
127
|
+
newton_callback: Callable[[dict[str, Any]], None] | None = None,
|
|
105
128
|
step_callback: Callable[[LoadStepResult], None] | None = None,
|
|
106
129
|
timer: "SectionTimer | None" = None,
|
|
107
130
|
report_timing: bool = True
|
|
108
|
-
):
|
|
131
|
+
) -> tuple[np.ndarray, list[LoadStepResult]]:
|
|
109
132
|
"""
|
|
110
133
|
Execute Newton solves over the configured load schedule.
|
|
111
134
|
|
|
@@ -154,6 +177,16 @@ class NewtonSolveRunner:
|
|
|
154
177
|
schedule.append(lf_clamped)
|
|
155
178
|
prev = lf_clamped
|
|
156
179
|
history: List[LoadStepResult] = []
|
|
180
|
+
matfree_cache = None
|
|
181
|
+
if self.config.linear_preconditioner == "diag0":
|
|
182
|
+
n_free = self.analysis.space.n_dofs
|
|
183
|
+
if self.analysis.dirichlet is not None:
|
|
184
|
+
n_free -= len(self.analysis.dirichlet[0])
|
|
185
|
+
cached_free = self._matfree_cache.get("n_free_dofs")
|
|
186
|
+
if cached_free is not None and cached_free != n_free:
|
|
187
|
+
self._matfree_cache.clear()
|
|
188
|
+
self._matfree_cache["n_free_dofs"] = n_free
|
|
189
|
+
matfree_cache = self._matfree_cache
|
|
157
190
|
for step_i, load_factor in enumerate(schedule, start=1):
|
|
158
191
|
with timer.section("step"):
|
|
159
192
|
external = self.analysis.external_for_load(load_factor)
|
|
@@ -203,6 +236,8 @@ class NewtonSolveRunner:
|
|
|
203
236
|
linear_maxiter=self.config.linear_maxiter,
|
|
204
237
|
linear_tol=self.config.linear_tol,
|
|
205
238
|
linear_preconditioner=self.config.linear_preconditioner,
|
|
239
|
+
matfree_cache=matfree_cache,
|
|
240
|
+
matfree_mode=self.config.matfree_mode,
|
|
206
241
|
dirichlet=self.analysis.dirichlet,
|
|
207
242
|
line_search=self.config.line_search,
|
|
208
243
|
max_ls=self.config.max_ls,
|
|
@@ -210,6 +245,7 @@ class NewtonSolveRunner:
|
|
|
210
245
|
external_vector=external,
|
|
211
246
|
callback=cb,
|
|
212
247
|
jacobian_pattern=self.analysis.jacobian_pattern,
|
|
248
|
+
extra_terms=self.analysis.extra_terms,
|
|
213
249
|
)
|
|
214
250
|
exception = None
|
|
215
251
|
except Exception as e: # pragma: no cover - defensive
|
|
@@ -256,7 +292,9 @@ class NewtonSolveRunner:
|
|
|
256
292
|
return u, history
|
|
257
293
|
|
|
258
294
|
|
|
259
|
-
def _condense_flux_dirichlet(
|
|
295
|
+
def _condense_flux_dirichlet(
|
|
296
|
+
K: FluxSparseMatrix, F: ArrayLike, dirichlet: DirichletLike
|
|
297
|
+
) -> tuple[Any, np.ndarray, np.ndarray | None, np.ndarray, np.ndarray, np.ndarray]:
|
|
260
298
|
dir_dofs, dir_vals = dirichlet
|
|
261
299
|
dir_arr = np.asarray(dir_dofs, dtype=int)
|
|
262
300
|
dir_vals_arr = np.asarray(dir_vals, dtype=float)
|
|
@@ -274,11 +312,12 @@ def _condense_flux_dirichlet(K: FluxSparseMatrix, F, dirichlet):
|
|
|
274
312
|
|
|
275
313
|
def solve_nonlinear(
|
|
276
314
|
space,
|
|
277
|
-
residual_form,
|
|
278
|
-
params,
|
|
315
|
+
residual_form: ResidualForm[Any],
|
|
316
|
+
params: Any,
|
|
279
317
|
*,
|
|
280
|
-
dirichlet:
|
|
281
|
-
base_external_vector=None,
|
|
318
|
+
dirichlet: DirichletLike | None = None,
|
|
319
|
+
base_external_vector: ArrayLike | None = None,
|
|
320
|
+
extra_terms: list[ExtraTerm] | None = None,
|
|
282
321
|
dtype=jnp.float64,
|
|
283
322
|
maxiter: int = 20,
|
|
284
323
|
tol: float = 1e-8,
|
|
@@ -287,13 +326,14 @@ def solve_nonlinear(
|
|
|
287
326
|
linear_maxiter: int | None = None,
|
|
288
327
|
linear_tol: float | None = None,
|
|
289
328
|
linear_preconditioner=None,
|
|
329
|
+
matfree_mode: str = "linearize",
|
|
290
330
|
line_search: bool = False,
|
|
291
331
|
max_ls: int = 10,
|
|
292
332
|
ls_c: float = 1e-4,
|
|
293
333
|
n_steps: int = 1,
|
|
294
334
|
jacobian_pattern=None,
|
|
295
|
-
u0=None,
|
|
296
|
-
):
|
|
335
|
+
u0: ArrayLike | None = None,
|
|
336
|
+
) -> tuple[np.ndarray, list[LoadStepResult]]:
|
|
297
337
|
"""
|
|
298
338
|
Convenience wrapper: build NonlinearAnalysis and run NewtonSolveRunner.
|
|
299
339
|
"""
|
|
@@ -303,6 +343,7 @@ def solve_nonlinear(
|
|
|
303
343
|
params=params,
|
|
304
344
|
base_external_vector=base_external_vector,
|
|
305
345
|
dirichlet=dirichlet,
|
|
346
|
+
extra_terms=extra_terms,
|
|
306
347
|
dtype=dtype,
|
|
307
348
|
jacobian_pattern=jacobian_pattern,
|
|
308
349
|
)
|
|
@@ -314,6 +355,7 @@ def solve_nonlinear(
|
|
|
314
355
|
linear_maxiter=linear_maxiter,
|
|
315
356
|
linear_tol=linear_tol,
|
|
316
357
|
linear_preconditioner=linear_preconditioner,
|
|
358
|
+
matfree_mode=matfree_mode,
|
|
317
359
|
line_search=line_search,
|
|
318
360
|
max_ls=max_ls,
|
|
319
361
|
ls_c=ls_c,
|
|
@@ -335,10 +377,10 @@ class LinearAnalysis:
|
|
|
335
377
|
"""
|
|
336
378
|
|
|
337
379
|
space: Any
|
|
338
|
-
bilinear_form: Any
|
|
380
|
+
bilinear_form: FormKernel[Any]
|
|
339
381
|
params: Any
|
|
340
|
-
base_rhs_vector:
|
|
341
|
-
dirichlet:
|
|
382
|
+
base_rhs_vector: ArrayLike
|
|
383
|
+
dirichlet: DirichletLike | None = None
|
|
342
384
|
pattern: Any | None = None
|
|
343
385
|
dtype: Any = jnp.float64
|
|
344
386
|
|
|
@@ -349,7 +391,7 @@ class LinearAnalysis:
|
|
|
349
391
|
pattern=self.pattern,
|
|
350
392
|
)
|
|
351
393
|
|
|
352
|
-
def rhs_for_load(self, load_factor: float):
|
|
394
|
+
def rhs_for_load(self, load_factor: float) -> ArrayLike:
|
|
353
395
|
return jnp.asarray(load_factor * self.base_rhs_vector, dtype=self.dtype)
|
|
354
396
|
|
|
355
397
|
|
|
@@ -359,10 +401,42 @@ class LinearSolveConfig:
|
|
|
359
401
|
Control parameters for the linear solve with optional load scaling.
|
|
360
402
|
"""
|
|
361
403
|
|
|
362
|
-
method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom"
|
|
404
|
+
method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom" | "petsc_shell"
|
|
363
405
|
tol: float = 1e-8
|
|
364
406
|
maxiter: int | None = None
|
|
365
407
|
preconditioner: Any | None = None
|
|
408
|
+
ksp_type: str | None = None
|
|
409
|
+
pc_type: str | None = None
|
|
410
|
+
ksp_rtol: float | None = None
|
|
411
|
+
ksp_atol: float | None = None
|
|
412
|
+
ksp_max_it: int | None = None
|
|
413
|
+
petsc_ksp_norm_type: str | None = None
|
|
414
|
+
petsc_ksp_monitor_true_residual: bool = False
|
|
415
|
+
petsc_ksp_converged_reason: bool = False
|
|
416
|
+
petsc_ksp_monitor_short: bool = False
|
|
417
|
+
petsc_shell_pmat: bool = False
|
|
418
|
+
petsc_shell_pmat_mode: str = "full"
|
|
419
|
+
petsc_shell_pmat_rebuild_iters: int | None = None
|
|
420
|
+
petsc_shell_fallback: bool = False
|
|
421
|
+
petsc_shell_fallback_ksp_types: tuple[str, ...] = ("bcgs", "gmres")
|
|
422
|
+
petsc_shell_fallback_rebuild_pmat: bool = True
|
|
423
|
+
|
|
424
|
+
@classmethod
|
|
425
|
+
def from_preset(cls, name: str) -> "LinearSolveConfig":
|
|
426
|
+
preset = name.lower()
|
|
427
|
+
if preset == "contact":
|
|
428
|
+
return cls(
|
|
429
|
+
method="petsc_shell",
|
|
430
|
+
ksp_type="bcgs",
|
|
431
|
+
pc_type="ilu",
|
|
432
|
+
petsc_shell_pmat=True,
|
|
433
|
+
petsc_shell_pmat_mode="full",
|
|
434
|
+
petsc_ksp_norm_type="unpreconditioned",
|
|
435
|
+
petsc_ksp_monitor_true_residual=True,
|
|
436
|
+
petsc_ksp_converged_reason=True,
|
|
437
|
+
petsc_shell_fallback=True,
|
|
438
|
+
)
|
|
439
|
+
raise ValueError(f"Unknown LinearSolveConfig preset: {name}")
|
|
366
440
|
|
|
367
441
|
|
|
368
442
|
@dataclass
|
|
@@ -381,7 +455,7 @@ class LinearStepResult:
|
|
|
381
455
|
"""
|
|
382
456
|
info: SolverResult
|
|
383
457
|
solve_time: float
|
|
384
|
-
u:
|
|
458
|
+
u: ArrayLike
|
|
385
459
|
|
|
386
460
|
|
|
387
461
|
class LinearSolveRunner:
|
|
@@ -392,6 +466,9 @@ class LinearSolveRunner:
|
|
|
392
466
|
def __init__(self, analysis: LinearAnalysis, config: LinearSolveConfig):
|
|
393
467
|
self.analysis = analysis
|
|
394
468
|
self.config = config
|
|
469
|
+
self._petsc_shell_pmat = None
|
|
470
|
+
self._petsc_shell_last_iters = None
|
|
471
|
+
self._petsc_shell_pmat_rebuilds = 0
|
|
395
472
|
|
|
396
473
|
def run(
|
|
397
474
|
self,
|
|
@@ -488,6 +565,114 @@ class LinearSolveRunner:
|
|
|
488
565
|
stop_reason=("converged" if lin_conv else "linfail"),
|
|
489
566
|
nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
|
|
490
567
|
)
|
|
568
|
+
elif self.config.method == "petsc_shell":
|
|
569
|
+
base_ksp_type = self.config.ksp_type or "gmres"
|
|
570
|
+
pc_type = self.config.pc_type if self.config.pc_type is not None else "none"
|
|
571
|
+
ksp_rtol = self.config.ksp_rtol if self.config.ksp_rtol is not None else self.config.tol
|
|
572
|
+
ksp_atol = self.config.ksp_atol
|
|
573
|
+
ksp_max_it = self.config.ksp_max_it if self.config.ksp_max_it is not None else self.config.maxiter
|
|
574
|
+
petsc_options = {}
|
|
575
|
+
if self.config.petsc_ksp_norm_type:
|
|
576
|
+
petsc_options["fluxfem_ksp_norm_type"] = self.config.petsc_ksp_norm_type
|
|
577
|
+
if self.config.petsc_ksp_monitor_true_residual:
|
|
578
|
+
petsc_options["fluxfem_ksp_monitor_true_residual"] = ""
|
|
579
|
+
if self.config.petsc_ksp_converged_reason:
|
|
580
|
+
petsc_options["fluxfem_ksp_converged_reason"] = ""
|
|
581
|
+
if self.config.petsc_ksp_monitor_short:
|
|
582
|
+
petsc_options["fluxfem_ksp_monitor_short"] = ""
|
|
583
|
+
if not petsc_options:
|
|
584
|
+
petsc_options = None
|
|
585
|
+
use_pmat = bool(self.config.petsc_shell_pmat)
|
|
586
|
+
rebuild_thresh = self.config.petsc_shell_pmat_rebuild_iters
|
|
587
|
+
if use_pmat:
|
|
588
|
+
pmat_mode = (self.config.petsc_shell_pmat_mode or "full").lower()
|
|
589
|
+
if pmat_mode == "none":
|
|
590
|
+
use_pmat = False
|
|
591
|
+
pmat = None
|
|
592
|
+
elif pmat_mode == "full":
|
|
593
|
+
pmat = K_ff
|
|
594
|
+
else:
|
|
595
|
+
warnings.warn(
|
|
596
|
+
f"petsc_shell_pmat_mode='{pmat_mode}' is not supported in runner; "
|
|
597
|
+
"falling back to 'full'.",
|
|
598
|
+
RuntimeWarning,
|
|
599
|
+
)
|
|
600
|
+
pmat = K_ff
|
|
601
|
+
if use_pmat:
|
|
602
|
+
if self._petsc_shell_pmat is None:
|
|
603
|
+
self._petsc_shell_pmat = pmat
|
|
604
|
+
self._petsc_shell_pmat_rebuilds += 1
|
|
605
|
+
elif rebuild_thresh is not None and self._petsc_shell_last_iters is not None:
|
|
606
|
+
if self._petsc_shell_last_iters > rebuild_thresh:
|
|
607
|
+
self._petsc_shell_pmat = pmat
|
|
608
|
+
self._petsc_shell_pmat_rebuilds += 1
|
|
609
|
+
pmat = self._petsc_shell_pmat
|
|
610
|
+
if not use_pmat:
|
|
611
|
+
pmat = None
|
|
612
|
+
|
|
613
|
+
def _attempt_solve(ksp_type: str):
|
|
614
|
+
return petsc_shell_solve(
|
|
615
|
+
K_ff,
|
|
616
|
+
F_free,
|
|
617
|
+
preconditioner=self.config.preconditioner,
|
|
618
|
+
ksp_type=ksp_type,
|
|
619
|
+
pc_type=pc_type,
|
|
620
|
+
rtol=ksp_rtol,
|
|
621
|
+
atol=ksp_atol,
|
|
622
|
+
max_it=ksp_max_it,
|
|
623
|
+
pmat=pmat,
|
|
624
|
+
options=petsc_options,
|
|
625
|
+
return_info=True,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
fallback_ksp = [base_ksp_type]
|
|
629
|
+
if self.config.petsc_shell_fallback:
|
|
630
|
+
for ksp in self.config.petsc_shell_fallback_ksp_types:
|
|
631
|
+
if ksp not in fallback_ksp:
|
|
632
|
+
fallback_ksp.append(ksp)
|
|
633
|
+
fallback_attempts = []
|
|
634
|
+
petsc_info = None
|
|
635
|
+
u_free = None
|
|
636
|
+
for ksp in fallback_ksp:
|
|
637
|
+
fallback_attempts.append(ksp)
|
|
638
|
+
u_free, petsc_info = _attempt_solve(ksp)
|
|
639
|
+
lin_conv = petsc_info.get("converged")
|
|
640
|
+
reason = petsc_info.get("reason")
|
|
641
|
+
if lin_conv is None and reason is not None:
|
|
642
|
+
lin_conv = reason > 0
|
|
643
|
+
if lin_conv:
|
|
644
|
+
break
|
|
645
|
+
if self.config.petsc_shell_fallback and use_pmat and self.config.petsc_shell_fallback_rebuild_pmat:
|
|
646
|
+
self._petsc_shell_pmat = pmat
|
|
647
|
+
self._petsc_shell_pmat_rebuilds += 1
|
|
648
|
+
lin_iters = petsc_info.get("iters")
|
|
649
|
+
lin_res = petsc_info.get("residual_norm")
|
|
650
|
+
lin_solve_dt = petsc_info.get("solve_time")
|
|
651
|
+
pc_setup_dt = petsc_info.get("pc_setup_time")
|
|
652
|
+
pmat_dt = petsc_info.get("pmat_build_time")
|
|
653
|
+
lin_conv = petsc_info.get("converged")
|
|
654
|
+
if lin_conv is None and petsc_info.get("reason") is not None:
|
|
655
|
+
lin_conv = petsc_info.get("reason") > 0
|
|
656
|
+
if lin_conv is None:
|
|
657
|
+
lin_conv = True
|
|
658
|
+
self._petsc_shell_last_iters = lin_iters
|
|
659
|
+
info = SolverResult(
|
|
660
|
+
converged=bool(lin_conv),
|
|
661
|
+
iters=int(lin_iters) if lin_iters is not None else 0,
|
|
662
|
+
linear_iters=int(lin_iters) if lin_iters is not None else None,
|
|
663
|
+
linear_converged=bool(lin_conv),
|
|
664
|
+
linear_residual=float(lin_res) if lin_res is not None else None,
|
|
665
|
+
linear_solve_time=float(lin_solve_dt) if lin_solve_dt is not None else None,
|
|
666
|
+
pc_setup_time=float(pc_setup_dt) if pc_setup_dt is not None else None,
|
|
667
|
+
pmat_build_time=float(pmat_dt) if pmat_dt is not None else None,
|
|
668
|
+
pmat_rebuilds=self._petsc_shell_pmat_rebuilds if use_pmat else None,
|
|
669
|
+
pmat_mode=self.config.petsc_shell_pmat_mode if use_pmat else None,
|
|
670
|
+
tol=self.config.tol,
|
|
671
|
+
stop_reason=("converged" if lin_conv else "linfail"),
|
|
672
|
+
nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
|
|
673
|
+
)
|
|
674
|
+
if len(fallback_attempts) > 1:
|
|
675
|
+
info.linear_fallbacks = fallback_attempts
|
|
491
676
|
else:
|
|
492
677
|
raise ValueError(f"Unknown linear solve method: {self.config.method}")
|
|
493
678
|
|
fluxfem/solver/solver.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from typing import Any, TYPE_CHECKING, TypeAlias
|
|
4
|
+
|
|
3
5
|
import numpy as np
|
|
4
6
|
import jax.numpy as jnp
|
|
5
7
|
|
|
6
8
|
from .cg import cg_solve, cg_solve_jax
|
|
7
9
|
from .newton import newton_solve
|
|
10
|
+
from .petsc import petsc_solve, petsc_shell_solve
|
|
8
11
|
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
9
12
|
from .dirichlet import (
|
|
13
|
+
DirichletBC,
|
|
10
14
|
condense_dirichlet_dense,
|
|
11
15
|
condense_dirichlet_fluxsparse,
|
|
12
16
|
expand_dirichlet_solution,
|
|
@@ -16,6 +20,16 @@ from .dirichlet import (
|
|
|
16
20
|
from .sparse import FluxSparseMatrix
|
|
17
21
|
from ..core.space import FESpace
|
|
18
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from jax import Array as JaxArray
|
|
25
|
+
|
|
26
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
27
|
+
else:
|
|
28
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
29
|
+
DirichletLike: TypeAlias = DirichletBC | tuple[np.ndarray, np.ndarray]
|
|
30
|
+
SolveInfo: TypeAlias = dict[str, Any]
|
|
31
|
+
SolveReturn: TypeAlias = tuple[np.ndarray, SolveInfo]
|
|
32
|
+
|
|
19
33
|
|
|
20
34
|
class LinearSolver:
|
|
21
35
|
"""
|
|
@@ -30,7 +44,7 @@ class LinearSolver:
|
|
|
30
44
|
self.tol = tol
|
|
31
45
|
self.maxiter = maxiter
|
|
32
46
|
|
|
33
|
-
def _solve_free(self, A, b):
|
|
47
|
+
def _solve_free(self, A: Any, b: Any) -> SolveReturn:
|
|
34
48
|
if self.method == "cg":
|
|
35
49
|
x, info = cg_solve_jax(A, b, tol=self.tol, maxiter=self.maxiter)
|
|
36
50
|
return np.asarray(x), {"iters": info.get("iters"), "converged": info.get("converged", True)}
|
|
@@ -46,25 +60,34 @@ class LinearSolver:
|
|
|
46
60
|
elif self.method == "spdirect_solve_gpu":
|
|
47
61
|
x = spdirect_solve_gpu(A, b)
|
|
48
62
|
return np.asarray(x), {"iters": 1, "converged": True}
|
|
63
|
+
elif self.method == "petsc":
|
|
64
|
+
x = petsc_solve(A, b)
|
|
65
|
+
return np.asarray(x), {"iters": None, "converged": True}
|
|
66
|
+
elif self.method == "petsc_shell":
|
|
67
|
+
x = petsc_shell_solve(A, b)
|
|
68
|
+
return np.asarray(x), {"iters": None, "converged": True}
|
|
49
69
|
else:
|
|
50
70
|
raise ValueError(f"Unknown linear method: {self.method}")
|
|
51
71
|
|
|
52
72
|
def solve(
|
|
53
73
|
self,
|
|
54
|
-
A,
|
|
55
|
-
b,
|
|
74
|
+
A: Any,
|
|
75
|
+
b: Any,
|
|
56
76
|
*,
|
|
57
|
-
dirichlet=None,
|
|
77
|
+
dirichlet: DirichletLike | None = None,
|
|
58
78
|
dirichlet_mode: str = "condense",
|
|
59
79
|
n_total: int | None = None,
|
|
60
|
-
):
|
|
80
|
+
) -> SolveReturn:
|
|
61
81
|
if dirichlet is None:
|
|
62
82
|
return self._solve_free(A, b)
|
|
63
83
|
|
|
64
84
|
if dirichlet_mode not in ("condense", "enforce"):
|
|
65
85
|
raise ValueError("dirichlet_mode must be 'condense' or 'enforce'.")
|
|
66
86
|
|
|
67
|
-
|
|
87
|
+
if isinstance(dirichlet, DirichletBC):
|
|
88
|
+
dir_dofs, dir_vals = dirichlet.as_tuple()
|
|
89
|
+
else:
|
|
90
|
+
dir_dofs, dir_vals = dirichlet
|
|
68
91
|
if dirichlet_mode == "enforce":
|
|
69
92
|
if isinstance(A, FluxSparseMatrix):
|
|
70
93
|
A_bc, b_bc = enforce_dirichlet_sparse(A, b, dir_dofs, dir_vals)
|
|
@@ -98,8 +121,8 @@ class NonlinearSolver:
|
|
|
98
121
|
def __init__(
|
|
99
122
|
self,
|
|
100
123
|
space: FESpace,
|
|
101
|
-
res_form,
|
|
102
|
-
params,
|
|
124
|
+
res_form: Any,
|
|
125
|
+
params: Any,
|
|
103
126
|
*,
|
|
104
127
|
tol: float = 1e-8,
|
|
105
128
|
maxiter: int = 20,
|
|
@@ -108,10 +131,10 @@ class NonlinearSolver:
|
|
|
108
131
|
max_ls: int = 10,
|
|
109
132
|
ls_c: float = 1e-4,
|
|
110
133
|
linear_tol: float | None = None,
|
|
111
|
-
dirichlet=None,
|
|
112
|
-
external_vector=None,
|
|
134
|
+
dirichlet: DirichletLike | None = None,
|
|
135
|
+
external_vector: ArrayLike | None = None,
|
|
113
136
|
linear_maxiter: int | None = None,
|
|
114
|
-
jacobian_pattern=None,
|
|
137
|
+
jacobian_pattern: Any | None = None,
|
|
115
138
|
):
|
|
116
139
|
self.space = space
|
|
117
140
|
self.res_form = res_form
|
|
@@ -128,7 +151,7 @@ class NonlinearSolver:
|
|
|
128
151
|
self.linear_maxiter = linear_maxiter
|
|
129
152
|
self.jacobian_pattern = jacobian_pattern
|
|
130
153
|
|
|
131
|
-
def solve(self, u0):
|
|
154
|
+
def solve(self, u0: ArrayLike):
|
|
132
155
|
return newton_solve(
|
|
133
156
|
self.space,
|
|
134
157
|
self.res_form,
|