fluxfem 0.1.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +316 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +818 -0
- fluxfem/helpers_num.py +11 -0
- fluxfem/helpers_wf.py +42 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.1a0.dist-info/METADATA +111 -0
- fluxfem-0.1.1a0.dist-info/RECORD +47 -0
- fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
- fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/solver/sparse.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import scipy.sparse as sp
|
|
11
|
+
except Exception: # pragma: no cover
|
|
12
|
+
sp = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@jax.tree_util.register_pytree_node_class
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class SparsityPattern:
|
|
18
|
+
"""
|
|
19
|
+
Jacobian sparsity pattern (rows/cols) that is independent of the solution.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
rows: jnp.ndarray
|
|
23
|
+
cols: jnp.ndarray
|
|
24
|
+
n_dofs: int
|
|
25
|
+
idx: jnp.ndarray | None = None
|
|
26
|
+
diag_idx: jnp.ndarray | None = None
|
|
27
|
+
perm: jnp.ndarray | None = None # permutation mapping COO data -> CSR data
|
|
28
|
+
indptr: jnp.ndarray | None = None # CSR row pointer
|
|
29
|
+
indices: jnp.ndarray | None = None # CSR column indices
|
|
30
|
+
|
|
31
|
+
def __post_init__(self):
|
|
32
|
+
# Ensure n_dofs is always a Python int so JAX treats it as a static aux value.
|
|
33
|
+
object.__setattr__(self, "n_dofs", int(self.n_dofs))
|
|
34
|
+
|
|
35
|
+
def tree_flatten(self):
|
|
36
|
+
children = (
|
|
37
|
+
self.rows,
|
|
38
|
+
self.cols,
|
|
39
|
+
self.idx if self.idx is not None else jnp.array([], jnp.int32),
|
|
40
|
+
self.diag_idx if self.diag_idx is not None else jnp.array([], jnp.int32),
|
|
41
|
+
self.perm if self.perm is not None else jnp.array([], jnp.int32),
|
|
42
|
+
self.indptr if self.indptr is not None else jnp.array([], jnp.int32),
|
|
43
|
+
self.indices if self.indices is not None else jnp.array([], jnp.int32),
|
|
44
|
+
)
|
|
45
|
+
aux = {
|
|
46
|
+
"n_dofs": self.n_dofs,
|
|
47
|
+
"has_idx": self.idx is not None,
|
|
48
|
+
"has_diag_idx": self.diag_idx is not None,
|
|
49
|
+
"has_perm": self.perm is not None,
|
|
50
|
+
"has_indptr": self.indptr is not None,
|
|
51
|
+
"has_indices": self.indices is not None,
|
|
52
|
+
}
|
|
53
|
+
return children, aux
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def tree_unflatten(cls, aux, children):
|
|
57
|
+
rows, cols, idx, diag_idx, perm, indptr, indices = children
|
|
58
|
+
idx_out = idx if aux["has_idx"] else None
|
|
59
|
+
diag_out = diag_idx if aux["has_diag_idx"] else None
|
|
60
|
+
perm_out = perm if aux["has_perm"] else None
|
|
61
|
+
indptr_out = indptr if aux["has_indptr"] else None
|
|
62
|
+
indices_out = indices if aux["has_indices"] else None
|
|
63
|
+
return cls(
|
|
64
|
+
rows=rows,
|
|
65
|
+
cols=cols,
|
|
66
|
+
n_dofs=aux["n_dofs"],
|
|
67
|
+
idx=idx_out,
|
|
68
|
+
diag_idx=diag_out,
|
|
69
|
+
perm=perm_out,
|
|
70
|
+
indptr=indptr_out,
|
|
71
|
+
indices=indices_out,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@jax.tree_util.register_pytree_node_class
|
|
76
|
+
class FluxSparseMatrix:
|
|
77
|
+
"""
|
|
78
|
+
Sparse matrix wrapper (COO) with a fixed pattern and mutable values.
|
|
79
|
+
- pattern stores rows/cols/n_dofs (optionally idx for dense scatter)
|
|
80
|
+
- data stores the numeric values for the current nonlinear iterate
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None):
|
|
84
|
+
# New signature: FluxSparseMatrix(pattern, data)
|
|
85
|
+
if isinstance(rows_or_pattern, SparsityPattern):
|
|
86
|
+
pattern = rows_or_pattern
|
|
87
|
+
values = cols if data is None else data
|
|
88
|
+
values = jnp.asarray(values)
|
|
89
|
+
else:
|
|
90
|
+
# Legacy signature: FluxSparseMatrix(rows, cols, data, n_dofs)
|
|
91
|
+
r_np = np.asarray(rows_or_pattern, dtype=np.int32)
|
|
92
|
+
c_np = np.asarray(cols, dtype=np.int32)
|
|
93
|
+
diag_idx_np = np.nonzero(r_np == c_np)[0].astype(np.int32)
|
|
94
|
+
pattern = SparsityPattern(
|
|
95
|
+
rows=jnp.asarray(r_np),
|
|
96
|
+
cols=jnp.asarray(c_np),
|
|
97
|
+
n_dofs=int(n_dofs) if n_dofs is not None else int(c_np.max()) + 1,
|
|
98
|
+
idx=None,
|
|
99
|
+
diag_idx=jnp.asarray(diag_idx_np),
|
|
100
|
+
)
|
|
101
|
+
values = jnp.asarray(data)
|
|
102
|
+
|
|
103
|
+
self.pattern = pattern
|
|
104
|
+
self.rows = pattern.rows
|
|
105
|
+
self.cols = pattern.cols
|
|
106
|
+
self.n_dofs = int(pattern.n_dofs)
|
|
107
|
+
self.data = values
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
def from_bilinear(cls, coo_tuple):
|
|
111
|
+
"""Construct from assemble_bilinear_dense(..., sparse=True)."""
|
|
112
|
+
rows, cols, data, n_dofs = coo_tuple
|
|
113
|
+
return cls(rows, cols, data, n_dofs)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_linear(cls, coo_tuple):
|
|
117
|
+
"""Construct from assemble_linear_form(..., sparse=True) (matrix interpretation only)."""
|
|
118
|
+
rows, data, n_dofs = coo_tuple
|
|
119
|
+
cols = jnp.zeros_like(rows)
|
|
120
|
+
return cls(rows, cols, data, n_dofs)
|
|
121
|
+
|
|
122
|
+
def with_data(self, data):
|
|
123
|
+
"""Return a new FluxSparseMatrix sharing the same pattern with updated data."""
|
|
124
|
+
return FluxSparseMatrix(self.pattern, data)
|
|
125
|
+
|
|
126
|
+
def to_coo(self):
|
|
127
|
+
return self.pattern.rows, self.pattern.cols, self.data, self.pattern.n_dofs
|
|
128
|
+
|
|
129
|
+
def to_csr(self):
|
|
130
|
+
if sp is None:
|
|
131
|
+
raise ImportError("scipy is required for to_csr()")
|
|
132
|
+
if (
|
|
133
|
+
self.pattern.indptr is not None
|
|
134
|
+
and self.pattern.indices is not None
|
|
135
|
+
and self.pattern.perm is not None
|
|
136
|
+
):
|
|
137
|
+
indptr = np.array(self.pattern.indptr, dtype=np.int32, copy=True)
|
|
138
|
+
indices = np.array(self.pattern.indices, dtype=np.int32, copy=True)
|
|
139
|
+
data = np.array(self.data, copy=True)[np.asarray(self.pattern.perm, dtype=np.int32)]
|
|
140
|
+
return sp.csr_matrix((data, indices, indptr), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
|
|
141
|
+
r = np.array(self.pattern.rows, dtype=np.int64, copy=True)
|
|
142
|
+
c = np.array(self.pattern.cols, dtype=np.int64, copy=True)
|
|
143
|
+
d = np.array(self.data, copy=True)
|
|
144
|
+
return sp.csr_matrix((d, (r, c)), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
|
|
145
|
+
|
|
146
|
+
def to_dense(self):
|
|
147
|
+
# small debug helper
|
|
148
|
+
dense = jnp.zeros((self.pattern.n_dofs, self.pattern.n_dofs), dtype=self.data.dtype)
|
|
149
|
+
dense = dense.at[self.pattern.rows, self.pattern.cols].add(self.data)
|
|
150
|
+
return dense
|
|
151
|
+
|
|
152
|
+
def to_bcoo(self):
|
|
153
|
+
"""Construct jax.experimental.sparse.BCOO (requires jax.experimental.sparse)."""
|
|
154
|
+
try:
|
|
155
|
+
from jax.experimental import sparse as jsparse # type: ignore
|
|
156
|
+
except Exception as exc: # pragma: no cover
|
|
157
|
+
raise ImportError("jax.experimental.sparse is required for to_bcoo()") from exc
|
|
158
|
+
idx = jnp.stack([self.pattern.rows, self.pattern.cols], axis=-1)
|
|
159
|
+
return jsparse.BCOO((self.data, idx), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
|
|
160
|
+
|
|
161
|
+
def matvec(self, x):
|
|
162
|
+
"""Compute y = A x in JAX (iterative solvers)."""
|
|
163
|
+
xj = jnp.asarray(x)
|
|
164
|
+
contrib = self.data * xj[self.pattern.cols]
|
|
165
|
+
# Use scatter_add to avoid tracing a dynamic int(x.max()) in jnp.bincount,
|
|
166
|
+
# which triggers concretization errors under jit/while_loop.
|
|
167
|
+
out = jnp.zeros(self.pattern.n_dofs, dtype=contrib.dtype)
|
|
168
|
+
return out.at[self.pattern.rows].add(contrib)
|
|
169
|
+
|
|
170
|
+
def diag(self):
|
|
171
|
+
"""Diagonal entries aggregated for Jacobi preconditioning."""
|
|
172
|
+
if self.pattern.diag_idx is not None:
|
|
173
|
+
r = self.pattern.rows[self.pattern.diag_idx]
|
|
174
|
+
d = self.data[self.pattern.diag_idx]
|
|
175
|
+
return jax.ops.segment_sum(d, r, self.pattern.n_dofs)
|
|
176
|
+
|
|
177
|
+
# Fallback for patterns without diag_idx (kept for backward compatibility).
|
|
178
|
+
mask = self.pattern.rows == self.pattern.cols
|
|
179
|
+
diag_contrib = jnp.where(mask, self.data, jnp.zeros_like(self.data))
|
|
180
|
+
return jax.ops.segment_sum(diag_contrib, self.pattern.rows, self.pattern.n_dofs)
|
|
181
|
+
|
|
182
|
+
def tree_flatten(self):
|
|
183
|
+
return (self.pattern, self.data), {}
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def tree_unflatten(cls, aux, children):
|
|
187
|
+
pattern, data = children
|
|
188
|
+
return cls(pattern, data)
|
fluxfem/tools/jit.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
|
|
3
|
+
from ..core.assembly import assemble_residual, assemble_jacobian
|
|
4
|
+
from ..core.space import FESpace
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def make_jitted_residual(space: FESpace, res_form, params, *, sparse: bool = False):
|
|
8
|
+
"""
|
|
9
|
+
Create a jitted residual assembler: u -> R(u).
|
|
10
|
+
params and space are closed over.
|
|
11
|
+
"""
|
|
12
|
+
space_jax = space
|
|
13
|
+
params_jax = params
|
|
14
|
+
|
|
15
|
+
@jax.jit
|
|
16
|
+
def residual(u):
|
|
17
|
+
return assemble_residual(space_jax, res_form, u, params_jax, sparse=sparse)
|
|
18
|
+
|
|
19
|
+
return residual
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_jitted_jacobian(
|
|
23
|
+
space: FESpace,
|
|
24
|
+
res_form,
|
|
25
|
+
params,
|
|
26
|
+
*,
|
|
27
|
+
sparse: bool = False,
|
|
28
|
+
return_flux_matrix: bool = False,
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Create a jitted Jacobian assembler: u -> J(u).
|
|
32
|
+
params and space are closed over.
|
|
33
|
+
"""
|
|
34
|
+
space_jax = space
|
|
35
|
+
params_jax = params
|
|
36
|
+
|
|
37
|
+
@jax.jit
|
|
38
|
+
def jacobian(u):
|
|
39
|
+
return assemble_jacobian(
|
|
40
|
+
space_jax,
|
|
41
|
+
res_form,
|
|
42
|
+
u,
|
|
43
|
+
params_jax,
|
|
44
|
+
sparse=sparse,
|
|
45
|
+
return_flux_matrix=return_flux_matrix,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return jacobian
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
__all__ = ["make_jitted_residual", "make_jitted_jacobian"]
|