fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/core/data.py ADDED
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import jax.numpy as jnp
7
+
8
+
9
+ @dataclass(eq=False)
10
+ class MeshData:
11
+ """Lightweight mesh data container for JAX-friendly serialization."""
12
+ coords: jnp.ndarray
13
+ conn: jnp.ndarray
14
+ cell_tags: jnp.ndarray | None = None
15
+ node_tags: jnp.ndarray | None = None
16
+
17
+ @classmethod
18
+ def from_mesh(cls, mesh: Any) -> "MeshData":
19
+ return cls(
20
+ coords=jnp.asarray(mesh.coords),
21
+ conn=jnp.asarray(mesh.conn),
22
+ cell_tags=None if mesh.cell_tags is None else jnp.asarray(mesh.cell_tags),
23
+ node_tags=None if mesh.node_tags is None else jnp.asarray(mesh.node_tags),
24
+ )
25
+
26
+
27
+ @dataclass(eq=False)
28
+ class BasisData:
29
+ """Quadrature and basis metadata for reproducible assembly."""
30
+ quad_points: jnp.ndarray
31
+ quad_weights: jnp.ndarray
32
+ dofs_per_node: int
33
+ kind: str
34
+
35
+ @classmethod
36
+ def from_basis(cls, basis: Any) -> "BasisData":
37
+ return cls(
38
+ quad_points=jnp.asarray(basis.quad_points),
39
+ quad_weights=jnp.asarray(basis.quad_weights),
40
+ dofs_per_node=int(basis.dofs_per_node),
41
+ kind=type(basis).__name__,
42
+ )
43
+
44
+
45
+ @dataclass(eq=False)
46
+ class SpaceData:
47
+ """Snapshot of space-related data used in assembly."""
48
+ mesh: MeshData
49
+ basis: BasisData
50
+ elem_dofs: jnp.ndarray
51
+ value_dim: int
52
+ n_dofs: int
53
+ n_ldofs: int
54
+
55
+ @classmethod
56
+ def from_space(cls, space: Any) -> "SpaceData":
57
+ return cls(
58
+ mesh=MeshData.from_mesh(space.mesh),
59
+ basis=BasisData.from_basis(space.basis),
60
+ elem_dofs=jnp.asarray(space.elem_dofs),
61
+ value_dim=int(space.value_dim),
62
+ n_dofs=int(space.n_dofs),
63
+ n_ldofs=int(space.n_ldofs),
64
+ )
fluxfem/core/dtypes.py ADDED
@@ -0,0 +1,4 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ DEFAULT_DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
fluxfem/core/forms.py ADDED
@@ -0,0 +1,234 @@
1
+ from __future__ import annotations
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from dataclasses import dataclass
6
+
7
+ from .basis import Basis3D
8
+
9
+ # FormContext/ScalarFormField/VectorFormField were dataclasses with the default __eq__. JAX ended up calling that
10
+ # __eq__ during the vmap over residuals, which tries to compare array fields element‑wise and then coerce to a
11
+ # bool, triggering “truth value of an array is ambiguous.” Setting eq=False (and for ElementVector for consistency)
12
+ # removes the autogenerated __eq__, so vmap no longer evaluates array equality and the residual/Jacobian assembly
13
+ # succeeds.
14
+
15
+
16
+ @dataclass(eq=False)
17
+ class ElementVector:
18
+ """
19
+ Simple vector-valued element wrapper (scikit-fem ElementVector style).
20
+ dim: dofs per node (e.g., 3 for displacement)
21
+ """
22
+ dim: int
23
+
24
+ def dof_map(self, conn: jnp.ndarray) -> jnp.ndarray:
25
+ """
26
+ Expand scalar connectivity (n_elems, n_nodes_per_elem) to vector dofs.
27
+ Returns shape (n_elems, n_nodes_per_elem * dim).
28
+ """
29
+ base = conn[..., None] * self.dim # (n_elems, n_nodes, 1)
30
+ offsets = jnp.arange(self.dim, dtype=conn.dtype) # (dim,)
31
+ dofs = base + offsets # (n_elems, n_nodes, dim)
32
+ return dofs.reshape(conn.shape[0], -1)
33
+
34
+
35
+ @jax.tree_util.register_pytree_node_class
36
+ @dataclass(eq=False)
37
+ class ScalarFormField:
38
+ """Scalar FE field evaluated on one element."""
39
+ N: jnp.ndarray # (n_q, n_nodes)
40
+ elem_coords: jnp.ndarray # (n_nodes, 3)
41
+ basis: Basis3D
42
+ _gradN: jnp.ndarray | None = None
43
+ _detJ: jnp.ndarray | None = None
44
+
45
+ @property
46
+ def gradN(self):
47
+ if self._gradN is None:
48
+ self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
49
+ return self._gradN
50
+
51
+ @property
52
+ def detJ(self):
53
+ if self._detJ is None:
54
+ self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
55
+ return self._detJ
56
+
57
+ def eval(self, u_elem: jnp.ndarray) -> jnp.ndarray:
58
+ # u_elem: (n_nodes,)
59
+ return jnp.einsum("qa,a->q", self.N, u_elem)
60
+
61
+ def grad(self, u_elem: jnp.ndarray) -> jnp.ndarray:
62
+ # returns (n_q, 3)
63
+ return jnp.einsum("qaj,a->qj", self.gradN, u_elem)
64
+
65
+ def tree_flatten(self):
66
+ children = (self.N, self.elem_coords, self._gradN, self._detJ)
67
+ aux = {"basis": self.basis}
68
+ return children, aux
69
+
70
+ @classmethod
71
+ def tree_unflatten(cls, aux, chirdren):
72
+ N, elem_coords, gradN, detJ = chirdren
73
+ return cls(N, elem_coords, aux["basis"], gradN, detJ)
74
+
75
+
76
+ @jax.tree_util.register_pytree_node_class
77
+ @dataclass(eq=False)
78
+ class VectorFormField:
79
+ """Vector-valued FE field evaluated on one element."""
80
+ N: jnp.ndarray
81
+ elem_coords: jnp.ndarray
82
+ basis: Basis3D
83
+ value_dim: int # ★Python int (static)
84
+ _gradN: jnp.ndarray | None = None
85
+ _detJ: jnp.ndarray | None = None
86
+
87
+ @property
88
+ def gradN(self):
89
+ if self._gradN is None:
90
+ self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
91
+ return self._gradN
92
+
93
+ @property
94
+ def detJ(self):
95
+ if self._detJ is None:
96
+ self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
97
+ return self._detJ
98
+
99
+ def eval(self, u_elem: jnp.ndarray) -> jnp.ndarray:
100
+ # u_elem: (value_dim*n_nodes,) expected
101
+ u_nodes = u_elem.reshape((-1, self.value_dim)) # (n_nodes, vd); vd is a Python int
102
+ return jnp.einsum("qa,ai->qi", self.N, u_nodes) # (n_q, vd)
103
+
104
+ def grad(self, u_elem: jnp.ndarray) -> jnp.ndarray:
105
+ u_nodes = u_elem.reshape((-1, self.value_dim))
106
+ return jnp.einsum("qaj,ai->qij", self.gradN, u_nodes) # (n_q, vd, 3)
107
+
108
+ def tree_flatten(self):
109
+ children = (self.N, self.elem_coords, self._gradN, self._detJ)
110
+ aux = {
111
+ "basis": self.basis,
112
+ "value_dim": int(self.value_dim)
113
+ }
114
+ return children, aux
115
+
116
+ @classmethod
117
+ def tree_unflatten(cls, aux, children):
118
+ N, elem_coords, gradN, detJ = children
119
+ return cls(N, elem_coords, aux["basis"], aux["value_dim"], gradN, detJ)
120
+
121
+
122
+ FormFieldLike = ScalarFormField | VectorFormField
123
+
124
+
125
+ def vector_load_form(field: FormFieldLike, load_vec: jnp.ndarray) -> jnp.ndarray:
126
+ """
127
+ Build vector linear form values from a FormField and a load vector.
128
+ """
129
+ lv = jnp.asarray(load_vec)
130
+ if lv.ndim == 1:
131
+ lv = lv[None, :]
132
+ elif lv.ndim != 2:
133
+ raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
134
+ if lv.shape[0] == 1:
135
+ lv = jnp.broadcast_to(lv, (field.N.shape[0], lv.shape[1]))
136
+ elif lv.shape[0] != field.N.shape[0]:
137
+ raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
138
+ load = field.N[..., None] * lv[:, None, :]
139
+ return load.reshape(load.shape[0], -1)
140
+
141
+
142
+ @jax.tree_util.register_pytree_node_class
143
+ @dataclass(eq=False)
144
+ class FormContext:
145
+ """Bundle test/trial fields and quadrature data for element assembly."""
146
+ test: FormFieldLike
147
+ trial: FormFieldLike
148
+ x_q: jnp.ndarray # (n_q, 3)
149
+ w: jnp.ndarray # (n_q,)
150
+ elem_id: jnp.ndarray | int = 0
151
+
152
+ @property
153
+ def u(self) -> FormFieldLike:
154
+ return self.trial
155
+
156
+ @property
157
+ def v(self) -> FormFieldLike:
158
+ return self.test
159
+
160
+ def tree_flatten(self):
161
+ children = (
162
+ self.test,
163
+ self.trial,
164
+ self.x_q,
165
+ self.w,
166
+ self.elem_id,
167
+ )
168
+ return children, {}
169
+
170
+ @classmethod
171
+ def tree_unflatten(cls, aux_data, children):
172
+ (
173
+ test,
174
+ trial,
175
+ x_q,
176
+ w,
177
+ elem_id,
178
+ ) = children
179
+ return cls(
180
+ test,
181
+ trial,
182
+ x_q,
183
+ w,
184
+ elem_id,
185
+ )
186
+
187
+
188
+ @dataclass(eq=False)
189
+ class FieldPair:
190
+ """Named test/trial/unknown grouping for mixed formulations."""
191
+ test: FormFieldLike
192
+ trial: FormFieldLike
193
+ unknown: FormFieldLike | None = None
194
+
195
+
196
+ @jax.tree_util.register_pytree_node_class
197
+ @dataclass(eq=False)
198
+ class MixedFormContext:
199
+ """FormContext for mixed formulations keyed by field name."""
200
+ fields: dict[str, FieldPair]
201
+ x_q: jnp.ndarray # (n_q, 3)
202
+ w: jnp.ndarray # (n_q,)
203
+ elem_id: jnp.ndarray | int = 0
204
+ unknown: FormFieldLike | None = None
205
+ trial_fields: dict[str, FormFieldLike] | None = None
206
+ test_fields: dict[str, FormFieldLike] | None = None
207
+ unknown_fields: dict[str, FormFieldLike] | None = None
208
+
209
+ def tree_flatten(self):
210
+ children = (
211
+ self.fields,
212
+ self.x_q,
213
+ self.w,
214
+ self.elem_id,
215
+ self.unknown,
216
+ self.trial_fields,
217
+ self.test_fields,
218
+ self.unknown_fields,
219
+ )
220
+ return children, {}
221
+
222
+ @classmethod
223
+ def tree_unflatten(cls, aux_data, children):
224
+ (
225
+ fields,
226
+ x_q,
227
+ w,
228
+ elem_id,
229
+ unknown,
230
+ trial_fields,
231
+ test_fields,
232
+ unknown_fields,
233
+ ) = children
234
+ return cls(fields, x_q, w, elem_id, unknown, trial_fields, test_fields, unknown_fields)
fluxfem/core/interp.py ADDED
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from .basis import HexTriLinearBasis
6
+ from .space import FESpace
7
+
8
+
9
+ def eval_shape_functions_hex8(xi_eta_zeta: np.ndarray) -> np.ndarray:
10
+ """
11
+ Evaluate trilinear Hex8 shape functions at given local coords (xi, eta, zeta) in [-1,1]^3.
12
+ Returns N with shape (n_q, 8).
13
+ """
14
+ pts = np.atleast_2d(np.asarray(xi_eta_zeta, dtype=float))
15
+ xi, eta, zeta = pts[:, 0], pts[:, 1], pts[:, 2]
16
+ N = np.stack(
17
+ [
18
+ 0.125 * (1 - xi) * (1 - eta) * (1 - zeta),
19
+ 0.125 * (1 + xi) * (1 - eta) * (1 - zeta),
20
+ 0.125 * (1 + xi) * (1 + eta) * (1 - zeta),
21
+ 0.125 * (1 - xi) * (1 + eta) * (1 - zeta),
22
+ 0.125 * (1 - xi) * (1 - eta) * (1 + zeta),
23
+ 0.125 * (1 + xi) * (1 - eta) * (1 + zeta),
24
+ 0.125 * (1 + xi) * (1 + eta) * (1 + zeta),
25
+ 0.125 * (1 - xi) * (1 + eta) * (1 + zeta),
26
+ ],
27
+ axis=1,
28
+ )
29
+ return N
30
+
31
+
32
+ def interpolate_field_at_element_points(space: FESpace, u: np.ndarray, xi_eta_zeta: np.ndarray) -> np.ndarray:
33
+ """
34
+ Interpolate vector field u (3 dof/node ordering) at given local points for all elements.
35
+ - xi_eta_zeta: (m,3) local coords in [-1,1]^3
36
+ Returns: (n_elem, m, 3)
37
+ """
38
+ if not isinstance(space.basis, HexTriLinearBasis):
39
+ raise NotImplementedError("interpolate_field_at_element_points currently supports Hex8 (trilinear) only.")
40
+ N = eval_shape_functions_hex8(xi_eta_zeta) # (m,8)
41
+ u_arr = np.asarray(u)
42
+ n_nodes = space.mesh.coords.shape[0]
43
+ if u_arr.shape[0] != 3 * n_nodes:
44
+ raise ValueError(f"Expected 3 dof/node; got {u_arr.shape[0]} for {n_nodes} nodes")
45
+ u_nodes = u_arr.reshape(n_nodes, 3)
46
+ conn = np.asarray(space.elem_dofs) // 3 # node indices
47
+ elem_u = u_nodes[conn] # (n_elem,8,3)
48
+ vals = np.einsum("pq,eqr->epr", N, elem_u) # (n_elem, m, 3)
49
+ return vals
50
+
51
+
52
+ __all__ = [
53
+ "eval_shape_functions_hex8",
54
+ "interpolate_field_at_element_points",
55
+ ]
fluxfem/core/solver.py ADDED
@@ -0,0 +1,113 @@
1
+ """
2
+ Helper to bridge JAX-assembled matrices back to NumPy/SciPy and solve.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import numpy as np
7
+ import jax.numpy as jnp
8
+ from typing import Any
9
+
10
+ try:
11
+ import scipy.sparse as sp
12
+ from scipy.sparse.linalg import spsolve
13
+ except Exception as exc: # pragma: no cover
14
+ raise ImportError("scipy is required for spsolve utilities") from exc
15
+
16
+
17
+ def coo_to_csr(rows: Any, cols: Any, data: Any, n_dofs: int):
18
+ """
19
+ Convert COO triplets to SciPy CSR matrix.
20
+ """
21
+ r = np.asarray(rows, dtype=np.int64)
22
+ c = np.asarray(cols, dtype=np.int64)
23
+ d = np.asarray(data)
24
+ return sp.csr_matrix((d, (r, c)), shape=(n_dofs, n_dofs))
25
+
26
+
27
+ def spdirect_solve_cpu(K: Any, F: jnp.ndarray, *, use_jax: bool = False) -> np.ndarray:
28
+ """
29
+ Convert JAX arrays to NumPy/SciPy and solve K u = F with sparse solver.
30
+ If use_jax=True, dispatch to JAX's experimental sparse spsolve.
31
+
32
+ Parameters
33
+ ----------
34
+ K : jnp.ndarray
35
+ Global stiffness matrix (n_dofs, n_dofs), dense or symmetric.
36
+ F : jnp.ndarray
37
+ Load vector (n_dofs,) or multiple RHS (n_dofs, n_rhs)
38
+
39
+ Returns
40
+ -------
41
+ np.ndarray
42
+ Solution vector u (n_dofs,) or (n_dofs, n_rhs)
43
+ """
44
+ if use_jax:
45
+ try:
46
+ return spdirect_solve_jax(K, F)
47
+ except Exception:
48
+ pass
49
+
50
+ if hasattr(K, "to_csr"):
51
+ K_csr = K.to_csr()
52
+ elif isinstance(K, tuple) and len(K) == 4:
53
+ K_csr = coo_to_csr(*K)
54
+ elif sp.issparse(K):
55
+ K_csr = K.tocsr()
56
+ else:
57
+ K_np = np.asarray(K)
58
+ K_csr = sp.csr_matrix(K_np)
59
+
60
+ F_np = np.asarray(F)
61
+ u = spsolve(K_csr, F_np)
62
+ return np.asarray(u)
63
+
64
+
65
+ def spdirect_solve_jax(K: Any, F: jnp.ndarray) -> np.ndarray:
66
+ """
67
+ Direct sparse solve in JAX via jax.experimental.sparse.linalg.spsolve.
68
+ Accepts FluxSparseMatrix or jax.experimental.sparse.BCOO.
69
+ """
70
+ try:
71
+ import jax
72
+ if jax.default_backend() == "cpu":
73
+ # JAX spsolve falls back to SciPy on CPU and can hit read-only buffers.
74
+ return spdirect_solve_cpu(K, F, use_jax=False)
75
+ except Exception:
76
+ pass
77
+ try:
78
+ from jax.experimental.sparse.linalg import spsolve as jspsolve
79
+ from jax.experimental import sparse as jsparse
80
+ except Exception as exc: # pragma: no cover
81
+ raise ImportError("jax.experimental.sparse is required for spdirect_solve_jax") from exc
82
+
83
+ if sp.issparse(K):
84
+ data = jnp.asarray(K.data)
85
+ indices = jnp.asarray(K.indices)
86
+ indptr = jnp.asarray(K.indptr)
87
+ F_arr = jnp.asarray(F)
88
+ if F_arr.ndim == 1:
89
+ return np.asarray(jspsolve(data, indices, indptr, F_arr))
90
+ return np.asarray(jnp.stack([jspsolve(data, indices, indptr, F_arr[:, i]) for i in range(F_arr.shape[1])], axis=1))
91
+
92
+ if isinstance(K, tuple) and len(K) == 4:
93
+ rows, cols, data, n_dofs = K
94
+ idx = jnp.stack([jnp.asarray(rows), jnp.asarray(cols)], axis=-1)
95
+ bcoo = jsparse.BCOO((jnp.asarray(data), idx), shape=(int(n_dofs), int(n_dofs)))
96
+ elif isinstance(K, jsparse.BCOO):
97
+ bcoo = K
98
+ elif hasattr(K, "to_bcoo"):
99
+ bcoo = K.to_bcoo()
100
+ else:
101
+ raise TypeError("spdirect_solve_jax expects FluxSparseMatrix, BCOO, CSR, or COO tuple")
102
+
103
+ bcsr = jsparse.BCSR.from_bcoo(bcoo)
104
+ F_arr = jnp.asarray(F)
105
+ if F_arr.ndim == 1:
106
+ return np.asarray(jspsolve(bcsr.data, bcsr.indices, bcsr.indptr, F_arr))
107
+ return np.asarray(jnp.stack([jspsolve(bcsr.data, bcsr.indices, bcsr.indptr, F_arr[:, i]) for i in range(F_arr.shape[1])], axis=1))
108
+
109
+ def spdirect_solve_gpu(K: Any, F: jnp.ndarray) -> np.ndarray:
110
+ """
111
+ GPU direct sparse solve via JAX experimental sparse solver.
112
+ """
113
+ return spdirect_solve_jax(K, F)