fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/solver/cg.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Any, Callable, TypeAlias
4
+
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
  import jax.scipy as jsp
@@ -10,9 +12,16 @@ except Exception: # pragma: no cover
10
12
  jsparse = None
11
13
 
12
14
  from .sparse import FluxSparseMatrix
15
+ from dataclasses import dataclass
16
+
17
+ from .preconditioner import make_block_jacobi_preconditioner
18
+
19
+ ArrayLike: TypeAlias = jnp.ndarray
20
+ MatVec: TypeAlias = Callable[[jnp.ndarray], jnp.ndarray]
21
+ CGInfo: TypeAlias = dict[str, Any]
13
22
 
14
23
 
15
- def _matvec_builder(A):
24
+ def _matvec_builder(A: Any) -> MatVec:
16
25
  if jsparse is not None and isinstance(A, jsparse.BCOO):
17
26
  return lambda x: A @ x
18
27
  if isinstance(A, FluxSparseMatrix):
@@ -30,7 +39,125 @@ def _matvec_builder(A):
30
39
  return mv
31
40
 
32
41
 
33
- def _diag_builder(A, n: int):
42
+ def _coo_tuple_from_any(A: Any):
43
+ if isinstance(A, FluxSparseMatrix):
44
+ return A.to_coo()
45
+ if isinstance(A, tuple) and len(A) == 4:
46
+ return A
47
+ try:
48
+ import scipy.sparse as sp # type: ignore
49
+ except Exception: # pragma: no cover
50
+ sp = None
51
+ if sp is not None and sp.issparse(A):
52
+ coo = A.tocoo()
53
+ return (
54
+ jnp.asarray(coo.row, dtype=jnp.int32),
55
+ jnp.asarray(coo.col, dtype=jnp.int32),
56
+ jnp.asarray(coo.data),
57
+ int(A.shape[0]),
58
+ )
59
+ return None
60
+
61
+
62
+ def _to_flux_matrix(A: Any) -> FluxSparseMatrix:
63
+ if isinstance(A, FluxSparseMatrix):
64
+ return A
65
+ coo = _coo_tuple_from_any(A)
66
+ if coo is None:
67
+ raise ValueError("Unable to build FluxSparseMatrix from A")
68
+ return FluxSparseMatrix.from_bilinear(coo)
69
+
70
+
71
+ def _to_bcoo_matrix(A: Any):
72
+ if jsparse is None:
73
+ raise ImportError("jax.experimental.sparse is required for BCOO matvec")
74
+ if jsparse is not None and isinstance(A, jsparse.BCOO):
75
+ return A
76
+ coo = _coo_tuple_from_any(A)
77
+ if coo is None:
78
+ raise ValueError("Unable to build BCOO from A")
79
+ rows, cols, data, n = coo
80
+ idx = jnp.stack([rows, cols], axis=-1)
81
+ return jsparse.BCOO((data, idx), shape=(n, n))
82
+
83
+
84
+ def _normalize_matvec_matrix(A: Any, matvec: str):
85
+ if matvec == "flux":
86
+ return _to_flux_matrix(A)
87
+ if matvec == "bcoo":
88
+ return _to_bcoo_matrix(A)
89
+ if matvec == "dense":
90
+ return jnp.asarray(A)
91
+ if matvec == "auto":
92
+ if jsparse is not None:
93
+ try:
94
+ return _to_bcoo_matrix(A)
95
+ except Exception:
96
+ return _to_flux_matrix(A)
97
+ return _to_flux_matrix(A)
98
+ raise ValueError(f"Unknown matvec backend: {matvec}")
99
+
100
+
101
+ @dataclass(frozen=True)
102
+ class CGOperator:
103
+ """
104
+ Lightweight CG operator wrapper with a consistent solve() entry point.
105
+ """
106
+ A: object
107
+ preconditioner: object | None = None
108
+ solver: str = "cg"
109
+
110
+ def solve(
111
+ self,
112
+ b: jnp.ndarray,
113
+ *,
114
+ x0: jnp.ndarray | None = None,
115
+ tol: float = 1e-8,
116
+ maxiter: int | None = None,
117
+ ):
118
+ if self.solver == "cg":
119
+ return cg_solve(
120
+ self.A,
121
+ b,
122
+ x0=x0,
123
+ tol=tol,
124
+ maxiter=maxiter,
125
+ preconditioner=self.preconditioner,
126
+ )
127
+ if self.solver == "cg_jax":
128
+ return cg_solve_jax(
129
+ self.A,
130
+ b,
131
+ x0=x0,
132
+ tol=tol,
133
+ maxiter=maxiter,
134
+ preconditioner=self.preconditioner,
135
+ )
136
+ raise ValueError(f"Unknown CG solver: {self.solver}")
137
+
138
+
139
+ def build_cg_operator(
140
+ A: Any,
141
+ *,
142
+ matvec: str = "flux",
143
+ preconditioner: object | None = None,
144
+ solver: str = "cg",
145
+ dof_per_node: int | None = None,
146
+ block_sizes: object | None = None,
147
+ ) -> CGOperator:
148
+ """
149
+ Normalize CG inputs into a single operator interface.
150
+ """
151
+ A_mat = _normalize_matvec_matrix(A, matvec)
152
+ precon = preconditioner
153
+ if preconditioner == "block_jacobi":
154
+ precon = make_block_jacobi_preconditioner(
155
+ A_mat, dof_per_node=dof_per_node, block_sizes=block_sizes
156
+ )
157
+ return CGOperator(A=A_mat, preconditioner=precon, solver=solver)
158
+
159
+
160
+ def _diag_builder(A: Any, n: int) -> jnp.ndarray:
34
161
  """
35
162
  Build diagonal for a Jacobi preconditioner when available.
36
163
  """
@@ -57,14 +184,14 @@ def _diag_builder(A, n: int):
57
184
 
58
185
 
59
186
  def _cg_solve_single(
60
- A,
61
- b,
187
+ A: Any,
188
+ b: jnp.ndarray,
62
189
  *,
63
- x0=None,
190
+ x0: jnp.ndarray | None = None,
64
191
  tol: float = 1e-8,
65
192
  maxiter: int | None = None,
66
- preconditioner=None,
67
- ):
193
+ preconditioner: object | None = None,
194
+ ) -> tuple[jnp.ndarray, CGInfo]:
68
195
  """
69
196
  Conjugate gradient (Ax=b) in JAX.
70
197
  A: FluxSparseMatrix / (rows, cols, data, n) / dense array
@@ -92,39 +219,7 @@ def _cg_solve_single(
92
219
  return inv_diag * r
93
220
 
94
221
  elif preconditioner == "block_jacobi":
95
- # Expect 3 DOFs per node
96
- if n % 3 != 0:
97
- raise ValueError("block_jacobi requires n_dofs % 3 == 0")
98
- if jsparse is not None and isinstance(A, jsparse.BCOO):
99
- rows = A.indices[:, 0]
100
- cols = A.indices[:, 1]
101
- data = A.data
102
- elif isinstance(A, FluxSparseMatrix):
103
- rows = jnp.asarray(A.pattern.rows)
104
- cols = jnp.asarray(A.pattern.cols)
105
- data = jnp.asarray(A.data)
106
- else:
107
- raise ValueError("block_jacobi requires FluxSparseMatrix or BCOO")
108
-
109
- block_rows = rows // 3
110
- block_cols = cols // 3
111
- lr = rows % 3
112
- lc = cols % 3
113
- mask = block_rows == block_cols
114
- block_rows = block_rows[mask]
115
- lr = lr[mask]
116
- lc = lc[mask]
117
- data = data[mask]
118
- n_block = n // 3
119
- blocks = jnp.zeros((n_block, 3, 3), dtype=data.dtype)
120
- blocks = blocks.at[block_rows, lr, lc].add(data)
121
- blocks = blocks + 1e-12 * jnp.eye(3)[None, :, :]
122
- inv_blocks = jnp.linalg.inv(blocks)
123
-
124
- def precon(r):
125
- rb = r.reshape((n_block, 3))
126
- zb = jnp.einsum("bij,bj->bi", inv_blocks, rb)
127
- return zb.reshape((-1,))
222
+ precon = make_block_jacobi_preconditioner(A)
128
223
 
129
224
  elif callable(preconditioner):
130
225
  precon = preconditioner
@@ -159,14 +254,14 @@ def _cg_solve_single(
159
254
 
160
255
 
161
256
  def cg_solve(
162
- A,
163
- b,
257
+ A: Any,
258
+ b: jnp.ndarray,
164
259
  *,
165
- x0=None,
260
+ x0: jnp.ndarray | None = None,
166
261
  tol: float = 1e-8,
167
262
  maxiter: int | None = None,
168
- preconditioner=None,
169
- ):
263
+ preconditioner: object | None = None,
264
+ ) -> tuple[jnp.ndarray, CGInfo]:
170
265
  """
171
266
  Conjugate gradient (Ax=b) in JAX.
172
267
  Supports single RHS (n,) or multiple RHS (n, n_rhs).
@@ -208,14 +303,14 @@ def cg_solve(
208
303
 
209
304
 
210
305
  def _cg_solve_jax_single(
211
- A,
212
- b,
306
+ A: Any,
307
+ b: jnp.ndarray,
213
308
  *,
214
- x0=None,
309
+ x0: jnp.ndarray | None = None,
215
310
  tol: float = 1e-8,
216
311
  maxiter: int | None = None,
217
- preconditioner=None,
218
- ):
312
+ preconditioner: object | None = None,
313
+ ) -> tuple[jnp.ndarray, CGInfo]:
219
314
  """
220
315
  Conjugate gradient via jax.scipy.sparse.linalg.cg.
221
316
  A: FluxSparseMatrix / (rows, cols, data, n) / dense array / callable
@@ -277,14 +372,14 @@ def _cg_solve_jax_single(
277
372
 
278
373
 
279
374
  def cg_solve_jax(
280
- A,
281
- b,
375
+ A: Any,
376
+ b: jnp.ndarray,
282
377
  *,
283
- x0=None,
378
+ x0: jnp.ndarray | None = None,
284
379
  tol: float = 1e-8,
285
380
  maxiter: int | None = None,
286
- preconditioner=None,
287
- ):
381
+ preconditioner: object | None = None,
382
+ ) -> tuple[jnp.ndarray, CGInfo]:
288
383
  """
289
384
  Conjugate gradient via jax.scipy.sparse.linalg.cg.
290
385
  Supports single RHS (n,) or multiple RHS (n, n_rhs).
@@ -1,16 +1,155 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import dataclass
4
+ from typing import Any, TYPE_CHECKING, TypeAlias
5
+
3
6
  import numpy as np
4
7
  import jax.numpy as jnp
5
8
 
6
- from .sparse import FluxSparseMatrix
9
+ try:
10
+ import scipy.sparse as sp
11
+ except Exception: # pragma: no cover
12
+ sp = None
7
13
 
14
+ from .sparse import FluxSparseMatrix, coalesce_coo
8
15
 
9
- def _normalize_dirichlet(dofs, vals):
10
- dir_arr = np.asarray(dofs, dtype=int)
16
+ if TYPE_CHECKING:
17
+ from jax import Array as JaxArray
18
+
19
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
20
+ else:
21
+ ArrayLike: TypeAlias = np.ndarray
22
+ DirichletLike: TypeAlias = tuple[np.ndarray, np.ndarray]
23
+
24
+
25
+ def _normalize_dirichlet_values(dofs: ArrayLike, vals: ArrayLike | None) -> np.ndarray:
11
26
  if vals is None:
12
- return dir_arr, np.zeros(dir_arr.shape[0], dtype=float)
13
- return dir_arr, np.asarray(vals, dtype=float)
27
+ return np.zeros(np.asarray(dofs).shape[0], dtype=float)
28
+ arr = np.asarray(vals)
29
+ if arr.ndim == 0:
30
+ return np.full(np.asarray(dofs).shape[0], float(arr), dtype=float)
31
+ return arr
32
+
33
+
34
+ def _normalize_dirichlet(dofs: ArrayLike, vals: ArrayLike | None) -> DirichletLike:
35
+ dir_arr = np.asarray(dofs, dtype=int)
36
+ return dir_arr, _normalize_dirichlet_values(dir_arr, vals)
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class CondensedSystem:
41
+ K: Any
42
+ F: Any
43
+ free_dofs: np.ndarray
44
+ dir_dofs: np.ndarray
45
+ dir_vals: np.ndarray
46
+ n_dofs: int
47
+
48
+ def expand(self, u_free: ArrayLike, *, fill_dirichlet: bool = True) -> np.ndarray:
49
+ u_full = np.zeros(self.n_dofs, dtype=np.asarray(u_free).dtype)
50
+ u_full[self.free_dofs] = np.asarray(u_free)
51
+ if fill_dirichlet and self.dir_dofs.size:
52
+ u_full[self.dir_dofs] = np.asarray(self.dir_vals, dtype=u_full.dtype)
53
+ return u_full
54
+
55
+
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class DirichletBC:
60
+ """
61
+ Dirichlet boundary condition container with helper methods.
62
+ """
63
+ dofs: np.ndarray
64
+ vals: np.ndarray
65
+
66
+ def __post_init__(self):
67
+ dofs, vals = _normalize_dirichlet(self.dofs, self.vals)
68
+ object.__setattr__(self, "dofs", dofs)
69
+ object.__setattr__(self, "vals", vals)
70
+
71
+ @classmethod
72
+ def from_boundary_dofs(cls, mesh, predicate, *, values: ArrayLike | None = None, **kwargs) -> "DirichletBC":
73
+ """
74
+ Build from mesh.boundary_dofs_where predicate.
75
+
76
+ kwargs are forwarded to mesh.boundary_dofs_where (e.g. components=..., dof_per_node=...).
77
+ """
78
+ dofs = mesh.boundary_dofs_where(predicate, **kwargs)
79
+ vals = _normalize_dirichlet_values(dofs, values)
80
+ return cls(dofs, vals)
81
+
82
+ @classmethod
83
+ def from_bbox(
84
+ cls,
85
+ mesh,
86
+ *,
87
+ mins: ArrayLike | None = None,
88
+ maxs: ArrayLike | None = None,
89
+ tol: float = 1e-8,
90
+ values: ArrayLike | None = None,
91
+ **kwargs,
92
+ ) -> "DirichletBC":
93
+ """
94
+ Build from the mesh axis-aligned bounding box.
95
+
96
+ mins/maxs default to mesh coordinate extrema. kwargs are forwarded to
97
+ mesh.boundary_dofs_where (e.g. components=..., dof_per_node=...).
98
+ """
99
+ from ..mesh.predicate import bbox_predicate
100
+
101
+ coords = np.asarray(mesh.coords)
102
+ if mins is None:
103
+ mins = coords.min(axis=0)
104
+ if maxs is None:
105
+ maxs = coords.max(axis=0)
106
+ pred = bbox_predicate(mins, maxs, tol=tol)
107
+ dofs = mesh.boundary_dofs_where(pred, **kwargs)
108
+ vals = _normalize_dirichlet_values(dofs, values)
109
+ return cls(dofs, vals)
110
+
111
+ def as_tuple(self) -> tuple[np.ndarray, np.ndarray]:
112
+ return self.dofs, self.vals
113
+
114
+ def condense_system(self, A: Any, F: ArrayLike, *, check: bool = True) -> CondensedSystem:
115
+ return condense_dirichlet_system(A, F, self.dofs, self.vals, check=check)
116
+
117
+ def enforce_system(self, A: Any, F: ArrayLike):
118
+ return enforce_dirichlet_system(A, F, self.dofs, self.vals)
119
+
120
+ def condense_flux(self, A: FluxSparseMatrix, F: ArrayLike):
121
+ """
122
+ Condense for FluxSparseMatrix and return (K_free, F_free, free_dofs).
123
+ """
124
+ condensed = self.condense_system(A, F)
125
+ free = condensed.free_dofs
126
+ return restrict_flux_to_free(A, free), condensed.F, free
127
+
128
+ def enforce_flux(self, A: FluxSparseMatrix, F: ArrayLike):
129
+ return enforce_dirichlet_fluxsparse(A, F, self.dofs, self.vals)
130
+
131
+ def split_matrix(self, A: Any, *, n_total: int | None = None):
132
+ return split_dirichlet_matrix(A, self.dofs, n_total=n_total)
133
+
134
+ def free_dofs(self, n_dofs: int) -> np.ndarray:
135
+ return free_dofs(n_dofs, self.dofs)
136
+
137
+ def expand_solution(
138
+ self,
139
+ u_free: ArrayLike,
140
+ *,
141
+ free: np.ndarray | None = None,
142
+ n_total: int | None = None,
143
+ ) -> np.ndarray:
144
+ if free is None:
145
+ if n_total is None:
146
+ raise ValueError("n_total is required when free is not provided.")
147
+ free = free_dofs(n_total, self.dofs)
148
+ if n_total is None:
149
+ max_free = int(np.max(free)) if len(free) else -1
150
+ max_dir = int(np.max(self.dofs)) if len(self.dofs) else -1
151
+ n_total = max(max_free, max_dir) + 1
152
+ return expand_dirichlet_solution(u_free, free, self.dofs, self.vals, n_total=n_total)
14
153
 
15
154
 
16
155
  def enforce_dirichlet_dense(K, F, dofs, vals):
@@ -33,6 +172,135 @@ def enforce_dirichlet_dense(K, F, dofs, vals):
33
172
  return Kc, Fc
34
173
 
35
174
 
175
+ def enforce_dirichlet_dense_jax(K, F, dofs, vals):
176
+ """Apply Dirichlet conditions directly to stiffness/load (dense, JAX-friendly)."""
177
+ import jax.numpy as jnp
178
+
179
+ dofs, vals = _normalize_dirichlet(dofs, vals)
180
+ dofs = jnp.asarray(dofs, dtype=jnp.int32)
181
+ vals = jnp.asarray(vals, dtype=F.dtype)
182
+ if F.ndim == 2:
183
+ F_mod = F - (K[:, dofs] @ vals)[:, None]
184
+ else:
185
+ F_mod = F - K[:, dofs] @ vals
186
+ K_mod = K.at[:, dofs].set(0.0)
187
+ K_mod = K_mod.at[dofs, :].set(0.0)
188
+ K_mod = K_mod.at[dofs, dofs].set(1.0)
189
+ if F.ndim == 2:
190
+ F_mod = F_mod.at[dofs, :].set(vals[:, None])
191
+ else:
192
+ F_mod = F_mod.at[dofs].set(vals)
193
+ return K_mod, F_mod
194
+
195
+
196
+ def enforce_dirichlet_system(A, F, dofs, vals):
197
+ """
198
+ Apply Dirichlet conditions directly to stiffness/load.
199
+ Dispatches based on matrix type (FluxSparseMatrix, JAX dense, or numpy dense).
200
+ """
201
+ if isinstance(A, FluxSparseMatrix):
202
+ return enforce_dirichlet_sparse(A, F, dofs, vals)
203
+ try:
204
+ import jax.numpy as jnp
205
+
206
+ if isinstance(A, jnp.ndarray):
207
+ return enforce_dirichlet_dense_jax(A, F, dofs, vals)
208
+ except Exception:
209
+ pass
210
+ return enforce_dirichlet_dense(A, F, dofs, vals)
211
+
212
+
213
+ def split_dirichlet_matrix(A, dir_dofs, *, n_total: int | None = None):
214
+ """
215
+ Split a matrix into free-free and free-dirichlet blocks.
216
+
217
+ Returns (free, dir_dofs, A_ff, A_fd).
218
+ """
219
+ dir_dofs, _ = _normalize_dirichlet(dir_dofs, None)
220
+ if n_total is None:
221
+ if hasattr(A, "n_dofs"):
222
+ n_total = int(A.n_dofs)
223
+ else:
224
+ arr = np.asarray(A)
225
+ if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
226
+ raise ValueError("A must be square when n_total is not provided.")
227
+ n_total = int(arr.shape[0])
228
+ free = free_dofs(n_total, dir_dofs)
229
+
230
+ if isinstance(A, FluxSparseMatrix):
231
+ if sp is None:
232
+ raise ImportError("scipy is required to split FluxSparseMatrix.")
233
+ A = A.to_csr()
234
+
235
+ if sp is not None and sp.issparse(A):
236
+ return free, dir_dofs, A[free][:, free], A[free][:, dir_dofs]
237
+
238
+ if isinstance(A, jnp.ndarray):
239
+ free_j = jnp.asarray(free, dtype=jnp.int32)
240
+ dir_j = jnp.asarray(dir_dofs, dtype=jnp.int32)
241
+ return free, dir_dofs, A[jnp.ix_(free_j, free_j)], A[jnp.ix_(free_j, dir_j)]
242
+
243
+ arr = np.asarray(A)
244
+ return free, dir_dofs, arr[np.ix_(free, free)], arr[np.ix_(free, dir_dofs)]
245
+
246
+
247
+ def condense_dirichlet_system(A, F, dofs, vals, *, check: bool = True) -> CondensedSystem:
248
+ """
249
+ Condense Dirichlet DOFs and return a structured system.
250
+ """
251
+ dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
252
+ F_arr = np.asarray(F)
253
+ if hasattr(A, "n_dofs"):
254
+ n_total = int(A.n_dofs)
255
+ else:
256
+ A_np = np.asarray(A)
257
+ if A_np.ndim != 2 or A_np.shape[0] != A_np.shape[1]:
258
+ raise ValueError("A must be square for Dirichlet condensation.")
259
+ n_total = int(A_np.shape[0])
260
+
261
+ if check:
262
+ if dir_arr.size != dir_vals_arr.size:
263
+ raise ValueError("dir_dofs and dir_vals must have the same length")
264
+ if dir_arr.size:
265
+ if np.min(dir_arr) < 0 or np.max(dir_arr) >= n_total:
266
+ raise ValueError("dir_dofs out of bounds")
267
+ if np.unique(dir_arr).size != dir_arr.size:
268
+ raise ValueError("dir_dofs contains duplicates")
269
+
270
+ mask = np.ones(n_total, dtype=bool)
271
+ mask[dir_arr] = False
272
+ free = np.nonzero(mask)[0]
273
+
274
+ if isinstance(A, FluxSparseMatrix):
275
+ K_csr = A.to_csr()
276
+ elif sp is not None and sp.issparse(A):
277
+ K_csr = A.tocsr()
278
+ elif hasattr(A, "to_csr"):
279
+ K_csr = A.to_csr()
280
+ else:
281
+ K_csr = np.asarray(A)
282
+
283
+ K_ff = K_csr[free][:, free]
284
+ F_free = F_arr[free]
285
+ if dir_arr.size:
286
+ K_fd = K_csr[free][:, dir_arr]
287
+ if F_free.ndim == 2:
288
+ F_free = F_free - (K_fd @ dir_vals_arr)[:, None]
289
+ else:
290
+ F_free = F_free - K_fd @ dir_vals_arr
291
+
292
+ return CondensedSystem(
293
+ K=K_ff,
294
+ F=F_free,
295
+ free_dofs=free,
296
+ dir_dofs=dir_arr,
297
+ dir_vals=dir_vals_arr,
298
+ n_dofs=n_total,
299
+ )
300
+
301
+
302
+
303
+
36
304
  def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
37
305
  """Apply Dirichlet conditions to FluxSparseMatrix + load (CSR)."""
38
306
  K_csr = A.to_csr().tolil()
@@ -54,6 +322,11 @@ def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
54
322
  return K_csr.tocsr(), Fc
55
323
 
56
324
 
325
+ def enforce_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
326
+ """Alias for enforce_dirichlet_sparse for FluxSparseMatrix inputs."""
327
+ return enforce_dirichlet_sparse(A, F, dofs, vals)
328
+
329
+
57
330
  def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
58
331
  """
59
332
  Condense Dirichlet DOFs for a FluxSparseMatrix.
@@ -76,6 +349,63 @@ def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
76
349
  return K_ff, F_free, free, dir_arr, dir_vals_arr
77
350
 
78
351
 
352
+ def condense_dirichlet_fluxsparse_coo(
353
+ A: FluxSparseMatrix,
354
+ F,
355
+ dofs,
356
+ vals,
357
+ *,
358
+ coalesce: bool = True,
359
+ ):
360
+ """
361
+ Condense Dirichlet DOFs for a FluxSparseMatrix using COO filtering.
362
+ Returns: (K_free, F_free, free_dofs, dir_dofs, dir_vals)
363
+ """
364
+ dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
365
+ n_total = int(A.n_dofs)
366
+ mask = np.ones(n_total, dtype=bool)
367
+ mask[dir_arr] = False
368
+ free = np.nonzero(mask)[0]
369
+
370
+ rows = np.asarray(A.pattern.rows, dtype=np.int64)
371
+ cols = np.asarray(A.pattern.cols, dtype=np.int64)
372
+ data = np.asarray(A.data)
373
+
374
+ g2l = -np.ones(n_total, dtype=np.int32)
375
+ g2l[free] = np.arange(free.size, dtype=np.int32)
376
+ r2 = g2l[rows]
377
+ c2 = g2l[cols]
378
+ keep = (r2 >= 0) & (c2 >= 0)
379
+
380
+ rows_f = r2[keep]
381
+ cols_f = c2[keep]
382
+ data_f = data[keep]
383
+ if coalesce:
384
+ rows_f, cols_f, data_f = coalesce_coo(rows_f, cols_f, data_f)
385
+
386
+ K_free = FluxSparseMatrix(rows_f, cols_f, data_f, int(free.size))
387
+
388
+ F_arr = np.asarray(F, dtype=float)
389
+ F_free = F_arr[free]
390
+ if dir_arr.size > 0 and not np.allclose(dir_vals_arr, 0.0):
391
+ dir_full = np.zeros(n_total, dtype=F_arr.dtype)
392
+ dir_full[dir_arr] = dir_vals_arr
393
+ mask_fd = mask[rows] & (~mask[cols])
394
+ if np.any(mask_fd):
395
+ rows_fd = rows[mask_fd]
396
+ cols_fd = cols[mask_fd]
397
+ data_fd = data[mask_fd]
398
+ contrib = data_fd * dir_full[cols_fd]
399
+ delta = np.zeros(n_total, dtype=F_arr.dtype)
400
+ np.add.at(delta, rows_fd, contrib)
401
+ if F_free.ndim == 2:
402
+ F_free = F_free - delta[free][:, None]
403
+ else:
404
+ F_free = F_free - delta[free]
405
+
406
+ return K_free, jnp.asarray(F_free), free, dir_arr, dir_vals_arr
407
+
408
+
79
409
  def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
80
410
  """
81
411
  Return free DOF indices given total DOFs and Dirichlet DOFs.
@@ -86,6 +416,29 @@ def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
86
416
  return np.nonzero(mask)[0]
87
417
 
88
418
 
419
+ def restrict_flux_to_free(K: FluxSparseMatrix, free: np.ndarray, *, coalesce: bool = True) -> FluxSparseMatrix:
420
+ """
421
+ Restrict a FluxSparseMatrix to free DOFs and return the condensed matrix.
422
+ """
423
+ free = np.asarray(free, dtype=np.int32)
424
+ g2l = -np.ones(K.n_dofs, dtype=np.int32)
425
+ g2l[free] = np.arange(free.size, dtype=np.int32)
426
+
427
+ rows = np.asarray(K.pattern.rows)
428
+ cols = np.asarray(K.pattern.cols)
429
+ data = np.asarray(K.data)
430
+ r2 = g2l[rows]
431
+ c2 = g2l[cols]
432
+ mask = (r2 >= 0) & (c2 >= 0)
433
+ K_free = FluxSparseMatrix(
434
+ jnp.asarray(r2[mask]),
435
+ jnp.asarray(c2[mask]),
436
+ jnp.asarray(data[mask]),
437
+ int(free.size),
438
+ )
439
+ return K_free.coalesce() if coalesce else K_free
440
+
441
+
89
442
  def condense_dirichlet_dense(K, F, dofs, vals):
90
443
  """
91
444
  Eliminate Dirichlet dofs for dense/CSR matrices and return condensed system.
fluxfem/solver/history.py CHANGED
@@ -1,7 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
- from typing import Any, List, Optional
4
+ from typing import Any, List, Optional, TYPE_CHECKING, TypeAlias
5
+
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ from .result import SolverResult
10
+
11
+ if TYPE_CHECKING:
12
+ from jax import Array as JaxArray
13
+
14
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
15
+ else:
16
+ ArrayLike: TypeAlias = np.ndarray
5
17
 
6
18
 
7
19
  @dataclass
@@ -23,9 +35,9 @@ class NewtonIterRecord:
23
35
  @dataclass
24
36
  class LoadStepResult:
25
37
  load_factor: float
26
- info: Any
38
+ info: SolverResult
27
39
  solve_time: float
28
- u: Any
40
+ u: ArrayLike
29
41
  iter_history: List[NewtonIterRecord] = field(default_factory=list)
30
42
  exception: Optional[str] = None
31
43
  meta: dict[str, Any] = field(default_factory=dict)