fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/solver/cg.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from typing import Any, Callable, TypeAlias
|
|
4
|
+
|
|
3
5
|
import jax
|
|
4
6
|
import jax.numpy as jnp
|
|
5
7
|
import jax.scipy as jsp
|
|
@@ -10,9 +12,16 @@ except Exception: # pragma: no cover
|
|
|
10
12
|
jsparse = None
|
|
11
13
|
|
|
12
14
|
from .sparse import FluxSparseMatrix
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
from .preconditioner import make_block_jacobi_preconditioner
|
|
18
|
+
|
|
19
|
+
ArrayLike: TypeAlias = jnp.ndarray
|
|
20
|
+
MatVec: TypeAlias = Callable[[jnp.ndarray], jnp.ndarray]
|
|
21
|
+
CGInfo: TypeAlias = dict[str, Any]
|
|
13
22
|
|
|
14
23
|
|
|
15
|
-
def _matvec_builder(A):
|
|
24
|
+
def _matvec_builder(A: Any) -> MatVec:
|
|
16
25
|
if jsparse is not None and isinstance(A, jsparse.BCOO):
|
|
17
26
|
return lambda x: A @ x
|
|
18
27
|
if isinstance(A, FluxSparseMatrix):
|
|
@@ -30,7 +39,125 @@ def _matvec_builder(A):
|
|
|
30
39
|
return mv
|
|
31
40
|
|
|
32
41
|
|
|
33
|
-
def
|
|
42
|
+
def _coo_tuple_from_any(A: Any):
|
|
43
|
+
if isinstance(A, FluxSparseMatrix):
|
|
44
|
+
return A.to_coo()
|
|
45
|
+
if isinstance(A, tuple) and len(A) == 4:
|
|
46
|
+
return A
|
|
47
|
+
try:
|
|
48
|
+
import scipy.sparse as sp # type: ignore
|
|
49
|
+
except Exception: # pragma: no cover
|
|
50
|
+
sp = None
|
|
51
|
+
if sp is not None and sp.issparse(A):
|
|
52
|
+
coo = A.tocoo()
|
|
53
|
+
return (
|
|
54
|
+
jnp.asarray(coo.row, dtype=jnp.int32),
|
|
55
|
+
jnp.asarray(coo.col, dtype=jnp.int32),
|
|
56
|
+
jnp.asarray(coo.data),
|
|
57
|
+
int(A.shape[0]),
|
|
58
|
+
)
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _to_flux_matrix(A: Any) -> FluxSparseMatrix:
|
|
63
|
+
if isinstance(A, FluxSparseMatrix):
|
|
64
|
+
return A
|
|
65
|
+
coo = _coo_tuple_from_any(A)
|
|
66
|
+
if coo is None:
|
|
67
|
+
raise ValueError("Unable to build FluxSparseMatrix from A")
|
|
68
|
+
return FluxSparseMatrix.from_bilinear(coo)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _to_bcoo_matrix(A: Any):
|
|
72
|
+
if jsparse is None:
|
|
73
|
+
raise ImportError("jax.experimental.sparse is required for BCOO matvec")
|
|
74
|
+
if jsparse is not None and isinstance(A, jsparse.BCOO):
|
|
75
|
+
return A
|
|
76
|
+
coo = _coo_tuple_from_any(A)
|
|
77
|
+
if coo is None:
|
|
78
|
+
raise ValueError("Unable to build BCOO from A")
|
|
79
|
+
rows, cols, data, n = coo
|
|
80
|
+
idx = jnp.stack([rows, cols], axis=-1)
|
|
81
|
+
return jsparse.BCOO((data, idx), shape=(n, n))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _normalize_matvec_matrix(A: Any, matvec: str):
|
|
85
|
+
if matvec == "flux":
|
|
86
|
+
return _to_flux_matrix(A)
|
|
87
|
+
if matvec == "bcoo":
|
|
88
|
+
return _to_bcoo_matrix(A)
|
|
89
|
+
if matvec == "dense":
|
|
90
|
+
return jnp.asarray(A)
|
|
91
|
+
if matvec == "auto":
|
|
92
|
+
if jsparse is not None:
|
|
93
|
+
try:
|
|
94
|
+
return _to_bcoo_matrix(A)
|
|
95
|
+
except Exception:
|
|
96
|
+
return _to_flux_matrix(A)
|
|
97
|
+
return _to_flux_matrix(A)
|
|
98
|
+
raise ValueError(f"Unknown matvec backend: {matvec}")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass(frozen=True)
|
|
102
|
+
class CGOperator:
|
|
103
|
+
"""
|
|
104
|
+
Lightweight CG operator wrapper with a consistent solve() entry point.
|
|
105
|
+
"""
|
|
106
|
+
A: object
|
|
107
|
+
preconditioner: object | None = None
|
|
108
|
+
solver: str = "cg"
|
|
109
|
+
|
|
110
|
+
def solve(
|
|
111
|
+
self,
|
|
112
|
+
b: jnp.ndarray,
|
|
113
|
+
*,
|
|
114
|
+
x0: jnp.ndarray | None = None,
|
|
115
|
+
tol: float = 1e-8,
|
|
116
|
+
maxiter: int | None = None,
|
|
117
|
+
):
|
|
118
|
+
if self.solver == "cg":
|
|
119
|
+
return cg_solve(
|
|
120
|
+
self.A,
|
|
121
|
+
b,
|
|
122
|
+
x0=x0,
|
|
123
|
+
tol=tol,
|
|
124
|
+
maxiter=maxiter,
|
|
125
|
+
preconditioner=self.preconditioner,
|
|
126
|
+
)
|
|
127
|
+
if self.solver == "cg_jax":
|
|
128
|
+
return cg_solve_jax(
|
|
129
|
+
self.A,
|
|
130
|
+
b,
|
|
131
|
+
x0=x0,
|
|
132
|
+
tol=tol,
|
|
133
|
+
maxiter=maxiter,
|
|
134
|
+
preconditioner=self.preconditioner,
|
|
135
|
+
)
|
|
136
|
+
raise ValueError(f"Unknown CG solver: {self.solver}")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def build_cg_operator(
|
|
140
|
+
A: Any,
|
|
141
|
+
*,
|
|
142
|
+
matvec: str = "flux",
|
|
143
|
+
preconditioner: object | None = None,
|
|
144
|
+
solver: str = "cg",
|
|
145
|
+
dof_per_node: int | None = None,
|
|
146
|
+
block_sizes: object | None = None,
|
|
147
|
+
) -> CGOperator:
|
|
148
|
+
"""
|
|
149
|
+
Normalize CG inputs into a single operator interface.
|
|
150
|
+
"""
|
|
151
|
+
A_mat = _normalize_matvec_matrix(A, matvec)
|
|
152
|
+
precon = preconditioner
|
|
153
|
+
if preconditioner == "block_jacobi":
|
|
154
|
+
precon = make_block_jacobi_preconditioner(
|
|
155
|
+
A_mat, dof_per_node=dof_per_node, block_sizes=block_sizes
|
|
156
|
+
)
|
|
157
|
+
return CGOperator(A=A_mat, preconditioner=precon, solver=solver)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _diag_builder(A: Any, n: int) -> jnp.ndarray:
|
|
34
161
|
"""
|
|
35
162
|
Build diagonal for a Jacobi preconditioner when available.
|
|
36
163
|
"""
|
|
@@ -57,14 +184,14 @@ def _diag_builder(A, n: int):
|
|
|
57
184
|
|
|
58
185
|
|
|
59
186
|
def _cg_solve_single(
|
|
60
|
-
A,
|
|
61
|
-
b,
|
|
187
|
+
A: Any,
|
|
188
|
+
b: jnp.ndarray,
|
|
62
189
|
*,
|
|
63
|
-
x0=None,
|
|
190
|
+
x0: jnp.ndarray | None = None,
|
|
64
191
|
tol: float = 1e-8,
|
|
65
192
|
maxiter: int | None = None,
|
|
66
|
-
preconditioner=None,
|
|
67
|
-
):
|
|
193
|
+
preconditioner: object | None = None,
|
|
194
|
+
) -> tuple[jnp.ndarray, CGInfo]:
|
|
68
195
|
"""
|
|
69
196
|
Conjugate gradient (Ax=b) in JAX.
|
|
70
197
|
A: FluxSparseMatrix / (rows, cols, data, n) / dense array
|
|
@@ -92,39 +219,7 @@ def _cg_solve_single(
|
|
|
92
219
|
return inv_diag * r
|
|
93
220
|
|
|
94
221
|
elif preconditioner == "block_jacobi":
|
|
95
|
-
|
|
96
|
-
if n % 3 != 0:
|
|
97
|
-
raise ValueError("block_jacobi requires n_dofs % 3 == 0")
|
|
98
|
-
if jsparse is not None and isinstance(A, jsparse.BCOO):
|
|
99
|
-
rows = A.indices[:, 0]
|
|
100
|
-
cols = A.indices[:, 1]
|
|
101
|
-
data = A.data
|
|
102
|
-
elif isinstance(A, FluxSparseMatrix):
|
|
103
|
-
rows = jnp.asarray(A.pattern.rows)
|
|
104
|
-
cols = jnp.asarray(A.pattern.cols)
|
|
105
|
-
data = jnp.asarray(A.data)
|
|
106
|
-
else:
|
|
107
|
-
raise ValueError("block_jacobi requires FluxSparseMatrix or BCOO")
|
|
108
|
-
|
|
109
|
-
block_rows = rows // 3
|
|
110
|
-
block_cols = cols // 3
|
|
111
|
-
lr = rows % 3
|
|
112
|
-
lc = cols % 3
|
|
113
|
-
mask = block_rows == block_cols
|
|
114
|
-
block_rows = block_rows[mask]
|
|
115
|
-
lr = lr[mask]
|
|
116
|
-
lc = lc[mask]
|
|
117
|
-
data = data[mask]
|
|
118
|
-
n_block = n // 3
|
|
119
|
-
blocks = jnp.zeros((n_block, 3, 3), dtype=data.dtype)
|
|
120
|
-
blocks = blocks.at[block_rows, lr, lc].add(data)
|
|
121
|
-
blocks = blocks + 1e-12 * jnp.eye(3)[None, :, :]
|
|
122
|
-
inv_blocks = jnp.linalg.inv(blocks)
|
|
123
|
-
|
|
124
|
-
def precon(r):
|
|
125
|
-
rb = r.reshape((n_block, 3))
|
|
126
|
-
zb = jnp.einsum("bij,bj->bi", inv_blocks, rb)
|
|
127
|
-
return zb.reshape((-1,))
|
|
222
|
+
precon = make_block_jacobi_preconditioner(A)
|
|
128
223
|
|
|
129
224
|
elif callable(preconditioner):
|
|
130
225
|
precon = preconditioner
|
|
@@ -159,14 +254,14 @@ def _cg_solve_single(
|
|
|
159
254
|
|
|
160
255
|
|
|
161
256
|
def cg_solve(
|
|
162
|
-
A,
|
|
163
|
-
b,
|
|
257
|
+
A: Any,
|
|
258
|
+
b: jnp.ndarray,
|
|
164
259
|
*,
|
|
165
|
-
x0=None,
|
|
260
|
+
x0: jnp.ndarray | None = None,
|
|
166
261
|
tol: float = 1e-8,
|
|
167
262
|
maxiter: int | None = None,
|
|
168
|
-
preconditioner=None,
|
|
169
|
-
):
|
|
263
|
+
preconditioner: object | None = None,
|
|
264
|
+
) -> tuple[jnp.ndarray, CGInfo]:
|
|
170
265
|
"""
|
|
171
266
|
Conjugate gradient (Ax=b) in JAX.
|
|
172
267
|
Supports single RHS (n,) or multiple RHS (n, n_rhs).
|
|
@@ -208,14 +303,14 @@ def cg_solve(
|
|
|
208
303
|
|
|
209
304
|
|
|
210
305
|
def _cg_solve_jax_single(
|
|
211
|
-
A,
|
|
212
|
-
b,
|
|
306
|
+
A: Any,
|
|
307
|
+
b: jnp.ndarray,
|
|
213
308
|
*,
|
|
214
|
-
x0=None,
|
|
309
|
+
x0: jnp.ndarray | None = None,
|
|
215
310
|
tol: float = 1e-8,
|
|
216
311
|
maxiter: int | None = None,
|
|
217
|
-
preconditioner=None,
|
|
218
|
-
):
|
|
312
|
+
preconditioner: object | None = None,
|
|
313
|
+
) -> tuple[jnp.ndarray, CGInfo]:
|
|
219
314
|
"""
|
|
220
315
|
Conjugate gradient via jax.scipy.sparse.linalg.cg.
|
|
221
316
|
A: FluxSparseMatrix / (rows, cols, data, n) / dense array / callable
|
|
@@ -277,14 +372,14 @@ def _cg_solve_jax_single(
|
|
|
277
372
|
|
|
278
373
|
|
|
279
374
|
def cg_solve_jax(
|
|
280
|
-
A,
|
|
281
|
-
b,
|
|
375
|
+
A: Any,
|
|
376
|
+
b: jnp.ndarray,
|
|
282
377
|
*,
|
|
283
|
-
x0=None,
|
|
378
|
+
x0: jnp.ndarray | None = None,
|
|
284
379
|
tol: float = 1e-8,
|
|
285
380
|
maxiter: int | None = None,
|
|
286
|
-
preconditioner=None,
|
|
287
|
-
):
|
|
381
|
+
preconditioner: object | None = None,
|
|
382
|
+
) -> tuple[jnp.ndarray, CGInfo]:
|
|
288
383
|
"""
|
|
289
384
|
Conjugate gradient via jax.scipy.sparse.linalg.cg.
|
|
290
385
|
Supports single RHS (n,) or multiple RHS (n, n_rhs).
|
fluxfem/solver/dirichlet.py
CHANGED
|
@@ -1,16 +1,155 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, TYPE_CHECKING, TypeAlias
|
|
5
|
+
|
|
3
6
|
import numpy as np
|
|
4
7
|
import jax.numpy as jnp
|
|
5
8
|
|
|
6
|
-
|
|
9
|
+
try:
|
|
10
|
+
import scipy.sparse as sp
|
|
11
|
+
except Exception: # pragma: no cover
|
|
12
|
+
sp = None
|
|
7
13
|
|
|
14
|
+
from .sparse import FluxSparseMatrix, coalesce_coo
|
|
8
15
|
|
|
9
|
-
|
|
10
|
-
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from jax import Array as JaxArray
|
|
18
|
+
|
|
19
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
20
|
+
else:
|
|
21
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
22
|
+
DirichletLike: TypeAlias = tuple[np.ndarray, np.ndarray]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _normalize_dirichlet_values(dofs: ArrayLike, vals: ArrayLike | None) -> np.ndarray:
|
|
11
26
|
if vals is None:
|
|
12
|
-
return
|
|
13
|
-
|
|
27
|
+
return np.zeros(np.asarray(dofs).shape[0], dtype=float)
|
|
28
|
+
arr = np.asarray(vals)
|
|
29
|
+
if arr.ndim == 0:
|
|
30
|
+
return np.full(np.asarray(dofs).shape[0], float(arr), dtype=float)
|
|
31
|
+
return arr
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _normalize_dirichlet(dofs: ArrayLike, vals: ArrayLike | None) -> DirichletLike:
|
|
35
|
+
dir_arr = np.asarray(dofs, dtype=int)
|
|
36
|
+
return dir_arr, _normalize_dirichlet_values(dir_arr, vals)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class CondensedSystem:
|
|
41
|
+
K: Any
|
|
42
|
+
F: Any
|
|
43
|
+
free_dofs: np.ndarray
|
|
44
|
+
dir_dofs: np.ndarray
|
|
45
|
+
dir_vals: np.ndarray
|
|
46
|
+
n_dofs: int
|
|
47
|
+
|
|
48
|
+
def expand(self, u_free: ArrayLike, *, fill_dirichlet: bool = True) -> np.ndarray:
|
|
49
|
+
u_full = np.zeros(self.n_dofs, dtype=np.asarray(u_free).dtype)
|
|
50
|
+
u_full[self.free_dofs] = np.asarray(u_free)
|
|
51
|
+
if fill_dirichlet and self.dir_dofs.size:
|
|
52
|
+
u_full[self.dir_dofs] = np.asarray(self.dir_vals, dtype=u_full.dtype)
|
|
53
|
+
return u_full
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class DirichletBC:
|
|
60
|
+
"""
|
|
61
|
+
Dirichlet boundary condition container with helper methods.
|
|
62
|
+
"""
|
|
63
|
+
dofs: np.ndarray
|
|
64
|
+
vals: np.ndarray
|
|
65
|
+
|
|
66
|
+
def __post_init__(self):
|
|
67
|
+
dofs, vals = _normalize_dirichlet(self.dofs, self.vals)
|
|
68
|
+
object.__setattr__(self, "dofs", dofs)
|
|
69
|
+
object.__setattr__(self, "vals", vals)
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_boundary_dofs(cls, mesh, predicate, *, values: ArrayLike | None = None, **kwargs) -> "DirichletBC":
|
|
73
|
+
"""
|
|
74
|
+
Build from mesh.boundary_dofs_where predicate.
|
|
75
|
+
|
|
76
|
+
kwargs are forwarded to mesh.boundary_dofs_where (e.g. components=..., dof_per_node=...).
|
|
77
|
+
"""
|
|
78
|
+
dofs = mesh.boundary_dofs_where(predicate, **kwargs)
|
|
79
|
+
vals = _normalize_dirichlet_values(dofs, values)
|
|
80
|
+
return cls(dofs, vals)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_bbox(
|
|
84
|
+
cls,
|
|
85
|
+
mesh,
|
|
86
|
+
*,
|
|
87
|
+
mins: ArrayLike | None = None,
|
|
88
|
+
maxs: ArrayLike | None = None,
|
|
89
|
+
tol: float = 1e-8,
|
|
90
|
+
values: ArrayLike | None = None,
|
|
91
|
+
**kwargs,
|
|
92
|
+
) -> "DirichletBC":
|
|
93
|
+
"""
|
|
94
|
+
Build from the mesh axis-aligned bounding box.
|
|
95
|
+
|
|
96
|
+
mins/maxs default to mesh coordinate extrema. kwargs are forwarded to
|
|
97
|
+
mesh.boundary_dofs_where (e.g. components=..., dof_per_node=...).
|
|
98
|
+
"""
|
|
99
|
+
from ..mesh.predicate import bbox_predicate
|
|
100
|
+
|
|
101
|
+
coords = np.asarray(mesh.coords)
|
|
102
|
+
if mins is None:
|
|
103
|
+
mins = coords.min(axis=0)
|
|
104
|
+
if maxs is None:
|
|
105
|
+
maxs = coords.max(axis=0)
|
|
106
|
+
pred = bbox_predicate(mins, maxs, tol=tol)
|
|
107
|
+
dofs = mesh.boundary_dofs_where(pred, **kwargs)
|
|
108
|
+
vals = _normalize_dirichlet_values(dofs, values)
|
|
109
|
+
return cls(dofs, vals)
|
|
110
|
+
|
|
111
|
+
def as_tuple(self) -> tuple[np.ndarray, np.ndarray]:
|
|
112
|
+
return self.dofs, self.vals
|
|
113
|
+
|
|
114
|
+
def condense_system(self, A: Any, F: ArrayLike, *, check: bool = True) -> CondensedSystem:
|
|
115
|
+
return condense_dirichlet_system(A, F, self.dofs, self.vals, check=check)
|
|
116
|
+
|
|
117
|
+
def enforce_system(self, A: Any, F: ArrayLike):
|
|
118
|
+
return enforce_dirichlet_system(A, F, self.dofs, self.vals)
|
|
119
|
+
|
|
120
|
+
def condense_flux(self, A: FluxSparseMatrix, F: ArrayLike):
|
|
121
|
+
"""
|
|
122
|
+
Condense for FluxSparseMatrix and return (K_free, F_free, free_dofs).
|
|
123
|
+
"""
|
|
124
|
+
condensed = self.condense_system(A, F)
|
|
125
|
+
free = condensed.free_dofs
|
|
126
|
+
return restrict_flux_to_free(A, free), condensed.F, free
|
|
127
|
+
|
|
128
|
+
def enforce_flux(self, A: FluxSparseMatrix, F: ArrayLike):
|
|
129
|
+
return enforce_dirichlet_fluxsparse(A, F, self.dofs, self.vals)
|
|
130
|
+
|
|
131
|
+
def split_matrix(self, A: Any, *, n_total: int | None = None):
|
|
132
|
+
return split_dirichlet_matrix(A, self.dofs, n_total=n_total)
|
|
133
|
+
|
|
134
|
+
def free_dofs(self, n_dofs: int) -> np.ndarray:
|
|
135
|
+
return free_dofs(n_dofs, self.dofs)
|
|
136
|
+
|
|
137
|
+
def expand_solution(
|
|
138
|
+
self,
|
|
139
|
+
u_free: ArrayLike,
|
|
140
|
+
*,
|
|
141
|
+
free: np.ndarray | None = None,
|
|
142
|
+
n_total: int | None = None,
|
|
143
|
+
) -> np.ndarray:
|
|
144
|
+
if free is None:
|
|
145
|
+
if n_total is None:
|
|
146
|
+
raise ValueError("n_total is required when free is not provided.")
|
|
147
|
+
free = free_dofs(n_total, self.dofs)
|
|
148
|
+
if n_total is None:
|
|
149
|
+
max_free = int(np.max(free)) if len(free) else -1
|
|
150
|
+
max_dir = int(np.max(self.dofs)) if len(self.dofs) else -1
|
|
151
|
+
n_total = max(max_free, max_dir) + 1
|
|
152
|
+
return expand_dirichlet_solution(u_free, free, self.dofs, self.vals, n_total=n_total)
|
|
14
153
|
|
|
15
154
|
|
|
16
155
|
def enforce_dirichlet_dense(K, F, dofs, vals):
|
|
@@ -33,6 +172,135 @@ def enforce_dirichlet_dense(K, F, dofs, vals):
|
|
|
33
172
|
return Kc, Fc
|
|
34
173
|
|
|
35
174
|
|
|
175
|
+
def enforce_dirichlet_dense_jax(K, F, dofs, vals):
|
|
176
|
+
"""Apply Dirichlet conditions directly to stiffness/load (dense, JAX-friendly)."""
|
|
177
|
+
import jax.numpy as jnp
|
|
178
|
+
|
|
179
|
+
dofs, vals = _normalize_dirichlet(dofs, vals)
|
|
180
|
+
dofs = jnp.asarray(dofs, dtype=jnp.int32)
|
|
181
|
+
vals = jnp.asarray(vals, dtype=F.dtype)
|
|
182
|
+
if F.ndim == 2:
|
|
183
|
+
F_mod = F - (K[:, dofs] @ vals)[:, None]
|
|
184
|
+
else:
|
|
185
|
+
F_mod = F - K[:, dofs] @ vals
|
|
186
|
+
K_mod = K.at[:, dofs].set(0.0)
|
|
187
|
+
K_mod = K_mod.at[dofs, :].set(0.0)
|
|
188
|
+
K_mod = K_mod.at[dofs, dofs].set(1.0)
|
|
189
|
+
if F.ndim == 2:
|
|
190
|
+
F_mod = F_mod.at[dofs, :].set(vals[:, None])
|
|
191
|
+
else:
|
|
192
|
+
F_mod = F_mod.at[dofs].set(vals)
|
|
193
|
+
return K_mod, F_mod
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def enforce_dirichlet_system(A, F, dofs, vals):
|
|
197
|
+
"""
|
|
198
|
+
Apply Dirichlet conditions directly to stiffness/load.
|
|
199
|
+
Dispatches based on matrix type (FluxSparseMatrix, JAX dense, or numpy dense).
|
|
200
|
+
"""
|
|
201
|
+
if isinstance(A, FluxSparseMatrix):
|
|
202
|
+
return enforce_dirichlet_sparse(A, F, dofs, vals)
|
|
203
|
+
try:
|
|
204
|
+
import jax.numpy as jnp
|
|
205
|
+
|
|
206
|
+
if isinstance(A, jnp.ndarray):
|
|
207
|
+
return enforce_dirichlet_dense_jax(A, F, dofs, vals)
|
|
208
|
+
except Exception:
|
|
209
|
+
pass
|
|
210
|
+
return enforce_dirichlet_dense(A, F, dofs, vals)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def split_dirichlet_matrix(A, dir_dofs, *, n_total: int | None = None):
|
|
214
|
+
"""
|
|
215
|
+
Split a matrix into free-free and free-dirichlet blocks.
|
|
216
|
+
|
|
217
|
+
Returns (free, dir_dofs, A_ff, A_fd).
|
|
218
|
+
"""
|
|
219
|
+
dir_dofs, _ = _normalize_dirichlet(dir_dofs, None)
|
|
220
|
+
if n_total is None:
|
|
221
|
+
if hasattr(A, "n_dofs"):
|
|
222
|
+
n_total = int(A.n_dofs)
|
|
223
|
+
else:
|
|
224
|
+
arr = np.asarray(A)
|
|
225
|
+
if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
|
|
226
|
+
raise ValueError("A must be square when n_total is not provided.")
|
|
227
|
+
n_total = int(arr.shape[0])
|
|
228
|
+
free = free_dofs(n_total, dir_dofs)
|
|
229
|
+
|
|
230
|
+
if isinstance(A, FluxSparseMatrix):
|
|
231
|
+
if sp is None:
|
|
232
|
+
raise ImportError("scipy is required to split FluxSparseMatrix.")
|
|
233
|
+
A = A.to_csr()
|
|
234
|
+
|
|
235
|
+
if sp is not None and sp.issparse(A):
|
|
236
|
+
return free, dir_dofs, A[free][:, free], A[free][:, dir_dofs]
|
|
237
|
+
|
|
238
|
+
if isinstance(A, jnp.ndarray):
|
|
239
|
+
free_j = jnp.asarray(free, dtype=jnp.int32)
|
|
240
|
+
dir_j = jnp.asarray(dir_dofs, dtype=jnp.int32)
|
|
241
|
+
return free, dir_dofs, A[jnp.ix_(free_j, free_j)], A[jnp.ix_(free_j, dir_j)]
|
|
242
|
+
|
|
243
|
+
arr = np.asarray(A)
|
|
244
|
+
return free, dir_dofs, arr[np.ix_(free, free)], arr[np.ix_(free, dir_dofs)]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def condense_dirichlet_system(A, F, dofs, vals, *, check: bool = True) -> CondensedSystem:
|
|
248
|
+
"""
|
|
249
|
+
Condense Dirichlet DOFs and return a structured system.
|
|
250
|
+
"""
|
|
251
|
+
dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
|
|
252
|
+
F_arr = np.asarray(F)
|
|
253
|
+
if hasattr(A, "n_dofs"):
|
|
254
|
+
n_total = int(A.n_dofs)
|
|
255
|
+
else:
|
|
256
|
+
A_np = np.asarray(A)
|
|
257
|
+
if A_np.ndim != 2 or A_np.shape[0] != A_np.shape[1]:
|
|
258
|
+
raise ValueError("A must be square for Dirichlet condensation.")
|
|
259
|
+
n_total = int(A_np.shape[0])
|
|
260
|
+
|
|
261
|
+
if check:
|
|
262
|
+
if dir_arr.size != dir_vals_arr.size:
|
|
263
|
+
raise ValueError("dir_dofs and dir_vals must have the same length")
|
|
264
|
+
if dir_arr.size:
|
|
265
|
+
if np.min(dir_arr) < 0 or np.max(dir_arr) >= n_total:
|
|
266
|
+
raise ValueError("dir_dofs out of bounds")
|
|
267
|
+
if np.unique(dir_arr).size != dir_arr.size:
|
|
268
|
+
raise ValueError("dir_dofs contains duplicates")
|
|
269
|
+
|
|
270
|
+
mask = np.ones(n_total, dtype=bool)
|
|
271
|
+
mask[dir_arr] = False
|
|
272
|
+
free = np.nonzero(mask)[0]
|
|
273
|
+
|
|
274
|
+
if isinstance(A, FluxSparseMatrix):
|
|
275
|
+
K_csr = A.to_csr()
|
|
276
|
+
elif sp is not None and sp.issparse(A):
|
|
277
|
+
K_csr = A.tocsr()
|
|
278
|
+
elif hasattr(A, "to_csr"):
|
|
279
|
+
K_csr = A.to_csr()
|
|
280
|
+
else:
|
|
281
|
+
K_csr = np.asarray(A)
|
|
282
|
+
|
|
283
|
+
K_ff = K_csr[free][:, free]
|
|
284
|
+
F_free = F_arr[free]
|
|
285
|
+
if dir_arr.size:
|
|
286
|
+
K_fd = K_csr[free][:, dir_arr]
|
|
287
|
+
if F_free.ndim == 2:
|
|
288
|
+
F_free = F_free - (K_fd @ dir_vals_arr)[:, None]
|
|
289
|
+
else:
|
|
290
|
+
F_free = F_free - K_fd @ dir_vals_arr
|
|
291
|
+
|
|
292
|
+
return CondensedSystem(
|
|
293
|
+
K=K_ff,
|
|
294
|
+
F=F_free,
|
|
295
|
+
free_dofs=free,
|
|
296
|
+
dir_dofs=dir_arr,
|
|
297
|
+
dir_vals=dir_vals_arr,
|
|
298
|
+
n_dofs=n_total,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
|
|
36
304
|
def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
|
|
37
305
|
"""Apply Dirichlet conditions to FluxSparseMatrix + load (CSR)."""
|
|
38
306
|
K_csr = A.to_csr().tolil()
|
|
@@ -54,6 +322,11 @@ def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
|
|
|
54
322
|
return K_csr.tocsr(), Fc
|
|
55
323
|
|
|
56
324
|
|
|
325
|
+
def enforce_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
|
|
326
|
+
"""Alias for enforce_dirichlet_sparse for FluxSparseMatrix inputs."""
|
|
327
|
+
return enforce_dirichlet_sparse(A, F, dofs, vals)
|
|
328
|
+
|
|
329
|
+
|
|
57
330
|
def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
|
|
58
331
|
"""
|
|
59
332
|
Condense Dirichlet DOFs for a FluxSparseMatrix.
|
|
@@ -76,6 +349,63 @@ def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
|
|
|
76
349
|
return K_ff, F_free, free, dir_arr, dir_vals_arr
|
|
77
350
|
|
|
78
351
|
|
|
352
|
+
def condense_dirichlet_fluxsparse_coo(
|
|
353
|
+
A: FluxSparseMatrix,
|
|
354
|
+
F,
|
|
355
|
+
dofs,
|
|
356
|
+
vals,
|
|
357
|
+
*,
|
|
358
|
+
coalesce: bool = True,
|
|
359
|
+
):
|
|
360
|
+
"""
|
|
361
|
+
Condense Dirichlet DOFs for a FluxSparseMatrix using COO filtering.
|
|
362
|
+
Returns: (K_free, F_free, free_dofs, dir_dofs, dir_vals)
|
|
363
|
+
"""
|
|
364
|
+
dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
|
|
365
|
+
n_total = int(A.n_dofs)
|
|
366
|
+
mask = np.ones(n_total, dtype=bool)
|
|
367
|
+
mask[dir_arr] = False
|
|
368
|
+
free = np.nonzero(mask)[0]
|
|
369
|
+
|
|
370
|
+
rows = np.asarray(A.pattern.rows, dtype=np.int64)
|
|
371
|
+
cols = np.asarray(A.pattern.cols, dtype=np.int64)
|
|
372
|
+
data = np.asarray(A.data)
|
|
373
|
+
|
|
374
|
+
g2l = -np.ones(n_total, dtype=np.int32)
|
|
375
|
+
g2l[free] = np.arange(free.size, dtype=np.int32)
|
|
376
|
+
r2 = g2l[rows]
|
|
377
|
+
c2 = g2l[cols]
|
|
378
|
+
keep = (r2 >= 0) & (c2 >= 0)
|
|
379
|
+
|
|
380
|
+
rows_f = r2[keep]
|
|
381
|
+
cols_f = c2[keep]
|
|
382
|
+
data_f = data[keep]
|
|
383
|
+
if coalesce:
|
|
384
|
+
rows_f, cols_f, data_f = coalesce_coo(rows_f, cols_f, data_f)
|
|
385
|
+
|
|
386
|
+
K_free = FluxSparseMatrix(rows_f, cols_f, data_f, int(free.size))
|
|
387
|
+
|
|
388
|
+
F_arr = np.asarray(F, dtype=float)
|
|
389
|
+
F_free = F_arr[free]
|
|
390
|
+
if dir_arr.size > 0 and not np.allclose(dir_vals_arr, 0.0):
|
|
391
|
+
dir_full = np.zeros(n_total, dtype=F_arr.dtype)
|
|
392
|
+
dir_full[dir_arr] = dir_vals_arr
|
|
393
|
+
mask_fd = mask[rows] & (~mask[cols])
|
|
394
|
+
if np.any(mask_fd):
|
|
395
|
+
rows_fd = rows[mask_fd]
|
|
396
|
+
cols_fd = cols[mask_fd]
|
|
397
|
+
data_fd = data[mask_fd]
|
|
398
|
+
contrib = data_fd * dir_full[cols_fd]
|
|
399
|
+
delta = np.zeros(n_total, dtype=F_arr.dtype)
|
|
400
|
+
np.add.at(delta, rows_fd, contrib)
|
|
401
|
+
if F_free.ndim == 2:
|
|
402
|
+
F_free = F_free - delta[free][:, None]
|
|
403
|
+
else:
|
|
404
|
+
F_free = F_free - delta[free]
|
|
405
|
+
|
|
406
|
+
return K_free, jnp.asarray(F_free), free, dir_arr, dir_vals_arr
|
|
407
|
+
|
|
408
|
+
|
|
79
409
|
def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
|
|
80
410
|
"""
|
|
81
411
|
Return free DOF indices given total DOFs and Dirichlet DOFs.
|
|
@@ -86,6 +416,29 @@ def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
|
|
|
86
416
|
return np.nonzero(mask)[0]
|
|
87
417
|
|
|
88
418
|
|
|
419
|
+
def restrict_flux_to_free(K: FluxSparseMatrix, free: np.ndarray, *, coalesce: bool = True) -> FluxSparseMatrix:
|
|
420
|
+
"""
|
|
421
|
+
Restrict a FluxSparseMatrix to free DOFs and return the condensed matrix.
|
|
422
|
+
"""
|
|
423
|
+
free = np.asarray(free, dtype=np.int32)
|
|
424
|
+
g2l = -np.ones(K.n_dofs, dtype=np.int32)
|
|
425
|
+
g2l[free] = np.arange(free.size, dtype=np.int32)
|
|
426
|
+
|
|
427
|
+
rows = np.asarray(K.pattern.rows)
|
|
428
|
+
cols = np.asarray(K.pattern.cols)
|
|
429
|
+
data = np.asarray(K.data)
|
|
430
|
+
r2 = g2l[rows]
|
|
431
|
+
c2 = g2l[cols]
|
|
432
|
+
mask = (r2 >= 0) & (c2 >= 0)
|
|
433
|
+
K_free = FluxSparseMatrix(
|
|
434
|
+
jnp.asarray(r2[mask]),
|
|
435
|
+
jnp.asarray(c2[mask]),
|
|
436
|
+
jnp.asarray(data[mask]),
|
|
437
|
+
int(free.size),
|
|
438
|
+
)
|
|
439
|
+
return K_free.coalesce() if coalesce else K_free
|
|
440
|
+
|
|
441
|
+
|
|
89
442
|
def condense_dirichlet_dense(K, F, dofs, vals):
|
|
90
443
|
"""
|
|
91
444
|
Eliminate Dirichlet dofs for dense/CSR matrices and return condensed system.
|
fluxfem/solver/history.py
CHANGED
|
@@ -1,7 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Any, List, Optional
|
|
4
|
+
from typing import Any, List, Optional, TYPE_CHECKING, TypeAlias
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .result import SolverResult
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from jax import Array as JaxArray
|
|
13
|
+
|
|
14
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
15
|
+
else:
|
|
16
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
5
17
|
|
|
6
18
|
|
|
7
19
|
@dataclass
|
|
@@ -23,9 +35,9 @@ class NewtonIterRecord:
|
|
|
23
35
|
@dataclass
|
|
24
36
|
class LoadStepResult:
|
|
25
37
|
load_factor: float
|
|
26
|
-
info:
|
|
38
|
+
info: SolverResult
|
|
27
39
|
solve_time: float
|
|
28
|
-
u:
|
|
40
|
+
u: ArrayLike
|
|
29
41
|
iter_history: List[NewtonIterRecord] = field(default_factory=list)
|
|
30
42
|
exception: Optional[str] = None
|
|
31
43
|
meta: dict[str, Any] = field(default_factory=dict)
|