fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +68 -0
- fluxfem/core/__init__.py +115 -10
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +768 -7
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +316 -7
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +6 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/solver/solve_runner.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
+
import warnings
|
|
5
6
|
from typing import Any, Callable, Iterable, List, Sequence
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
@@ -10,6 +11,7 @@ import jax.numpy as jnp
|
|
|
10
11
|
from ..core.assembly import assemble_bilinear_form
|
|
11
12
|
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
12
13
|
from .cg import cg_solve, cg_solve_jax
|
|
14
|
+
from .petsc import petsc_shell_solve
|
|
13
15
|
from .sparse import FluxSparseMatrix
|
|
14
16
|
from .dirichlet import expand_dirichlet_solution
|
|
15
17
|
from .newton import newton_solve
|
|
@@ -35,6 +37,8 @@ class NonlinearAnalysis:
|
|
|
35
37
|
Unscaled external load vector (scaled by load factor in `external_for_load`).
|
|
36
38
|
dirichlet : tuple | None
|
|
37
39
|
(dofs, values) for Dirichlet boundary conditions.
|
|
40
|
+
extra_terms : list[callable] | None
|
|
41
|
+
Optional extra term assemblers returning (K, f[, metrics]).
|
|
38
42
|
jacobian_pattern : Any | None
|
|
39
43
|
Optional sparsity pattern to reuse between load steps.
|
|
40
44
|
dtype : Any
|
|
@@ -46,6 +50,7 @@ class NonlinearAnalysis:
|
|
|
46
50
|
params: Any
|
|
47
51
|
base_external_vector: Any | None = None
|
|
48
52
|
dirichlet: tuple | None = None
|
|
53
|
+
extra_terms: list[Callable] | None = None
|
|
49
54
|
jacobian_pattern: Any | None = None
|
|
50
55
|
dtype: Any = jnp.float64
|
|
51
56
|
|
|
@@ -71,6 +76,7 @@ class NewtonLoopConfig:
|
|
|
71
76
|
linear_maxiter: int | None = None
|
|
72
77
|
linear_tol: float | None = None
|
|
73
78
|
linear_preconditioner: Any | None = None
|
|
79
|
+
matfree_mode: str = "linearize"
|
|
74
80
|
load_sequence: Sequence[float] | None = None
|
|
75
81
|
n_steps: int = 1
|
|
76
82
|
|
|
@@ -95,6 +101,7 @@ class NewtonSolveRunner:
|
|
|
95
101
|
def __init__(self, analysis: NonlinearAnalysis, config: NewtonLoopConfig):
|
|
96
102
|
self.analysis = analysis
|
|
97
103
|
self.config = config
|
|
104
|
+
self._matfree_cache: dict[str, Any] = {}
|
|
98
105
|
|
|
99
106
|
def run(
|
|
100
107
|
self,
|
|
@@ -154,6 +161,16 @@ class NewtonSolveRunner:
|
|
|
154
161
|
schedule.append(lf_clamped)
|
|
155
162
|
prev = lf_clamped
|
|
156
163
|
history: List[LoadStepResult] = []
|
|
164
|
+
matfree_cache = None
|
|
165
|
+
if self.config.linear_preconditioner == "diag0":
|
|
166
|
+
n_free = self.analysis.space.n_dofs
|
|
167
|
+
if self.analysis.dirichlet is not None:
|
|
168
|
+
n_free -= len(self.analysis.dirichlet[0])
|
|
169
|
+
cached_free = self._matfree_cache.get("n_free_dofs")
|
|
170
|
+
if cached_free is not None and cached_free != n_free:
|
|
171
|
+
self._matfree_cache.clear()
|
|
172
|
+
self._matfree_cache["n_free_dofs"] = n_free
|
|
173
|
+
matfree_cache = self._matfree_cache
|
|
157
174
|
for step_i, load_factor in enumerate(schedule, start=1):
|
|
158
175
|
with timer.section("step"):
|
|
159
176
|
external = self.analysis.external_for_load(load_factor)
|
|
@@ -203,6 +220,8 @@ class NewtonSolveRunner:
|
|
|
203
220
|
linear_maxiter=self.config.linear_maxiter,
|
|
204
221
|
linear_tol=self.config.linear_tol,
|
|
205
222
|
linear_preconditioner=self.config.linear_preconditioner,
|
|
223
|
+
matfree_cache=matfree_cache,
|
|
224
|
+
matfree_mode=self.config.matfree_mode,
|
|
206
225
|
dirichlet=self.analysis.dirichlet,
|
|
207
226
|
line_search=self.config.line_search,
|
|
208
227
|
max_ls=self.config.max_ls,
|
|
@@ -210,6 +229,7 @@ class NewtonSolveRunner:
|
|
|
210
229
|
external_vector=external,
|
|
211
230
|
callback=cb,
|
|
212
231
|
jacobian_pattern=self.analysis.jacobian_pattern,
|
|
232
|
+
extra_terms=self.analysis.extra_terms,
|
|
213
233
|
)
|
|
214
234
|
exception = None
|
|
215
235
|
except Exception as e: # pragma: no cover - defensive
|
|
@@ -279,6 +299,7 @@ def solve_nonlinear(
|
|
|
279
299
|
*,
|
|
280
300
|
dirichlet: tuple | None = None,
|
|
281
301
|
base_external_vector=None,
|
|
302
|
+
extra_terms=None,
|
|
282
303
|
dtype=jnp.float64,
|
|
283
304
|
maxiter: int = 20,
|
|
284
305
|
tol: float = 1e-8,
|
|
@@ -287,6 +308,7 @@ def solve_nonlinear(
|
|
|
287
308
|
linear_maxiter: int | None = None,
|
|
288
309
|
linear_tol: float | None = None,
|
|
289
310
|
linear_preconditioner=None,
|
|
311
|
+
matfree_mode: str = "linearize",
|
|
290
312
|
line_search: bool = False,
|
|
291
313
|
max_ls: int = 10,
|
|
292
314
|
ls_c: float = 1e-4,
|
|
@@ -303,6 +325,7 @@ def solve_nonlinear(
|
|
|
303
325
|
params=params,
|
|
304
326
|
base_external_vector=base_external_vector,
|
|
305
327
|
dirichlet=dirichlet,
|
|
328
|
+
extra_terms=extra_terms,
|
|
306
329
|
dtype=dtype,
|
|
307
330
|
jacobian_pattern=jacobian_pattern,
|
|
308
331
|
)
|
|
@@ -314,6 +337,7 @@ def solve_nonlinear(
|
|
|
314
337
|
linear_maxiter=linear_maxiter,
|
|
315
338
|
linear_tol=linear_tol,
|
|
316
339
|
linear_preconditioner=linear_preconditioner,
|
|
340
|
+
matfree_mode=matfree_mode,
|
|
317
341
|
line_search=line_search,
|
|
318
342
|
max_ls=max_ls,
|
|
319
343
|
ls_c=ls_c,
|
|
@@ -359,10 +383,42 @@ class LinearSolveConfig:
|
|
|
359
383
|
Control parameters for the linear solve with optional load scaling.
|
|
360
384
|
"""
|
|
361
385
|
|
|
362
|
-
method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom"
|
|
386
|
+
method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom" | "petsc_shell"
|
|
363
387
|
tol: float = 1e-8
|
|
364
388
|
maxiter: int | None = None
|
|
365
389
|
preconditioner: Any | None = None
|
|
390
|
+
ksp_type: str | None = None
|
|
391
|
+
pc_type: str | None = None
|
|
392
|
+
ksp_rtol: float | None = None
|
|
393
|
+
ksp_atol: float | None = None
|
|
394
|
+
ksp_max_it: int | None = None
|
|
395
|
+
petsc_ksp_norm_type: str | None = None
|
|
396
|
+
petsc_ksp_monitor_true_residual: bool = False
|
|
397
|
+
petsc_ksp_converged_reason: bool = False
|
|
398
|
+
petsc_ksp_monitor_short: bool = False
|
|
399
|
+
petsc_shell_pmat: bool = False
|
|
400
|
+
petsc_shell_pmat_mode: str = "full"
|
|
401
|
+
petsc_shell_pmat_rebuild_iters: int | None = None
|
|
402
|
+
petsc_shell_fallback: bool = False
|
|
403
|
+
petsc_shell_fallback_ksp_types: tuple[str, ...] = ("bcgs", "gmres")
|
|
404
|
+
petsc_shell_fallback_rebuild_pmat: bool = True
|
|
405
|
+
|
|
406
|
+
@classmethod
|
|
407
|
+
def from_preset(cls, name: str) -> "LinearSolveConfig":
|
|
408
|
+
preset = name.lower()
|
|
409
|
+
if preset == "contact":
|
|
410
|
+
return cls(
|
|
411
|
+
method="petsc_shell",
|
|
412
|
+
ksp_type="bcgs",
|
|
413
|
+
pc_type="ilu",
|
|
414
|
+
petsc_shell_pmat=True,
|
|
415
|
+
petsc_shell_pmat_mode="full",
|
|
416
|
+
petsc_ksp_norm_type="unpreconditioned",
|
|
417
|
+
petsc_ksp_monitor_true_residual=True,
|
|
418
|
+
petsc_ksp_converged_reason=True,
|
|
419
|
+
petsc_shell_fallback=True,
|
|
420
|
+
)
|
|
421
|
+
raise ValueError(f"Unknown LinearSolveConfig preset: {name}")
|
|
366
422
|
|
|
367
423
|
|
|
368
424
|
@dataclass
|
|
@@ -392,6 +448,9 @@ class LinearSolveRunner:
|
|
|
392
448
|
def __init__(self, analysis: LinearAnalysis, config: LinearSolveConfig):
|
|
393
449
|
self.analysis = analysis
|
|
394
450
|
self.config = config
|
|
451
|
+
self._petsc_shell_pmat = None
|
|
452
|
+
self._petsc_shell_last_iters = None
|
|
453
|
+
self._petsc_shell_pmat_rebuilds = 0
|
|
395
454
|
|
|
396
455
|
def run(
|
|
397
456
|
self,
|
|
@@ -488,6 +547,114 @@ class LinearSolveRunner:
|
|
|
488
547
|
stop_reason=("converged" if lin_conv else "linfail"),
|
|
489
548
|
nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
|
|
490
549
|
)
|
|
550
|
+
elif self.config.method == "petsc_shell":
|
|
551
|
+
base_ksp_type = self.config.ksp_type or "gmres"
|
|
552
|
+
pc_type = self.config.pc_type if self.config.pc_type is not None else "none"
|
|
553
|
+
ksp_rtol = self.config.ksp_rtol if self.config.ksp_rtol is not None else self.config.tol
|
|
554
|
+
ksp_atol = self.config.ksp_atol
|
|
555
|
+
ksp_max_it = self.config.ksp_max_it if self.config.ksp_max_it is not None else self.config.maxiter
|
|
556
|
+
petsc_options = {}
|
|
557
|
+
if self.config.petsc_ksp_norm_type:
|
|
558
|
+
petsc_options["fluxfem_ksp_norm_type"] = self.config.petsc_ksp_norm_type
|
|
559
|
+
if self.config.petsc_ksp_monitor_true_residual:
|
|
560
|
+
petsc_options["fluxfem_ksp_monitor_true_residual"] = ""
|
|
561
|
+
if self.config.petsc_ksp_converged_reason:
|
|
562
|
+
petsc_options["fluxfem_ksp_converged_reason"] = ""
|
|
563
|
+
if self.config.petsc_ksp_monitor_short:
|
|
564
|
+
petsc_options["fluxfem_ksp_monitor_short"] = ""
|
|
565
|
+
if not petsc_options:
|
|
566
|
+
petsc_options = None
|
|
567
|
+
use_pmat = bool(self.config.petsc_shell_pmat)
|
|
568
|
+
rebuild_thresh = self.config.petsc_shell_pmat_rebuild_iters
|
|
569
|
+
if use_pmat:
|
|
570
|
+
pmat_mode = (self.config.petsc_shell_pmat_mode or "full").lower()
|
|
571
|
+
if pmat_mode == "none":
|
|
572
|
+
use_pmat = False
|
|
573
|
+
pmat = None
|
|
574
|
+
elif pmat_mode == "full":
|
|
575
|
+
pmat = K_ff
|
|
576
|
+
else:
|
|
577
|
+
warnings.warn(
|
|
578
|
+
f"petsc_shell_pmat_mode='{pmat_mode}' is not supported in runner; "
|
|
579
|
+
"falling back to 'full'.",
|
|
580
|
+
RuntimeWarning,
|
|
581
|
+
)
|
|
582
|
+
pmat = K_ff
|
|
583
|
+
if use_pmat:
|
|
584
|
+
if self._petsc_shell_pmat is None:
|
|
585
|
+
self._petsc_shell_pmat = pmat
|
|
586
|
+
self._petsc_shell_pmat_rebuilds += 1
|
|
587
|
+
elif rebuild_thresh is not None and self._petsc_shell_last_iters is not None:
|
|
588
|
+
if self._petsc_shell_last_iters > rebuild_thresh:
|
|
589
|
+
self._petsc_shell_pmat = pmat
|
|
590
|
+
self._petsc_shell_pmat_rebuilds += 1
|
|
591
|
+
pmat = self._petsc_shell_pmat
|
|
592
|
+
if not use_pmat:
|
|
593
|
+
pmat = None
|
|
594
|
+
|
|
595
|
+
def _attempt_solve(ksp_type: str):
|
|
596
|
+
return petsc_shell_solve(
|
|
597
|
+
K_ff,
|
|
598
|
+
F_free,
|
|
599
|
+
preconditioner=self.config.preconditioner,
|
|
600
|
+
ksp_type=ksp_type,
|
|
601
|
+
pc_type=pc_type,
|
|
602
|
+
rtol=ksp_rtol,
|
|
603
|
+
atol=ksp_atol,
|
|
604
|
+
max_it=ksp_max_it,
|
|
605
|
+
pmat=pmat,
|
|
606
|
+
options=petsc_options,
|
|
607
|
+
return_info=True,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
fallback_ksp = [base_ksp_type]
|
|
611
|
+
if self.config.petsc_shell_fallback:
|
|
612
|
+
for ksp in self.config.petsc_shell_fallback_ksp_types:
|
|
613
|
+
if ksp not in fallback_ksp:
|
|
614
|
+
fallback_ksp.append(ksp)
|
|
615
|
+
fallback_attempts = []
|
|
616
|
+
petsc_info = None
|
|
617
|
+
u_free = None
|
|
618
|
+
for ksp in fallback_ksp:
|
|
619
|
+
fallback_attempts.append(ksp)
|
|
620
|
+
u_free, petsc_info = _attempt_solve(ksp)
|
|
621
|
+
lin_conv = petsc_info.get("converged")
|
|
622
|
+
reason = petsc_info.get("reason")
|
|
623
|
+
if lin_conv is None and reason is not None:
|
|
624
|
+
lin_conv = reason > 0
|
|
625
|
+
if lin_conv:
|
|
626
|
+
break
|
|
627
|
+
if self.config.petsc_shell_fallback and use_pmat and self.config.petsc_shell_fallback_rebuild_pmat:
|
|
628
|
+
self._petsc_shell_pmat = pmat
|
|
629
|
+
self._petsc_shell_pmat_rebuilds += 1
|
|
630
|
+
lin_iters = petsc_info.get("iters")
|
|
631
|
+
lin_res = petsc_info.get("residual_norm")
|
|
632
|
+
lin_solve_dt = petsc_info.get("solve_time")
|
|
633
|
+
pc_setup_dt = petsc_info.get("pc_setup_time")
|
|
634
|
+
pmat_dt = petsc_info.get("pmat_build_time")
|
|
635
|
+
lin_conv = petsc_info.get("converged")
|
|
636
|
+
if lin_conv is None and petsc_info.get("reason") is not None:
|
|
637
|
+
lin_conv = petsc_info.get("reason") > 0
|
|
638
|
+
if lin_conv is None:
|
|
639
|
+
lin_conv = True
|
|
640
|
+
self._petsc_shell_last_iters = lin_iters
|
|
641
|
+
info = SolverResult(
|
|
642
|
+
converged=bool(lin_conv),
|
|
643
|
+
iters=int(lin_iters) if lin_iters is not None else 0,
|
|
644
|
+
linear_iters=int(lin_iters) if lin_iters is not None else None,
|
|
645
|
+
linear_converged=bool(lin_conv),
|
|
646
|
+
linear_residual=float(lin_res) if lin_res is not None else None,
|
|
647
|
+
linear_solve_time=float(lin_solve_dt) if lin_solve_dt is not None else None,
|
|
648
|
+
pc_setup_time=float(pc_setup_dt) if pc_setup_dt is not None else None,
|
|
649
|
+
pmat_build_time=float(pmat_dt) if pmat_dt is not None else None,
|
|
650
|
+
pmat_rebuilds=self._petsc_shell_pmat_rebuilds if use_pmat else None,
|
|
651
|
+
pmat_mode=self.config.petsc_shell_pmat_mode if use_pmat else None,
|
|
652
|
+
tol=self.config.tol,
|
|
653
|
+
stop_reason=("converged" if lin_conv else "linfail"),
|
|
654
|
+
nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
|
|
655
|
+
)
|
|
656
|
+
if len(fallback_attempts) > 1:
|
|
657
|
+
info.linear_fallbacks = fallback_attempts
|
|
491
658
|
else:
|
|
492
659
|
raise ValueError(f"Unknown linear solve method: {self.config.method}")
|
|
493
660
|
|
fluxfem/solver/solver.py
CHANGED
|
@@ -5,8 +5,10 @@ import jax.numpy as jnp
|
|
|
5
5
|
|
|
6
6
|
from .cg import cg_solve, cg_solve_jax
|
|
7
7
|
from .newton import newton_solve
|
|
8
|
+
from .petsc import petsc_solve, petsc_shell_solve
|
|
8
9
|
from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
|
|
9
10
|
from .dirichlet import (
|
|
11
|
+
DirichletBC,
|
|
10
12
|
condense_dirichlet_dense,
|
|
11
13
|
condense_dirichlet_fluxsparse,
|
|
12
14
|
expand_dirichlet_solution,
|
|
@@ -46,6 +48,12 @@ class LinearSolver:
|
|
|
46
48
|
elif self.method == "spdirect_solve_gpu":
|
|
47
49
|
x = spdirect_solve_gpu(A, b)
|
|
48
50
|
return np.asarray(x), {"iters": 1, "converged": True}
|
|
51
|
+
elif self.method == "petsc":
|
|
52
|
+
x = petsc_solve(A, b)
|
|
53
|
+
return np.asarray(x), {"iters": None, "converged": True}
|
|
54
|
+
elif self.method == "petsc_shell":
|
|
55
|
+
x = petsc_shell_solve(A, b)
|
|
56
|
+
return np.asarray(x), {"iters": None, "converged": True}
|
|
49
57
|
else:
|
|
50
58
|
raise ValueError(f"Unknown linear method: {self.method}")
|
|
51
59
|
|
|
@@ -64,7 +72,10 @@ class LinearSolver:
|
|
|
64
72
|
if dirichlet_mode not in ("condense", "enforce"):
|
|
65
73
|
raise ValueError("dirichlet_mode must be 'condense' or 'enforce'.")
|
|
66
74
|
|
|
67
|
-
|
|
75
|
+
if isinstance(dirichlet, DirichletBC):
|
|
76
|
+
dir_dofs, dir_vals = dirichlet.as_tuple()
|
|
77
|
+
else:
|
|
78
|
+
dir_dofs, dir_vals = dirichlet
|
|
68
79
|
if dirichlet_mode == "enforce":
|
|
69
80
|
if isinstance(A, FluxSparseMatrix):
|
|
70
81
|
A_bc, b_bc = enforce_dirichlet_sparse(A, b, dir_dofs, dir_vals)
|
fluxfem/solver/sparse.py
CHANGED
|
@@ -12,6 +12,79 @@ except Exception: # pragma: no cover
|
|
|
12
12
|
sp = None
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
def coalesce_coo(rows, cols, data):
|
|
16
|
+
"""
|
|
17
|
+
Sum duplicate COO entries by sorting (CPU-friendly).
|
|
18
|
+
Returns (rows_u, cols_u, data_u) as NumPy arrays.
|
|
19
|
+
"""
|
|
20
|
+
r = np.asarray(rows, dtype=np.int64)
|
|
21
|
+
c = np.asarray(cols, dtype=np.int64)
|
|
22
|
+
d = np.asarray(data)
|
|
23
|
+
if r.size == 0:
|
|
24
|
+
return r, c, d
|
|
25
|
+
order = np.lexsort((c, r))
|
|
26
|
+
r_s = r[order]
|
|
27
|
+
c_s = c[order]
|
|
28
|
+
d_s = d[order]
|
|
29
|
+
new_group = np.ones(r_s.size, dtype=bool)
|
|
30
|
+
new_group[1:] = (r_s[1:] != r_s[:-1]) | (c_s[1:] != c_s[:-1])
|
|
31
|
+
starts = np.nonzero(new_group)[0]
|
|
32
|
+
r_u = r_s[starts]
|
|
33
|
+
c_u = c_s[starts]
|
|
34
|
+
d_u = np.add.reduceat(d_s, starts)
|
|
35
|
+
return r_u, c_u, d_u
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _normalize_flux_mats(mats):
|
|
39
|
+
if len(mats) == 1 and isinstance(mats[0], (list, tuple)):
|
|
40
|
+
mats = tuple(mats[0])
|
|
41
|
+
if not mats:
|
|
42
|
+
raise ValueError("At least one FluxSparseMatrix is required.")
|
|
43
|
+
return mats
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def concat_flux(*mats, n_dofs: int | None = None):
|
|
47
|
+
"""
|
|
48
|
+
Concatenate COO entries from multiple FluxSparseMatrix objects.
|
|
49
|
+
All matrices must share the same n_dofs unless n_dofs is provided.
|
|
50
|
+
"""
|
|
51
|
+
mats = _normalize_flux_mats(mats)
|
|
52
|
+
if n_dofs is None:
|
|
53
|
+
n_dofs = int(mats[0].n_dofs)
|
|
54
|
+
for mat in mats[1:]:
|
|
55
|
+
if int(mat.n_dofs) != n_dofs:
|
|
56
|
+
raise ValueError("All matrices must share n_dofs for concat_flux.")
|
|
57
|
+
rows_list = [np.asarray(mat.pattern.rows, dtype=np.int32) for mat in mats]
|
|
58
|
+
cols_list = [np.asarray(mat.pattern.cols, dtype=np.int32) for mat in mats]
|
|
59
|
+
data_list = [np.asarray(mat.data) for mat in mats]
|
|
60
|
+
rows = np.concatenate(rows_list) if rows_list else np.asarray([], dtype=np.int32)
|
|
61
|
+
cols = np.concatenate(cols_list) if cols_list else np.asarray([], dtype=np.int32)
|
|
62
|
+
data = np.concatenate(data_list) if data_list else np.asarray([], dtype=float)
|
|
63
|
+
return FluxSparseMatrix(rows, cols, data, int(n_dofs))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def block_diag_flux(*mats):
|
|
67
|
+
"""Block-diagonal concatenation for FluxSparseMatrix objects."""
|
|
68
|
+
mats = _normalize_flux_mats(mats)
|
|
69
|
+
rows_out = []
|
|
70
|
+
cols_out = []
|
|
71
|
+
data_out = []
|
|
72
|
+
offset = 0
|
|
73
|
+
for mat in mats:
|
|
74
|
+
rows = np.asarray(mat.pattern.rows, dtype=np.int32)
|
|
75
|
+
cols = np.asarray(mat.pattern.cols, dtype=np.int32)
|
|
76
|
+
data = np.asarray(mat.data)
|
|
77
|
+
if rows.size:
|
|
78
|
+
rows_out.append(rows + offset)
|
|
79
|
+
cols_out.append(cols + offset)
|
|
80
|
+
data_out.append(data)
|
|
81
|
+
offset += int(mat.n_dofs)
|
|
82
|
+
rows = np.concatenate(rows_out) if rows_out else np.asarray([], dtype=np.int32)
|
|
83
|
+
cols = np.concatenate(cols_out) if cols_out else np.asarray([], dtype=np.int32)
|
|
84
|
+
data = np.concatenate(data_out) if data_out else np.asarray([], dtype=float)
|
|
85
|
+
return FluxSparseMatrix(rows, cols, data, int(offset))
|
|
86
|
+
|
|
87
|
+
|
|
15
88
|
@jax.tree_util.register_pytree_node_class
|
|
16
89
|
@dataclass(frozen=True)
|
|
17
90
|
class SparsityPattern:
|
|
@@ -80,7 +153,7 @@ class FluxSparseMatrix:
|
|
|
80
153
|
- data stores the numeric values for the current nonlinear iterate
|
|
81
154
|
"""
|
|
82
155
|
|
|
83
|
-
def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None):
|
|
156
|
+
def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None, meta: dict | None = None):
|
|
84
157
|
# New signature: FluxSparseMatrix(pattern, data)
|
|
85
158
|
if isinstance(rows_or_pattern, SparsityPattern):
|
|
86
159
|
pattern = rows_or_pattern
|
|
@@ -88,15 +161,22 @@ class FluxSparseMatrix:
|
|
|
88
161
|
values = jnp.asarray(values)
|
|
89
162
|
else:
|
|
90
163
|
# Legacy signature: FluxSparseMatrix(rows, cols, data, n_dofs)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
164
|
+
r_j = jnp.asarray(rows_or_pattern, dtype=jnp.int32)
|
|
165
|
+
c_j = jnp.asarray(cols, dtype=jnp.int32)
|
|
166
|
+
is_tracer = isinstance(rows_or_pattern, jax.core.Tracer) or isinstance(cols, jax.core.Tracer)
|
|
167
|
+
diag_idx_j = None
|
|
168
|
+
if not is_tracer:
|
|
169
|
+
diag_idx_j = jnp.nonzero(r_j == c_j)[0].astype(jnp.int32)
|
|
170
|
+
if n_dofs is None:
|
|
171
|
+
if is_tracer:
|
|
172
|
+
raise ValueError("n_dofs must be provided when constructing FluxSparseMatrix under JIT.")
|
|
173
|
+
n_dofs = int(np.asarray(cols).max()) + 1
|
|
94
174
|
pattern = SparsityPattern(
|
|
95
|
-
rows=
|
|
96
|
-
cols=
|
|
97
|
-
n_dofs=int(n_dofs) if n_dofs is not None else int(
|
|
175
|
+
rows=r_j,
|
|
176
|
+
cols=c_j,
|
|
177
|
+
n_dofs=int(n_dofs) if n_dofs is not None else int(np.asarray(cols).max()) + 1,
|
|
98
178
|
idx=None,
|
|
99
|
-
diag_idx=
|
|
179
|
+
diag_idx=diag_idx_j,
|
|
100
180
|
)
|
|
101
181
|
values = jnp.asarray(data)
|
|
102
182
|
|
|
@@ -105,6 +185,7 @@ class FluxSparseMatrix:
|
|
|
105
185
|
self.cols = pattern.cols
|
|
106
186
|
self.n_dofs = int(pattern.n_dofs)
|
|
107
187
|
self.data = values
|
|
188
|
+
self.meta = dict(meta) if meta is not None else None
|
|
108
189
|
|
|
109
190
|
@classmethod
|
|
110
191
|
def from_bilinear(cls, coo_tuple):
|
|
@@ -121,11 +202,25 @@ class FluxSparseMatrix:
|
|
|
121
202
|
|
|
122
203
|
def with_data(self, data):
|
|
123
204
|
"""Return a new FluxSparseMatrix sharing the same pattern with updated data."""
|
|
124
|
-
return FluxSparseMatrix(self.pattern, data)
|
|
205
|
+
return FluxSparseMatrix(self.pattern, data, meta=self.meta)
|
|
206
|
+
|
|
207
|
+
def add_dense(self, dense):
|
|
208
|
+
"""Return a new FluxSparseMatrix with dense entries added on the pattern."""
|
|
209
|
+
dense_vals = jnp.asarray(dense)[self.pattern.rows, self.pattern.cols]
|
|
210
|
+
return FluxSparseMatrix(self.pattern, self.data + dense_vals)
|
|
125
211
|
|
|
126
212
|
def to_coo(self):
|
|
127
213
|
return self.pattern.rows, self.pattern.cols, self.data, self.pattern.n_dofs
|
|
128
214
|
|
|
215
|
+
@property
|
|
216
|
+
def nnz(self) -> int:
|
|
217
|
+
return int(self.data.shape[0])
|
|
218
|
+
|
|
219
|
+
def coalesce(self):
|
|
220
|
+
"""Return a new FluxSparseMatrix with duplicate entries summed."""
|
|
221
|
+
rows_u, cols_u, data_u = coalesce_coo(self.pattern.rows, self.pattern.cols, self.data)
|
|
222
|
+
return FluxSparseMatrix(rows_u, cols_u, data_u, self.pattern.n_dofs)
|
|
223
|
+
|
|
129
224
|
def to_csr(self):
|
|
130
225
|
if sp is None:
|
|
131
226
|
raise ImportError("scipy is required for to_csr()")
|
|
@@ -167,6 +262,26 @@ class FluxSparseMatrix:
|
|
|
167
262
|
out = jnp.zeros(self.pattern.n_dofs, dtype=contrib.dtype)
|
|
168
263
|
return out.at[self.pattern.rows].add(contrib)
|
|
169
264
|
|
|
265
|
+
def as_cg_operator(
|
|
266
|
+
self,
|
|
267
|
+
*,
|
|
268
|
+
matvec: str = "flux",
|
|
269
|
+
preconditioner=None,
|
|
270
|
+
solver: str = "cg",
|
|
271
|
+
dof_per_node: int | None = None,
|
|
272
|
+
block_sizes=None,
|
|
273
|
+
):
|
|
274
|
+
from .cg import build_cg_operator
|
|
275
|
+
|
|
276
|
+
return build_cg_operator(
|
|
277
|
+
self,
|
|
278
|
+
matvec=matvec,
|
|
279
|
+
preconditioner=preconditioner,
|
|
280
|
+
solver=solver,
|
|
281
|
+
dof_per_node=dof_per_node,
|
|
282
|
+
block_sizes=block_sizes,
|
|
283
|
+
)
|
|
284
|
+
|
|
170
285
|
def diag(self):
|
|
171
286
|
"""Diagonal entries aggregated for Jacobi preconditioning."""
|
|
172
287
|
if self.pattern.diag_idx is not None:
|