fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,188 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+ try:
10
+ import scipy.sparse as sp
11
+ except Exception: # pragma: no cover
12
+ sp = None
13
+
14
+
15
+ @jax.tree_util.register_pytree_node_class
16
+ @dataclass(frozen=True)
17
+ class SparsityPattern:
18
+ """
19
+ Jacobian sparsity pattern (rows/cols) that is independent of the solution.
20
+ """
21
+
22
+ rows: jnp.ndarray
23
+ cols: jnp.ndarray
24
+ n_dofs: int
25
+ idx: jnp.ndarray | None = None
26
+ diag_idx: jnp.ndarray | None = None
27
+ perm: jnp.ndarray | None = None # permutation mapping COO data -> CSR data
28
+ indptr: jnp.ndarray | None = None # CSR row pointer
29
+ indices: jnp.ndarray | None = None # CSR column indices
30
+
31
+ def __post_init__(self):
32
+ # Ensure n_dofs is always a Python int so JAX treats it as a static aux value.
33
+ object.__setattr__(self, "n_dofs", int(self.n_dofs))
34
+
35
+ def tree_flatten(self):
36
+ children = (
37
+ self.rows,
38
+ self.cols,
39
+ self.idx if self.idx is not None else jnp.array([], jnp.int32),
40
+ self.diag_idx if self.diag_idx is not None else jnp.array([], jnp.int32),
41
+ self.perm if self.perm is not None else jnp.array([], jnp.int32),
42
+ self.indptr if self.indptr is not None else jnp.array([], jnp.int32),
43
+ self.indices if self.indices is not None else jnp.array([], jnp.int32),
44
+ )
45
+ aux = {
46
+ "n_dofs": self.n_dofs,
47
+ "has_idx": self.idx is not None,
48
+ "has_diag_idx": self.diag_idx is not None,
49
+ "has_perm": self.perm is not None,
50
+ "has_indptr": self.indptr is not None,
51
+ "has_indices": self.indices is not None,
52
+ }
53
+ return children, aux
54
+
55
+ @classmethod
56
+ def tree_unflatten(cls, aux, children):
57
+ rows, cols, idx, diag_idx, perm, indptr, indices = children
58
+ idx_out = idx if aux["has_idx"] else None
59
+ diag_out = diag_idx if aux["has_diag_idx"] else None
60
+ perm_out = perm if aux["has_perm"] else None
61
+ indptr_out = indptr if aux["has_indptr"] else None
62
+ indices_out = indices if aux["has_indices"] else None
63
+ return cls(
64
+ rows=rows,
65
+ cols=cols,
66
+ n_dofs=aux["n_dofs"],
67
+ idx=idx_out,
68
+ diag_idx=diag_out,
69
+ perm=perm_out,
70
+ indptr=indptr_out,
71
+ indices=indices_out,
72
+ )
73
+
74
+
75
+ @jax.tree_util.register_pytree_node_class
76
+ class FluxSparseMatrix:
77
+ """
78
+ Sparse matrix wrapper (COO) with a fixed pattern and mutable values.
79
+ - pattern stores rows/cols/n_dofs (optionally idx for dense scatter)
80
+ - data stores the numeric values for the current nonlinear iterate
81
+ """
82
+
83
+ def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None):
84
+ # New signature: FluxSparseMatrix(pattern, data)
85
+ if isinstance(rows_or_pattern, SparsityPattern):
86
+ pattern = rows_or_pattern
87
+ values = cols if data is None else data
88
+ values = jnp.asarray(values)
89
+ else:
90
+ # Legacy signature: FluxSparseMatrix(rows, cols, data, n_dofs)
91
+ r_np = np.asarray(rows_or_pattern, dtype=np.int32)
92
+ c_np = np.asarray(cols, dtype=np.int32)
93
+ diag_idx_np = np.nonzero(r_np == c_np)[0].astype(np.int32)
94
+ pattern = SparsityPattern(
95
+ rows=jnp.asarray(r_np),
96
+ cols=jnp.asarray(c_np),
97
+ n_dofs=int(n_dofs) if n_dofs is not None else int(c_np.max()) + 1,
98
+ idx=None,
99
+ diag_idx=jnp.asarray(diag_idx_np),
100
+ )
101
+ values = jnp.asarray(data)
102
+
103
+ self.pattern = pattern
104
+ self.rows = pattern.rows
105
+ self.cols = pattern.cols
106
+ self.n_dofs = int(pattern.n_dofs)
107
+ self.data = values
108
+
109
+ @classmethod
110
+ def from_bilinear(cls, coo_tuple):
111
+ """Construct from assemble_bilinear_dense(..., sparse=True)."""
112
+ rows, cols, data, n_dofs = coo_tuple
113
+ return cls(rows, cols, data, n_dofs)
114
+
115
+ @classmethod
116
+ def from_linear(cls, coo_tuple):
117
+ """Construct from assemble_linear_form(..., sparse=True) (matrix interpretation only)."""
118
+ rows, data, n_dofs = coo_tuple
119
+ cols = jnp.zeros_like(rows)
120
+ return cls(rows, cols, data, n_dofs)
121
+
122
+ def with_data(self, data):
123
+ """Return a new FluxSparseMatrix sharing the same pattern with updated data."""
124
+ return FluxSparseMatrix(self.pattern, data)
125
+
126
+ def to_coo(self):
127
+ return self.pattern.rows, self.pattern.cols, self.data, self.pattern.n_dofs
128
+
129
+ def to_csr(self):
130
+ if sp is None:
131
+ raise ImportError("scipy is required for to_csr()")
132
+ if (
133
+ self.pattern.indptr is not None
134
+ and self.pattern.indices is not None
135
+ and self.pattern.perm is not None
136
+ ):
137
+ indptr = np.array(self.pattern.indptr, dtype=np.int32, copy=True)
138
+ indices = np.array(self.pattern.indices, dtype=np.int32, copy=True)
139
+ data = np.array(self.data, copy=True)[np.asarray(self.pattern.perm, dtype=np.int32)]
140
+ return sp.csr_matrix((data, indices, indptr), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
141
+ r = np.array(self.pattern.rows, dtype=np.int64, copy=True)
142
+ c = np.array(self.pattern.cols, dtype=np.int64, copy=True)
143
+ d = np.array(self.data, copy=True)
144
+ return sp.csr_matrix((d, (r, c)), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
145
+
146
+ def to_dense(self):
147
+ # small debug helper
148
+ dense = jnp.zeros((self.pattern.n_dofs, self.pattern.n_dofs), dtype=self.data.dtype)
149
+ dense = dense.at[self.pattern.rows, self.pattern.cols].add(self.data)
150
+ return dense
151
+
152
+ def to_bcoo(self):
153
+ """Construct jax.experimental.sparse.BCOO (requires jax.experimental.sparse)."""
154
+ try:
155
+ from jax.experimental import sparse as jsparse # type: ignore
156
+ except Exception as exc: # pragma: no cover
157
+ raise ImportError("jax.experimental.sparse is required for to_bcoo()") from exc
158
+ idx = jnp.stack([self.pattern.rows, self.pattern.cols], axis=-1)
159
+ return jsparse.BCOO((self.data, idx), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
160
+
161
+ def matvec(self, x):
162
+ """Compute y = A x in JAX (iterative solvers)."""
163
+ xj = jnp.asarray(x)
164
+ contrib = self.data * xj[self.pattern.cols]
165
+ # Use scatter_add to avoid tracing a dynamic int(x.max()) in jnp.bincount,
166
+ # which triggers concretization errors under jit/while_loop.
167
+ out = jnp.zeros(self.pattern.n_dofs, dtype=contrib.dtype)
168
+ return out.at[self.pattern.rows].add(contrib)
169
+
170
+ def diag(self):
171
+ """Diagonal entries aggregated for Jacobi preconditioning."""
172
+ if self.pattern.diag_idx is not None:
173
+ r = self.pattern.rows[self.pattern.diag_idx]
174
+ d = self.data[self.pattern.diag_idx]
175
+ return jax.ops.segment_sum(d, r, self.pattern.n_dofs)
176
+
177
+ # Fallback for patterns without diag_idx (kept for backward compatibility).
178
+ mask = self.pattern.rows == self.pattern.cols
179
+ diag_contrib = jnp.where(mask, self.data, jnp.zeros_like(self.data))
180
+ return jax.ops.segment_sum(diag_contrib, self.pattern.rows, self.pattern.n_dofs)
181
+
182
+ def tree_flatten(self):
183
+ return (self.pattern, self.data), {}
184
+
185
+ @classmethod
186
+ def tree_unflatten(cls, aux, children):
187
+ pattern, data = children
188
+ return cls(pattern, data)
@@ -0,0 +1,7 @@
1
+
2
+ from .timer import SectionTimer, NullTimer
3
+
4
+ __all__ = [
5
+ "SectionTimer",
6
+ "NullTimer",
7
+ ]
fluxfem/tools/jit.py ADDED
@@ -0,0 +1,51 @@
1
+ import jax
2
+
3
+ from ..core.assembly import assemble_residual, assemble_jacobian
4
+ from ..core.space import FESpace
5
+
6
+
7
+ def make_jitted_residual(space: FESpace, res_form, params, *, sparse: bool = False):
8
+ """
9
+ Create a jitted residual assembler: u -> R(u).
10
+ params and space are closed over.
11
+ """
12
+ space_jax = space
13
+ params_jax = params
14
+
15
+ @jax.jit
16
+ def residual(u):
17
+ return assemble_residual(space_jax, res_form, u, params_jax, sparse=sparse)
18
+
19
+ return residual
20
+
21
+
22
+ def make_jitted_jacobian(
23
+ space: FESpace,
24
+ res_form,
25
+ params,
26
+ *,
27
+ sparse: bool = False,
28
+ return_flux_matrix: bool = False,
29
+ ):
30
+ """
31
+ Create a jitted Jacobian assembler: u -> J(u).
32
+ params and space are closed over.
33
+ """
34
+ space_jax = space
35
+ params_jax = params
36
+
37
+ @jax.jit
38
+ def jacobian(u):
39
+ return assemble_jacobian(
40
+ space_jax,
41
+ res_form,
42
+ u,
43
+ params_jax,
44
+ sparse=sparse,
45
+ return_flux_matrix=return_flux_matrix,
46
+ )
47
+
48
+ return jacobian
49
+
50
+
51
+ __all__ = ["make_jitted_residual", "make_jitted_jacobian"]