fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/solver/cg.py CHANGED
@@ -10,6 +10,9 @@ except Exception: # pragma: no cover
10
10
  jsparse = None
11
11
 
12
12
  from .sparse import FluxSparseMatrix
13
+ from dataclasses import dataclass
14
+
15
+ from .preconditioner import make_block_jacobi_preconditioner
13
16
 
14
17
 
15
18
  def _matvec_builder(A):
@@ -30,6 +33,117 @@ def _matvec_builder(A):
30
33
  return mv
31
34
 
32
35
 
36
+ def _coo_tuple_from_any(A):
37
+ if isinstance(A, FluxSparseMatrix):
38
+ return A.to_coo()
39
+ if isinstance(A, tuple) and len(A) == 4:
40
+ return A
41
+ try:
42
+ import scipy.sparse as sp # type: ignore
43
+ except Exception: # pragma: no cover
44
+ sp = None
45
+ if sp is not None and sp.issparse(A):
46
+ coo = A.tocoo()
47
+ return (
48
+ jnp.asarray(coo.row, dtype=jnp.int32),
49
+ jnp.asarray(coo.col, dtype=jnp.int32),
50
+ jnp.asarray(coo.data),
51
+ int(A.shape[0]),
52
+ )
53
+ return None
54
+
55
+
56
+ def _to_flux_matrix(A):
57
+ if isinstance(A, FluxSparseMatrix):
58
+ return A
59
+ coo = _coo_tuple_from_any(A)
60
+ if coo is None:
61
+ raise ValueError("Unable to build FluxSparseMatrix from A")
62
+ return FluxSparseMatrix.from_bilinear(coo)
63
+
64
+
65
+ def _to_bcoo_matrix(A):
66
+ if jsparse is None:
67
+ raise ImportError("jax.experimental.sparse is required for BCOO matvec")
68
+ if jsparse is not None and isinstance(A, jsparse.BCOO):
69
+ return A
70
+ coo = _coo_tuple_from_any(A)
71
+ if coo is None:
72
+ raise ValueError("Unable to build BCOO from A")
73
+ rows, cols, data, n = coo
74
+ idx = jnp.stack([rows, cols], axis=-1)
75
+ return jsparse.BCOO((data, idx), shape=(n, n))
76
+
77
+
78
+ def _normalize_matvec_matrix(A, matvec: str):
79
+ if matvec == "flux":
80
+ return _to_flux_matrix(A)
81
+ if matvec == "bcoo":
82
+ return _to_bcoo_matrix(A)
83
+ if matvec == "dense":
84
+ return jnp.asarray(A)
85
+ if matvec == "auto":
86
+ if jsparse is not None:
87
+ try:
88
+ return _to_bcoo_matrix(A)
89
+ except Exception:
90
+ return _to_flux_matrix(A)
91
+ return _to_flux_matrix(A)
92
+ raise ValueError(f"Unknown matvec backend: {matvec}")
93
+
94
+
95
+ @dataclass(frozen=True)
96
+ class CGOperator:
97
+ """
98
+ Lightweight CG operator wrapper with a consistent solve() entry point.
99
+ """
100
+ A: object
101
+ preconditioner: object | None = None
102
+ solver: str = "cg"
103
+
104
+ def solve(self, b, *, x0=None, tol: float = 1e-8, maxiter: int | None = None):
105
+ if self.solver == "cg":
106
+ return cg_solve(
107
+ self.A,
108
+ b,
109
+ x0=x0,
110
+ tol=tol,
111
+ maxiter=maxiter,
112
+ preconditioner=self.preconditioner,
113
+ )
114
+ if self.solver == "cg_jax":
115
+ return cg_solve_jax(
116
+ self.A,
117
+ b,
118
+ x0=x0,
119
+ tol=tol,
120
+ maxiter=maxiter,
121
+ preconditioner=self.preconditioner,
122
+ )
123
+ raise ValueError(f"Unknown CG solver: {self.solver}")
124
+
125
+
126
+ def build_cg_operator(
127
+ A,
128
+ *,
129
+ matvec: str = "flux",
130
+ preconditioner=None,
131
+ solver: str = "cg",
132
+ dof_per_node: int | None = None,
133
+ block_sizes=None,
134
+ ) -> CGOperator:
135
+ """
136
+ Normalize CG inputs into a single operator interface.
137
+ """
138
+ A_mat = _normalize_matvec_matrix(A, matvec)
139
+ precon = preconditioner
140
+ if preconditioner == "block_jacobi":
141
+ precon = make_block_jacobi_preconditioner(
142
+ A_mat, dof_per_node=dof_per_node, block_sizes=block_sizes
143
+ )
144
+ return CGOperator(A=A_mat, preconditioner=precon, solver=solver)
145
+
146
+
33
147
  def _diag_builder(A, n: int):
34
148
  """
35
149
  Build diagonal for a Jacobi preconditioner when available.
@@ -92,39 +206,7 @@ def _cg_solve_single(
92
206
  return inv_diag * r
93
207
 
94
208
  elif preconditioner == "block_jacobi":
95
- # Expect 3 DOFs per node
96
- if n % 3 != 0:
97
- raise ValueError("block_jacobi requires n_dofs % 3 == 0")
98
- if jsparse is not None and isinstance(A, jsparse.BCOO):
99
- rows = A.indices[:, 0]
100
- cols = A.indices[:, 1]
101
- data = A.data
102
- elif isinstance(A, FluxSparseMatrix):
103
- rows = jnp.asarray(A.pattern.rows)
104
- cols = jnp.asarray(A.pattern.cols)
105
- data = jnp.asarray(A.data)
106
- else:
107
- raise ValueError("block_jacobi requires FluxSparseMatrix or BCOO")
108
-
109
- block_rows = rows // 3
110
- block_cols = cols // 3
111
- lr = rows % 3
112
- lc = cols % 3
113
- mask = block_rows == block_cols
114
- block_rows = block_rows[mask]
115
- lr = lr[mask]
116
- lc = lc[mask]
117
- data = data[mask]
118
- n_block = n // 3
119
- blocks = jnp.zeros((n_block, 3, 3), dtype=data.dtype)
120
- blocks = blocks.at[block_rows, lr, lc].add(data)
121
- blocks = blocks + 1e-12 * jnp.eye(3)[None, :, :]
122
- inv_blocks = jnp.linalg.inv(blocks)
123
-
124
- def precon(r):
125
- rb = r.reshape((n_block, 3))
126
- zb = jnp.einsum("bij,bj->bi", inv_blocks, rb)
127
- return zb.reshape((-1,))
209
+ precon = make_block_jacobi_preconditioner(A)
128
210
 
129
211
  elif callable(preconditioner):
130
212
  precon = preconditioner
@@ -1,16 +1,132 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
3
6
  import numpy as np
4
7
  import jax.numpy as jnp
5
8
 
6
- from .sparse import FluxSparseMatrix
9
+ try:
10
+ import scipy.sparse as sp
11
+ except Exception: # pragma: no cover
12
+ sp = None
13
+
14
+ from .sparse import FluxSparseMatrix, coalesce_coo
15
+
16
+
17
+ def _normalize_dirichlet_values(dofs, vals):
18
+ if vals is None:
19
+ return np.zeros(np.asarray(dofs).shape[0], dtype=float)
20
+ arr = np.asarray(vals)
21
+ if arr.ndim == 0:
22
+ return np.full(np.asarray(dofs).shape[0], float(arr), dtype=float)
23
+ return arr
7
24
 
8
25
 
9
26
  def _normalize_dirichlet(dofs, vals):
10
27
  dir_arr = np.asarray(dofs, dtype=int)
11
- if vals is None:
12
- return dir_arr, np.zeros(dir_arr.shape[0], dtype=float)
13
- return dir_arr, np.asarray(vals, dtype=float)
28
+ return dir_arr, _normalize_dirichlet_values(dir_arr, vals)
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class CondensedSystem:
33
+ K: Any
34
+ F: Any
35
+ free_dofs: np.ndarray
36
+ dir_dofs: np.ndarray
37
+ dir_vals: np.ndarray
38
+ n_dofs: int
39
+
40
+ def expand(self, u_free, *, fill_dirichlet: bool = True):
41
+ u_full = np.zeros(self.n_dofs, dtype=np.asarray(u_free).dtype)
42
+ u_full[self.free_dofs] = np.asarray(u_free)
43
+ if fill_dirichlet and self.dir_dofs.size:
44
+ u_full[self.dir_dofs] = np.asarray(self.dir_vals, dtype=u_full.dtype)
45
+ return u_full
46
+
47
+
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class DirichletBC:
52
+ """
53
+ Dirichlet boundary condition container with helper methods.
54
+ """
55
+ dofs: np.ndarray
56
+ vals: np.ndarray
57
+
58
+ def __post_init__(self):
59
+ dofs, vals = _normalize_dirichlet(self.dofs, self.vals)
60
+ object.__setattr__(self, "dofs", dofs)
61
+ object.__setattr__(self, "vals", vals)
62
+
63
+ @classmethod
64
+ def from_boundary_dofs(cls, mesh, predicate, *, values=None, **kwargs):
65
+ """
66
+ Build from mesh.boundary_dofs_where predicate.
67
+
68
+ kwargs are forwarded to mesh.boundary_dofs_where (e.g. components=..., dof_per_node=...).
69
+ """
70
+ dofs = mesh.boundary_dofs_where(predicate, **kwargs)
71
+ vals = _normalize_dirichlet_values(dofs, values)
72
+ return cls(dofs, vals)
73
+
74
+ @classmethod
75
+ def from_bbox(cls, mesh, *, mins=None, maxs=None, tol: float = 1e-8, values=None, **kwargs):
76
+ """
77
+ Build from the mesh axis-aligned bounding box.
78
+
79
+ mins/maxs default to mesh coordinate extrema. kwargs are forwarded to
80
+ mesh.boundary_dofs_where (e.g. components=..., dof_per_node=...).
81
+ """
82
+ from ..mesh.predicate import bbox_predicate
83
+
84
+ coords = np.asarray(mesh.coords)
85
+ if mins is None:
86
+ mins = coords.min(axis=0)
87
+ if maxs is None:
88
+ maxs = coords.max(axis=0)
89
+ pred = bbox_predicate(mins, maxs, tol=tol)
90
+ dofs = mesh.boundary_dofs_where(pred, **kwargs)
91
+ vals = _normalize_dirichlet_values(dofs, values)
92
+ return cls(dofs, vals)
93
+
94
+ def as_tuple(self) -> tuple[np.ndarray, np.ndarray]:
95
+ return self.dofs, self.vals
96
+
97
+ def condense_system(self, A, F, *, check: bool = True) -> CondensedSystem:
98
+ return condense_dirichlet_system(A, F, self.dofs, self.vals, check=check)
99
+
100
+ def enforce_system(self, A, F):
101
+ return enforce_dirichlet_system(A, F, self.dofs, self.vals)
102
+
103
+ def condense_flux(self, A: FluxSparseMatrix, F):
104
+ """
105
+ Condense for FluxSparseMatrix and return (K_free, F_free, free_dofs).
106
+ """
107
+ condensed = self.condense_system(A, F)
108
+ free = condensed.free_dofs
109
+ return restrict_flux_to_free(A, free), condensed.F, free
110
+
111
+ def enforce_flux(self, A: FluxSparseMatrix, F):
112
+ return enforce_dirichlet_fluxsparse(A, F, self.dofs, self.vals)
113
+
114
+ def split_matrix(self, A, *, n_total: int | None = None):
115
+ return split_dirichlet_matrix(A, self.dofs, n_total=n_total)
116
+
117
+ def free_dofs(self, n_dofs: int) -> np.ndarray:
118
+ return free_dofs(n_dofs, self.dofs)
119
+
120
+ def expand_solution(self, u_free, *, free=None, n_total: int | None = None):
121
+ if free is None:
122
+ if n_total is None:
123
+ raise ValueError("n_total is required when free is not provided.")
124
+ free = free_dofs(n_total, self.dofs)
125
+ if n_total is None:
126
+ max_free = int(np.max(free)) if len(free) else -1
127
+ max_dir = int(np.max(self.dofs)) if len(self.dofs) else -1
128
+ n_total = max(max_free, max_dir) + 1
129
+ return expand_dirichlet_solution(u_free, free, self.dofs, self.vals, n_total=n_total)
14
130
 
15
131
 
16
132
  def enforce_dirichlet_dense(K, F, dofs, vals):
@@ -33,6 +149,135 @@ def enforce_dirichlet_dense(K, F, dofs, vals):
33
149
  return Kc, Fc
34
150
 
35
151
 
152
+ def enforce_dirichlet_dense_jax(K, F, dofs, vals):
153
+ """Apply Dirichlet conditions directly to stiffness/load (dense, JAX-friendly)."""
154
+ import jax.numpy as jnp
155
+
156
+ dofs, vals = _normalize_dirichlet(dofs, vals)
157
+ dofs = jnp.asarray(dofs, dtype=jnp.int32)
158
+ vals = jnp.asarray(vals, dtype=F.dtype)
159
+ if F.ndim == 2:
160
+ F_mod = F - (K[:, dofs] @ vals)[:, None]
161
+ else:
162
+ F_mod = F - K[:, dofs] @ vals
163
+ K_mod = K.at[:, dofs].set(0.0)
164
+ K_mod = K_mod.at[dofs, :].set(0.0)
165
+ K_mod = K_mod.at[dofs, dofs].set(1.0)
166
+ if F.ndim == 2:
167
+ F_mod = F_mod.at[dofs, :].set(vals[:, None])
168
+ else:
169
+ F_mod = F_mod.at[dofs].set(vals)
170
+ return K_mod, F_mod
171
+
172
+
173
+ def enforce_dirichlet_system(A, F, dofs, vals):
174
+ """
175
+ Apply Dirichlet conditions directly to stiffness/load.
176
+ Dispatches based on matrix type (FluxSparseMatrix, JAX dense, or numpy dense).
177
+ """
178
+ if isinstance(A, FluxSparseMatrix):
179
+ return enforce_dirichlet_sparse(A, F, dofs, vals)
180
+ try:
181
+ import jax.numpy as jnp
182
+
183
+ if isinstance(A, jnp.ndarray):
184
+ return enforce_dirichlet_dense_jax(A, F, dofs, vals)
185
+ except Exception:
186
+ pass
187
+ return enforce_dirichlet_dense(A, F, dofs, vals)
188
+
189
+
190
+ def split_dirichlet_matrix(A, dir_dofs, *, n_total: int | None = None):
191
+ """
192
+ Split a matrix into free-free and free-dirichlet blocks.
193
+
194
+ Returns (free, dir_dofs, A_ff, A_fd).
195
+ """
196
+ dir_dofs, _ = _normalize_dirichlet(dir_dofs, None)
197
+ if n_total is None:
198
+ if hasattr(A, "n_dofs"):
199
+ n_total = int(A.n_dofs)
200
+ else:
201
+ arr = np.asarray(A)
202
+ if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
203
+ raise ValueError("A must be square when n_total is not provided.")
204
+ n_total = int(arr.shape[0])
205
+ free = free_dofs(n_total, dir_dofs)
206
+
207
+ if isinstance(A, FluxSparseMatrix):
208
+ if sp is None:
209
+ raise ImportError("scipy is required to split FluxSparseMatrix.")
210
+ A = A.to_csr()
211
+
212
+ if sp is not None and sp.issparse(A):
213
+ return free, dir_dofs, A[free][:, free], A[free][:, dir_dofs]
214
+
215
+ if isinstance(A, jnp.ndarray):
216
+ free_j = jnp.asarray(free, dtype=jnp.int32)
217
+ dir_j = jnp.asarray(dir_dofs, dtype=jnp.int32)
218
+ return free, dir_dofs, A[jnp.ix_(free_j, free_j)], A[jnp.ix_(free_j, dir_j)]
219
+
220
+ arr = np.asarray(A)
221
+ return free, dir_dofs, arr[np.ix_(free, free)], arr[np.ix_(free, dir_dofs)]
222
+
223
+
224
+ def condense_dirichlet_system(A, F, dofs, vals, *, check: bool = True) -> CondensedSystem:
225
+ """
226
+ Condense Dirichlet DOFs and return a structured system.
227
+ """
228
+ dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
229
+ F_arr = np.asarray(F)
230
+ if hasattr(A, "n_dofs"):
231
+ n_total = int(A.n_dofs)
232
+ else:
233
+ A_np = np.asarray(A)
234
+ if A_np.ndim != 2 or A_np.shape[0] != A_np.shape[1]:
235
+ raise ValueError("A must be square for Dirichlet condensation.")
236
+ n_total = int(A_np.shape[0])
237
+
238
+ if check:
239
+ if dir_arr.size != dir_vals_arr.size:
240
+ raise ValueError("dir_dofs and dir_vals must have the same length")
241
+ if dir_arr.size:
242
+ if np.min(dir_arr) < 0 or np.max(dir_arr) >= n_total:
243
+ raise ValueError("dir_dofs out of bounds")
244
+ if np.unique(dir_arr).size != dir_arr.size:
245
+ raise ValueError("dir_dofs contains duplicates")
246
+
247
+ mask = np.ones(n_total, dtype=bool)
248
+ mask[dir_arr] = False
249
+ free = np.nonzero(mask)[0]
250
+
251
+ if isinstance(A, FluxSparseMatrix):
252
+ K_csr = A.to_csr()
253
+ elif sp is not None and sp.issparse(A):
254
+ K_csr = A.tocsr()
255
+ elif hasattr(A, "to_csr"):
256
+ K_csr = A.to_csr()
257
+ else:
258
+ K_csr = np.asarray(A)
259
+
260
+ K_ff = K_csr[free][:, free]
261
+ F_free = F_arr[free]
262
+ if dir_arr.size:
263
+ K_fd = K_csr[free][:, dir_arr]
264
+ if F_free.ndim == 2:
265
+ F_free = F_free - (K_fd @ dir_vals_arr)[:, None]
266
+ else:
267
+ F_free = F_free - K_fd @ dir_vals_arr
268
+
269
+ return CondensedSystem(
270
+ K=K_ff,
271
+ F=F_free,
272
+ free_dofs=free,
273
+ dir_dofs=dir_arr,
274
+ dir_vals=dir_vals_arr,
275
+ n_dofs=n_total,
276
+ )
277
+
278
+
279
+
280
+
36
281
  def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
37
282
  """Apply Dirichlet conditions to FluxSparseMatrix + load (CSR)."""
38
283
  K_csr = A.to_csr().tolil()
@@ -54,6 +299,11 @@ def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
54
299
  return K_csr.tocsr(), Fc
55
300
 
56
301
 
302
+ def enforce_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
303
+ """Alias for enforce_dirichlet_sparse for FluxSparseMatrix inputs."""
304
+ return enforce_dirichlet_sparse(A, F, dofs, vals)
305
+
306
+
57
307
  def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
58
308
  """
59
309
  Condense Dirichlet DOFs for a FluxSparseMatrix.
@@ -76,6 +326,63 @@ def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
76
326
  return K_ff, F_free, free, dir_arr, dir_vals_arr
77
327
 
78
328
 
329
+ def condense_dirichlet_fluxsparse_coo(
330
+ A: FluxSparseMatrix,
331
+ F,
332
+ dofs,
333
+ vals,
334
+ *,
335
+ coalesce: bool = True,
336
+ ):
337
+ """
338
+ Condense Dirichlet DOFs for a FluxSparseMatrix using COO filtering.
339
+ Returns: (K_free, F_free, free_dofs, dir_dofs, dir_vals)
340
+ """
341
+ dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
342
+ n_total = int(A.n_dofs)
343
+ mask = np.ones(n_total, dtype=bool)
344
+ mask[dir_arr] = False
345
+ free = np.nonzero(mask)[0]
346
+
347
+ rows = np.asarray(A.pattern.rows, dtype=np.int64)
348
+ cols = np.asarray(A.pattern.cols, dtype=np.int64)
349
+ data = np.asarray(A.data)
350
+
351
+ g2l = -np.ones(n_total, dtype=np.int32)
352
+ g2l[free] = np.arange(free.size, dtype=np.int32)
353
+ r2 = g2l[rows]
354
+ c2 = g2l[cols]
355
+ keep = (r2 >= 0) & (c2 >= 0)
356
+
357
+ rows_f = r2[keep]
358
+ cols_f = c2[keep]
359
+ data_f = data[keep]
360
+ if coalesce:
361
+ rows_f, cols_f, data_f = coalesce_coo(rows_f, cols_f, data_f)
362
+
363
+ K_free = FluxSparseMatrix(rows_f, cols_f, data_f, int(free.size))
364
+
365
+ F_arr = np.asarray(F, dtype=float)
366
+ F_free = F_arr[free]
367
+ if dir_arr.size > 0 and not np.allclose(dir_vals_arr, 0.0):
368
+ dir_full = np.zeros(n_total, dtype=F_arr.dtype)
369
+ dir_full[dir_arr] = dir_vals_arr
370
+ mask_fd = mask[rows] & (~mask[cols])
371
+ if np.any(mask_fd):
372
+ rows_fd = rows[mask_fd]
373
+ cols_fd = cols[mask_fd]
374
+ data_fd = data[mask_fd]
375
+ contrib = data_fd * dir_full[cols_fd]
376
+ delta = np.zeros(n_total, dtype=F_arr.dtype)
377
+ np.add.at(delta, rows_fd, contrib)
378
+ if F_free.ndim == 2:
379
+ F_free = F_free - delta[free][:, None]
380
+ else:
381
+ F_free = F_free - delta[free]
382
+
383
+ return K_free, jnp.asarray(F_free), free, dir_arr, dir_vals_arr
384
+
385
+
79
386
  def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
80
387
  """
81
388
  Return free DOF indices given total DOFs and Dirichlet DOFs.
@@ -86,6 +393,29 @@ def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
86
393
  return np.nonzero(mask)[0]
87
394
 
88
395
 
396
+ def restrict_flux_to_free(K: FluxSparseMatrix, free: np.ndarray, *, coalesce: bool = True) -> FluxSparseMatrix:
397
+ """
398
+ Restrict a FluxSparseMatrix to free DOFs and return the condensed matrix.
399
+ """
400
+ free = np.asarray(free, dtype=np.int32)
401
+ g2l = -np.ones(K.n_dofs, dtype=np.int32)
402
+ g2l[free] = np.arange(free.size, dtype=np.int32)
403
+
404
+ rows = np.asarray(K.pattern.rows)
405
+ cols = np.asarray(K.pattern.cols)
406
+ data = np.asarray(K.data)
407
+ r2 = g2l[rows]
408
+ c2 = g2l[cols]
409
+ mask = (r2 >= 0) & (c2 >= 0)
410
+ K_free = FluxSparseMatrix(
411
+ jnp.asarray(r2[mask]),
412
+ jnp.asarray(c2[mask]),
413
+ jnp.asarray(data[mask]),
414
+ int(free.size),
415
+ )
416
+ return K_free.coalesce() if coalesce else K_free
417
+
418
+
89
419
  def condense_dirichlet_dense(K, F, dofs, vals):
90
420
  """
91
421
  Eliminate Dirichlet dofs for dense/CSR matrices and return condensed system.