fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/solver/cg.py ADDED
@@ -0,0 +1,326 @@
1
+ from __future__ import annotations
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jax.scipy as jsp
6
+
7
+ try:
8
+ from jax.experimental import sparse as jsparse
9
+ except Exception: # pragma: no cover
10
+ jsparse = None
11
+
12
+ from .sparse import FluxSparseMatrix
13
+
14
+
15
+ def _matvec_builder(A):
16
+ if jsparse is not None and isinstance(A, jsparse.BCOO):
17
+ return lambda x: A @ x
18
+ if isinstance(A, FluxSparseMatrix):
19
+ return A.matvec
20
+ if hasattr(A, "matvec"):
21
+ return A.matvec
22
+ if callable(A):
23
+ return lambda x: A(x)
24
+ if isinstance(A, tuple) and len(A) == 4:
25
+ return FluxSparseMatrix.from_bilinear(A).matvec
26
+
27
+ def mv(x):
28
+ return jnp.asarray(A) @ x
29
+
30
+ return mv
31
+
32
+
33
+ def _diag_builder(A, n: int):
34
+ """
35
+ Build diagonal for a Jacobi preconditioner when available.
36
+ """
37
+ if jsparse is not None and isinstance(A, jsparse.BCOO):
38
+ idx = A.indices
39
+ data = A.data
40
+ rows = idx[:, 0]
41
+ cols = idx[:, 1]
42
+ mask = rows == cols
43
+ diag_contrib = jnp.where(mask, data, jnp.zeros_like(data))
44
+ return jax.ops.segment_sum(diag_contrib, rows, n)
45
+ if isinstance(A, FluxSparseMatrix):
46
+ return A.diag()
47
+ if hasattr(A, "diagonal"):
48
+ return jnp.asarray(A.diagonal())
49
+ if isinstance(A, tuple) and len(A) == 4:
50
+ return _diag_builder(FluxSparseMatrix.from_bilinear(A), n)
51
+ if callable(A):
52
+ raise ValueError("Jacobi preconditioner requires access to matrix diagonal")
53
+ arr = jnp.asarray(A)
54
+ if arr.ndim == 2 and arr.shape[0] == arr.shape[1]:
55
+ return jnp.diag(arr)
56
+ raise ValueError("Cannot build Jacobi preconditioner: diagonal unavailable")
57
+
58
+
59
+ def _cg_solve_single(
60
+ A,
61
+ b,
62
+ *,
63
+ x0=None,
64
+ tol: float = 1e-8,
65
+ maxiter: int | None = None,
66
+ preconditioner=None,
67
+ ):
68
+ """
69
+ Conjugate gradient (Ax=b) in JAX.
70
+ A: FluxSparseMatrix / (rows, cols, data, n) / dense array
71
+ b: RHS (jnp or np)
72
+ preconditioner: None | "jacobi" | callable(r) -> z
73
+ returns: (x, info dict)
74
+ """
75
+ b = jnp.asarray(b)
76
+ n = b.shape[0]
77
+ if x0 is None:
78
+ x0 = jnp.zeros_like(b)
79
+ if maxiter is None:
80
+ maxiter = max(10 * n, 1)
81
+
82
+ mv = _matvec_builder(A)
83
+ precon = None
84
+
85
+ if preconditioner is None:
86
+ pass
87
+ elif preconditioner == "jacobi":
88
+ diag = _diag_builder(A, n)
89
+ inv_diag = jnp.where(diag != 0.0, 1.0 / diag, 0.0)
90
+
91
+ def precon(r):
92
+ return inv_diag * r
93
+
94
+ elif preconditioner == "block_jacobi":
95
+ # Expect 3 DOFs per node
96
+ if n % 3 != 0:
97
+ raise ValueError("block_jacobi requires n_dofs % 3 == 0")
98
+ if jsparse is not None and isinstance(A, jsparse.BCOO):
99
+ rows = A.indices[:, 0]
100
+ cols = A.indices[:, 1]
101
+ data = A.data
102
+ elif isinstance(A, FluxSparseMatrix):
103
+ rows = jnp.asarray(A.pattern.rows)
104
+ cols = jnp.asarray(A.pattern.cols)
105
+ data = jnp.asarray(A.data)
106
+ else:
107
+ raise ValueError("block_jacobi requires FluxSparseMatrix or BCOO")
108
+
109
+ block_rows = rows // 3
110
+ block_cols = cols // 3
111
+ lr = rows % 3
112
+ lc = cols % 3
113
+ mask = block_rows == block_cols
114
+ block_rows = block_rows[mask]
115
+ lr = lr[mask]
116
+ lc = lc[mask]
117
+ data = data[mask]
118
+ n_block = n // 3
119
+ blocks = jnp.zeros((n_block, 3, 3), dtype=data.dtype)
120
+ blocks = blocks.at[block_rows, lr, lc].add(data)
121
+ blocks = blocks + 1e-12 * jnp.eye(3)[None, :, :]
122
+ inv_blocks = jnp.linalg.inv(blocks)
123
+
124
+ def precon(r):
125
+ rb = r.reshape((n_block, 3))
126
+ zb = jnp.einsum("bij,bj->bi", inv_blocks, rb)
127
+ return zb.reshape((-1,))
128
+
129
+ elif callable(preconditioner):
130
+ precon = preconditioner
131
+ else:
132
+ raise ValueError(f"Unknown preconditioner type: {preconditioner}")
133
+
134
+ def body_fun(state):
135
+ k, x, r, p, rz_old = state
136
+ Ap = mv(p)
137
+ alpha = rz_old / jnp.dot(p, Ap)
138
+ x_new = x + alpha * p
139
+ r_new = r - alpha * Ap
140
+ z_new = r_new if precon is None else precon(r_new)
141
+ rz_new = jnp.dot(r_new, z_new)
142
+ beta = rz_new / rz_old
143
+ p_new = z_new + beta * p
144
+ return k + 1, x_new, r_new, p_new, rz_new
145
+
146
+ r0 = b - mv(x0)
147
+ z0 = r0 if precon is None else precon(r0)
148
+ rz0 = jnp.dot(r0, z0)
149
+ init_state = (0, x0, r0, z0, rz0)
150
+
151
+ def cond_fun(state):
152
+ k, x, r, p, rz_old = state
153
+ return jnp.logical_and(k < maxiter, jnp.dot(r, r) > tol * tol)
154
+
155
+ k, x, r, p, rz_old = jax.lax.while_loop(cond_fun, body_fun, init_state)
156
+ res_norm = jnp.sqrt(jnp.dot(r, r))
157
+ info = {"iters": k, "residual_norm": res_norm}
158
+ return x, info
159
+
160
+
161
+ def cg_solve(
162
+ A,
163
+ b,
164
+ *,
165
+ x0=None,
166
+ tol: float = 1e-8,
167
+ maxiter: int | None = None,
168
+ preconditioner=None,
169
+ ):
170
+ """
171
+ Conjugate gradient (Ax=b) in JAX.
172
+ Supports single RHS (n,) or multiple RHS (n, n_rhs).
173
+ """
174
+ b_arr = jnp.asarray(b)
175
+ if b_arr.ndim == 1:
176
+ return _cg_solve_single(
177
+ A,
178
+ b_arr,
179
+ x0=x0,
180
+ tol=tol,
181
+ maxiter=maxiter,
182
+ preconditioner=preconditioner,
183
+ )
184
+
185
+ if b_arr.ndim != 2:
186
+ raise ValueError("cg_solve expects b with shape (n,) or (n, n_rhs).")
187
+
188
+ xs = []
189
+ infos = []
190
+ for i in range(b_arr.shape[1]):
191
+ x0_i = None if x0 is None else jnp.asarray(x0)[:, i]
192
+ x_i, info_i = _cg_solve_single(
193
+ A,
194
+ b_arr[:, i],
195
+ x0=x0_i,
196
+ tol=tol,
197
+ maxiter=maxiter,
198
+ preconditioner=preconditioner,
199
+ )
200
+ xs.append(x_i)
201
+ infos.append(info_i)
202
+ x_out = jnp.stack(xs, axis=1)
203
+ info = {
204
+ "iters": [int(i.get("iters", 0)) for i in infos],
205
+ "residual_norm": jnp.asarray([i.get("residual_norm", 0.0) for i in infos]),
206
+ }
207
+ return x_out, info
208
+
209
+
210
+ def _cg_solve_jax_single(
211
+ A,
212
+ b,
213
+ *,
214
+ x0=None,
215
+ tol: float = 1e-8,
216
+ maxiter: int | None = None,
217
+ preconditioner=None,
218
+ ):
219
+ """
220
+ Conjugate gradient via jax.scipy.sparse.linalg.cg.
221
+ A: FluxSparseMatrix / (rows, cols, data, n) / dense array / callable
222
+ preconditioner: None | "jacobi" | callable(r) -> z
223
+ returns: (x, info dict)
224
+ """
225
+ b = jnp.asarray(b)
226
+ n = b.shape[0]
227
+ if x0 is None:
228
+ x0 = jnp.zeros_like(b)
229
+ if maxiter is None:
230
+ maxiter = max(10 * n, 1)
231
+
232
+ mv = _matvec_builder(A)
233
+
234
+ precon = None
235
+ if preconditioner is None:
236
+ pass
237
+ elif preconditioner == "jacobi":
238
+ diag = _diag_builder(A, n)
239
+ inv_diag = jnp.where(diag != 0.0, 1.0 / diag, 0.0)
240
+
241
+ def precon(r):
242
+ return inv_diag * r
243
+
244
+ elif callable(preconditioner):
245
+ precon = preconditioner
246
+ else:
247
+ raise ValueError(f"Unknown preconditioner type: {preconditioner}")
248
+
249
+ def M_fun(x):
250
+ return precon(x) if precon is not None else x
251
+
252
+ x, info_val = jsp.sparse.linalg.cg(
253
+ mv,
254
+ b,
255
+ x0=x0,
256
+ tol=tol,
257
+ atol=0.0,
258
+ maxiter=maxiter,
259
+ M=M_fun if precon is not None else None,
260
+ )
261
+ res = b - mv(x)
262
+ res_norm = jnp.sqrt(jnp.dot(res, res))
263
+
264
+ if isinstance(info_val, (int, jnp.integer)):
265
+ iters = int(info_val)
266
+ converged = iters == 0
267
+ elif info_val is None:
268
+ # jax.scipy may return None on success; treat as converged with unknown iter count.
269
+ iters = 0
270
+ converged = True
271
+ else:
272
+ # Unknown type; fall back to "not sure" but keep running.
273
+ iters = 0
274
+ converged = True
275
+ info = {"iters": iters, "residual_norm": res_norm, "converged": converged, "info": info_val}
276
+ return x, info
277
+
278
+
279
+ def cg_solve_jax(
280
+ A,
281
+ b,
282
+ *,
283
+ x0=None,
284
+ tol: float = 1e-8,
285
+ maxiter: int | None = None,
286
+ preconditioner=None,
287
+ ):
288
+ """
289
+ Conjugate gradient via jax.scipy.sparse.linalg.cg.
290
+ Supports single RHS (n,) or multiple RHS (n, n_rhs).
291
+ """
292
+ b_arr = jnp.asarray(b)
293
+ if b_arr.ndim == 1:
294
+ return _cg_solve_jax_single(
295
+ A,
296
+ b_arr,
297
+ x0=x0,
298
+ tol=tol,
299
+ maxiter=maxiter,
300
+ preconditioner=preconditioner,
301
+ )
302
+
303
+ if b_arr.ndim != 2:
304
+ raise ValueError("cg_solve_jax expects b with shape (n,) or (n, n_rhs).")
305
+
306
+ xs = []
307
+ infos = []
308
+ for i in range(b_arr.shape[1]):
309
+ x0_i = None if x0 is None else jnp.asarray(x0)[:, i]
310
+ x_i, info_i = _cg_solve_jax_single(
311
+ A,
312
+ b_arr[:, i],
313
+ x0=x0_i,
314
+ tol=tol,
315
+ maxiter=maxiter,
316
+ preconditioner=preconditioner,
317
+ )
318
+ xs.append(x_i)
319
+ infos.append(info_i)
320
+ x_out = jnp.stack(xs, axis=1)
321
+ info = {
322
+ "iters": [int(i.get("iters", 0)) for i in infos],
323
+ "residual_norm": jnp.asarray([i.get("residual_norm", 0.0) for i in infos]),
324
+ "converged": [bool(i.get("converged", True)) for i in infos],
325
+ }
326
+ return x_out, info
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+
6
+ from .sparse import FluxSparseMatrix
7
+
8
+
9
+ def _normalize_dirichlet(dofs, vals):
10
+ dir_arr = np.asarray(dofs, dtype=int)
11
+ if vals is None:
12
+ return dir_arr, np.zeros(dir_arr.shape[0], dtype=float)
13
+ return dir_arr, np.asarray(vals, dtype=float)
14
+
15
+
16
+ def enforce_dirichlet_dense(K, F, dofs, vals):
17
+ """Apply Dirichlet conditions directly to stiffness/load (dense)."""
18
+ Kc = np.asarray(K, dtype=float).copy()
19
+ Fc = np.asarray(F, dtype=float).copy()
20
+ dofs, vals = _normalize_dirichlet(dofs, vals)
21
+ if Fc.ndim == 2:
22
+ Fc = Fc - (Kc[:, dofs] @ vals)[:, None]
23
+ else:
24
+ Fc = Fc - Kc[:, dofs] @ vals
25
+ for d, v in zip(dofs, vals):
26
+ Kc[d, :] = 0.0
27
+ Kc[:, d] = 0.0
28
+ Kc[d, d] = 1.0
29
+ if Fc.ndim == 2:
30
+ Fc[d, :] = v
31
+ else:
32
+ Fc[d] = v
33
+ return Kc, Fc
34
+
35
+
36
+ def enforce_dirichlet_sparse(A: FluxSparseMatrix, F, dofs, vals):
37
+ """Apply Dirichlet conditions to FluxSparseMatrix + load (CSR)."""
38
+ K_csr = A.to_csr().tolil()
39
+ Fc = np.asarray(F, dtype=float).copy()
40
+ dofs, vals = _normalize_dirichlet(dofs, vals)
41
+ if Fc.ndim == 2:
42
+ Fc = Fc - (K_csr[:, dofs] @ vals)[:, None]
43
+ else:
44
+ Fc = Fc - K_csr[:, dofs] @ vals
45
+ for d, v in zip(dofs, vals):
46
+ K_csr.rows[d] = [d]
47
+ K_csr.data[d] = [1.0]
48
+ K_csr[:, d] = 0.0
49
+ K_csr[d, d] = 1.0
50
+ if Fc.ndim == 2:
51
+ Fc[d, :] = v
52
+ else:
53
+ Fc[d] = v
54
+ return K_csr.tocsr(), Fc
55
+
56
+
57
+ def condense_dirichlet_fluxsparse(A: FluxSparseMatrix, F, dofs, vals):
58
+ """
59
+ Condense Dirichlet DOFs for a FluxSparseMatrix.
60
+ Returns: (K_ff, F_free, free_dofs, dir_dofs, dir_vals)
61
+ """
62
+ K_csr = A.to_csr()
63
+ dir_arr, dir_vals_arr = _normalize_dirichlet(dofs, vals)
64
+ mask = np.ones(K_csr.shape[0], dtype=bool)
65
+ mask[dir_arr] = False
66
+ free = np.nonzero(mask)[0]
67
+ K_ff = K_csr[free][:, free]
68
+ K_fd = K_csr[free][:, dir_arr] if dir_arr.size > 0 else None
69
+ F_full = np.asarray(F, dtype=float)
70
+ F_free = F_full[free]
71
+ if K_fd is not None and dir_arr.size > 0:
72
+ if F_free.ndim == 2:
73
+ F_free = F_free - (K_fd @ dir_vals_arr)[:, None]
74
+ else:
75
+ F_free = F_free - K_fd @ dir_vals_arr
76
+ return K_ff, F_free, free, dir_arr, dir_vals_arr
77
+
78
+
79
+ def free_dofs(n_dofs: int, dir_dofs) -> np.ndarray:
80
+ """
81
+ Return free DOF indices given total DOFs and Dirichlet DOFs.
82
+ """
83
+ dir_set = np.asarray(dir_dofs, dtype=int)
84
+ mask = np.ones(int(n_dofs), dtype=bool)
85
+ mask[dir_set] = False
86
+ return np.nonzero(mask)[0]
87
+
88
+
89
+ def condense_dirichlet_dense(K, F, dofs, vals):
90
+ """
91
+ Eliminate Dirichlet dofs for dense/CSR matrices and return condensed system.
92
+ Returns: (K_cc, F_c, free_dofs, dir_dofs, dir_vals)
93
+ """
94
+ K_np = np.asarray(K, dtype=float)
95
+ F_np = np.asarray(F, dtype=float)
96
+ n = K_np.shape[0]
97
+
98
+ dir_set, dir_vals = _normalize_dirichlet(dofs, vals)
99
+ mask = np.ones(n, dtype=bool)
100
+ mask[dir_set] = False
101
+ free_dofs = np.nonzero(mask)[0]
102
+
103
+ K_ff = K_np[np.ix_(free_dofs, free_dofs)]
104
+ K_fd = K_np[np.ix_(free_dofs, dir_set)]
105
+ F_f = F_np[free_dofs]
106
+ if F_f.ndim == 2:
107
+ F_f = F_f - (K_fd @ dir_vals)[:, None]
108
+ else:
109
+ F_f = F_f - K_fd @ dir_vals
110
+
111
+ return K_ff, F_f, free_dofs, dir_set, dir_vals
112
+
113
+
114
+ def expand_dirichlet_solution(u_free, free_dofs, dir_dofs, dir_vals, n_total):
115
+ """Expand condensed solution back to full vector."""
116
+ dir_dofs, dir_vals = _normalize_dirichlet(dir_dofs, dir_vals)
117
+ u_free_arr = np.asarray(u_free, dtype=float)
118
+ if u_free_arr.ndim == 2:
119
+ u = np.zeros((n_total, u_free_arr.shape[1]), dtype=float)
120
+ u[free_dofs, :] = u_free_arr
121
+ u[dir_dofs, :] = np.asarray(dir_vals, dtype=float)
122
+ else:
123
+ u = np.zeros(n_total, dtype=float)
124
+ u[free_dofs] = u_free_arr
125
+ u[dir_dofs] = np.asarray(dir_vals, dtype=float)
126
+ return u
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, List, Optional
5
+
6
+
7
+ @dataclass
8
+ class NewtonIterRecord:
9
+ iter: int
10
+ res_inf: float
11
+ res_two: float
12
+ rel_res_inf: float
13
+ alpha: float
14
+ step_norm: float
15
+ lin_iters: Optional[int] = None
16
+ lin_converged: Optional[bool] = None
17
+ lin_residual: Optional[float] = None
18
+ nan_detected: bool = False
19
+ assemble_time: Optional[float] = None
20
+ linear_time: Optional[float] = None
21
+
22
+
23
+ @dataclass
24
+ class LoadStepResult:
25
+ load_factor: float
26
+ info: Any
27
+ solve_time: float
28
+ u: Any
29
+ iter_history: List[NewtonIterRecord] = field(default_factory=list)
30
+ exception: Optional[str] = None
31
+ meta: dict[str, Any] = field(default_factory=dict)