fluxfem 0.1.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fluxfem might be problematic. Click here for more details.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +318 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +828 -0
- fluxfem/helpers_ts.py +11 -0
- fluxfem/helpers_wf.py +44 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.3a0.dist-info/LICENSE +201 -0
- fluxfem-0.1.3a0.dist-info/METADATA +125 -0
- fluxfem-0.1.3a0.dist-info/RECORD +47 -0
- fluxfem-0.1.3a0.dist-info/WHEEL +4 -0
fluxfem/core/data.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(eq=False)
|
|
10
|
+
class MeshData:
|
|
11
|
+
"""Lightweight mesh data container for JAX-friendly serialization."""
|
|
12
|
+
coords: jnp.ndarray
|
|
13
|
+
conn: jnp.ndarray
|
|
14
|
+
cell_tags: jnp.ndarray | None = None
|
|
15
|
+
node_tags: jnp.ndarray | None = None
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def from_mesh(cls, mesh: Any) -> "MeshData":
|
|
19
|
+
return cls(
|
|
20
|
+
coords=jnp.asarray(mesh.coords),
|
|
21
|
+
conn=jnp.asarray(mesh.conn),
|
|
22
|
+
cell_tags=None if mesh.cell_tags is None else jnp.asarray(mesh.cell_tags),
|
|
23
|
+
node_tags=None if mesh.node_tags is None else jnp.asarray(mesh.node_tags),
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(eq=False)
|
|
28
|
+
class BasisData:
|
|
29
|
+
"""Quadrature and basis metadata for reproducible assembly."""
|
|
30
|
+
quad_points: jnp.ndarray
|
|
31
|
+
quad_weights: jnp.ndarray
|
|
32
|
+
dofs_per_node: int
|
|
33
|
+
kind: str
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def from_basis(cls, basis: Any) -> "BasisData":
|
|
37
|
+
return cls(
|
|
38
|
+
quad_points=jnp.asarray(basis.quad_points),
|
|
39
|
+
quad_weights=jnp.asarray(basis.quad_weights),
|
|
40
|
+
dofs_per_node=int(basis.dofs_per_node),
|
|
41
|
+
kind=type(basis).__name__,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(eq=False)
|
|
46
|
+
class SpaceData:
|
|
47
|
+
"""Snapshot of space-related data used in assembly."""
|
|
48
|
+
mesh: MeshData
|
|
49
|
+
basis: BasisData
|
|
50
|
+
elem_dofs: jnp.ndarray
|
|
51
|
+
value_dim: int
|
|
52
|
+
n_dofs: int
|
|
53
|
+
n_ldofs: int
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_space(cls, space: Any) -> "SpaceData":
|
|
57
|
+
return cls(
|
|
58
|
+
mesh=MeshData.from_mesh(space.mesh),
|
|
59
|
+
basis=BasisData.from_basis(space.basis),
|
|
60
|
+
elem_dofs=jnp.asarray(space.elem_dofs),
|
|
61
|
+
value_dim=int(space.value_dim),
|
|
62
|
+
n_dofs=int(space.n_dofs),
|
|
63
|
+
n_ldofs=int(space.n_ldofs),
|
|
64
|
+
)
|
fluxfem/core/dtypes.py
ADDED
fluxfem/core/forms.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from .basis import Basis3D
|
|
8
|
+
|
|
9
|
+
# FormContext/ScalarFormField/VectorFormField were dataclasses with the default __eq__. JAX ended up calling that
|
|
10
|
+
# __eq__ during the vmap over residuals, which tries to compare array fields element‑wise and then coerce to a
|
|
11
|
+
# bool, triggering “truth value of an array is ambiguous.” Setting eq=False (and for ElementVector for consistency)
|
|
12
|
+
# removes the autogenerated __eq__, so vmap no longer evaluates array equality and the residual/Jacobian assembly
|
|
13
|
+
# succeeds.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(eq=False)
|
|
17
|
+
class ElementVector:
|
|
18
|
+
"""
|
|
19
|
+
Simple vector-valued element wrapper (scikit-fem ElementVector style).
|
|
20
|
+
dim: dofs per node (e.g., 3 for displacement)
|
|
21
|
+
"""
|
|
22
|
+
dim: int
|
|
23
|
+
|
|
24
|
+
def dof_map(self, conn: jnp.ndarray) -> jnp.ndarray:
|
|
25
|
+
"""
|
|
26
|
+
Expand scalar connectivity (n_elems, n_nodes_per_elem) to vector dofs.
|
|
27
|
+
Returns shape (n_elems, n_nodes_per_elem * dim).
|
|
28
|
+
"""
|
|
29
|
+
base = conn[..., None] * self.dim # (n_elems, n_nodes, 1)
|
|
30
|
+
offsets = jnp.arange(self.dim, dtype=conn.dtype) # (dim,)
|
|
31
|
+
dofs = base + offsets # (n_elems, n_nodes, dim)
|
|
32
|
+
return dofs.reshape(conn.shape[0], -1)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@jax.tree_util.register_pytree_node_class
|
|
36
|
+
@dataclass(eq=False)
|
|
37
|
+
class ScalarFormField:
|
|
38
|
+
"""Scalar FE field evaluated on one element."""
|
|
39
|
+
N: jnp.ndarray # (n_q, n_nodes)
|
|
40
|
+
elem_coords: jnp.ndarray # (n_nodes, 3)
|
|
41
|
+
basis: Basis3D
|
|
42
|
+
_gradN: jnp.ndarray | None = None
|
|
43
|
+
_detJ: jnp.ndarray | None = None
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def gradN(self):
|
|
47
|
+
if self._gradN is None:
|
|
48
|
+
self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
|
|
49
|
+
return self._gradN
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def detJ(self):
|
|
53
|
+
if self._detJ is None:
|
|
54
|
+
self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
|
|
55
|
+
return self._detJ
|
|
56
|
+
|
|
57
|
+
def eval(self, u_elem: jnp.ndarray) -> jnp.ndarray:
|
|
58
|
+
# u_elem: (n_nodes,)
|
|
59
|
+
return jnp.einsum("qa,a->q", self.N, u_elem)
|
|
60
|
+
|
|
61
|
+
def grad(self, u_elem: jnp.ndarray) -> jnp.ndarray:
|
|
62
|
+
# returns (n_q, 3)
|
|
63
|
+
return jnp.einsum("qaj,a->qj", self.gradN, u_elem)
|
|
64
|
+
|
|
65
|
+
def tree_flatten(self):
|
|
66
|
+
children = (self.N, self.elem_coords, self._gradN, self._detJ)
|
|
67
|
+
aux = {"basis": self.basis}
|
|
68
|
+
return children, aux
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def tree_unflatten(cls, aux, chirdren):
|
|
72
|
+
N, elem_coords, gradN, detJ = chirdren
|
|
73
|
+
return cls(N, elem_coords, aux["basis"], gradN, detJ)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@jax.tree_util.register_pytree_node_class
|
|
77
|
+
@dataclass(eq=False)
|
|
78
|
+
class VectorFormField:
|
|
79
|
+
"""Vector-valued FE field evaluated on one element."""
|
|
80
|
+
N: jnp.ndarray
|
|
81
|
+
elem_coords: jnp.ndarray
|
|
82
|
+
basis: Basis3D
|
|
83
|
+
value_dim: int # ★Python int (static)
|
|
84
|
+
_gradN: jnp.ndarray | None = None
|
|
85
|
+
_detJ: jnp.ndarray | None = None
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def gradN(self):
|
|
89
|
+
if self._gradN is None:
|
|
90
|
+
self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
|
|
91
|
+
return self._gradN
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def detJ(self):
|
|
95
|
+
if self._detJ is None:
|
|
96
|
+
self._gradN, self._detJ = self.basis.spatial_grads_and_detJ(self.elem_coords)
|
|
97
|
+
return self._detJ
|
|
98
|
+
|
|
99
|
+
def eval(self, u_elem: jnp.ndarray) -> jnp.ndarray:
|
|
100
|
+
# u_elem: (value_dim*n_nodes,) expected
|
|
101
|
+
u_nodes = u_elem.reshape((-1, self.value_dim)) # (n_nodes, vd); vd is a Python int
|
|
102
|
+
return jnp.einsum("qa,ai->qi", self.N, u_nodes) # (n_q, vd)
|
|
103
|
+
|
|
104
|
+
def grad(self, u_elem: jnp.ndarray) -> jnp.ndarray:
|
|
105
|
+
u_nodes = u_elem.reshape((-1, self.value_dim))
|
|
106
|
+
return jnp.einsum("qaj,ai->qij", self.gradN, u_nodes) # (n_q, vd, 3)
|
|
107
|
+
|
|
108
|
+
def tree_flatten(self):
|
|
109
|
+
children = (self.N, self.elem_coords, self._gradN, self._detJ)
|
|
110
|
+
aux = {
|
|
111
|
+
"basis": self.basis,
|
|
112
|
+
"value_dim": int(self.value_dim)
|
|
113
|
+
}
|
|
114
|
+
return children, aux
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def tree_unflatten(cls, aux, children):
|
|
118
|
+
N, elem_coords, gradN, detJ = children
|
|
119
|
+
return cls(N, elem_coords, aux["basis"], aux["value_dim"], gradN, detJ)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
FormFieldLike = ScalarFormField | VectorFormField
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def vector_load_form(field: FormFieldLike, load_vec: jnp.ndarray) -> jnp.ndarray:
|
|
126
|
+
"""
|
|
127
|
+
Build vector linear form values from a FormField and a load vector.
|
|
128
|
+
"""
|
|
129
|
+
lv = jnp.asarray(load_vec)
|
|
130
|
+
if lv.ndim == 1:
|
|
131
|
+
lv = lv[None, :]
|
|
132
|
+
elif lv.ndim != 2:
|
|
133
|
+
raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
|
|
134
|
+
if lv.shape[0] == 1:
|
|
135
|
+
lv = jnp.broadcast_to(lv, (field.N.shape[0], lv.shape[1]))
|
|
136
|
+
elif lv.shape[0] != field.N.shape[0]:
|
|
137
|
+
raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
|
|
138
|
+
load = field.N[..., None] * lv[:, None, :]
|
|
139
|
+
return load.reshape(load.shape[0], -1)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@jax.tree_util.register_pytree_node_class
|
|
143
|
+
@dataclass(eq=False)
|
|
144
|
+
class FormContext:
|
|
145
|
+
"""Bundle test/trial fields and quadrature data for element assembly."""
|
|
146
|
+
test: FormFieldLike
|
|
147
|
+
trial: FormFieldLike
|
|
148
|
+
x_q: jnp.ndarray # (n_q, 3)
|
|
149
|
+
w: jnp.ndarray # (n_q,)
|
|
150
|
+
elem_id: jnp.ndarray | int = 0
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def u(self) -> FormFieldLike:
|
|
154
|
+
return self.trial
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def v(self) -> FormFieldLike:
|
|
158
|
+
return self.test
|
|
159
|
+
|
|
160
|
+
def tree_flatten(self):
|
|
161
|
+
children = (
|
|
162
|
+
self.test,
|
|
163
|
+
self.trial,
|
|
164
|
+
self.x_q,
|
|
165
|
+
self.w,
|
|
166
|
+
self.elem_id,
|
|
167
|
+
)
|
|
168
|
+
return children, {}
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def tree_unflatten(cls, aux_data, children):
|
|
172
|
+
(
|
|
173
|
+
test,
|
|
174
|
+
trial,
|
|
175
|
+
x_q,
|
|
176
|
+
w,
|
|
177
|
+
elem_id,
|
|
178
|
+
) = children
|
|
179
|
+
return cls(
|
|
180
|
+
test,
|
|
181
|
+
trial,
|
|
182
|
+
x_q,
|
|
183
|
+
w,
|
|
184
|
+
elem_id,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass(eq=False)
|
|
189
|
+
class FieldPair:
|
|
190
|
+
"""Named test/trial/unknown grouping for mixed formulations."""
|
|
191
|
+
test: FormFieldLike
|
|
192
|
+
trial: FormFieldLike
|
|
193
|
+
unknown: FormFieldLike | None = None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@jax.tree_util.register_pytree_node_class
|
|
197
|
+
@dataclass(eq=False)
|
|
198
|
+
class MixedFormContext:
|
|
199
|
+
"""FormContext for mixed formulations keyed by field name."""
|
|
200
|
+
fields: dict[str, FieldPair]
|
|
201
|
+
x_q: jnp.ndarray # (n_q, 3)
|
|
202
|
+
w: jnp.ndarray # (n_q,)
|
|
203
|
+
elem_id: jnp.ndarray | int = 0
|
|
204
|
+
unknown: FormFieldLike | None = None
|
|
205
|
+
trial_fields: dict[str, FormFieldLike] | None = None
|
|
206
|
+
test_fields: dict[str, FormFieldLike] | None = None
|
|
207
|
+
unknown_fields: dict[str, FormFieldLike] | None = None
|
|
208
|
+
|
|
209
|
+
def tree_flatten(self):
|
|
210
|
+
children = (
|
|
211
|
+
self.fields,
|
|
212
|
+
self.x_q,
|
|
213
|
+
self.w,
|
|
214
|
+
self.elem_id,
|
|
215
|
+
self.unknown,
|
|
216
|
+
self.trial_fields,
|
|
217
|
+
self.test_fields,
|
|
218
|
+
self.unknown_fields,
|
|
219
|
+
)
|
|
220
|
+
return children, {}
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def tree_unflatten(cls, aux_data, children):
|
|
224
|
+
(
|
|
225
|
+
fields,
|
|
226
|
+
x_q,
|
|
227
|
+
w,
|
|
228
|
+
elem_id,
|
|
229
|
+
unknown,
|
|
230
|
+
trial_fields,
|
|
231
|
+
test_fields,
|
|
232
|
+
unknown_fields,
|
|
233
|
+
) = children
|
|
234
|
+
return cls(fields, x_q, w, elem_id, unknown, trial_fields, test_fields, unknown_fields)
|
fluxfem/core/interp.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .basis import HexTriLinearBasis
|
|
6
|
+
from .space import FESpace
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def eval_shape_functions_hex8(xi_eta_zeta: np.ndarray) -> np.ndarray:
|
|
10
|
+
"""
|
|
11
|
+
Evaluate trilinear Hex8 shape functions at given local coords (xi, eta, zeta) in [-1,1]^3.
|
|
12
|
+
Returns N with shape (n_q, 8).
|
|
13
|
+
"""
|
|
14
|
+
pts = np.atleast_2d(np.asarray(xi_eta_zeta, dtype=float))
|
|
15
|
+
xi, eta, zeta = pts[:, 0], pts[:, 1], pts[:, 2]
|
|
16
|
+
N = np.stack(
|
|
17
|
+
[
|
|
18
|
+
0.125 * (1 - xi) * (1 - eta) * (1 - zeta),
|
|
19
|
+
0.125 * (1 + xi) * (1 - eta) * (1 - zeta),
|
|
20
|
+
0.125 * (1 + xi) * (1 + eta) * (1 - zeta),
|
|
21
|
+
0.125 * (1 - xi) * (1 + eta) * (1 - zeta),
|
|
22
|
+
0.125 * (1 - xi) * (1 - eta) * (1 + zeta),
|
|
23
|
+
0.125 * (1 + xi) * (1 - eta) * (1 + zeta),
|
|
24
|
+
0.125 * (1 + xi) * (1 + eta) * (1 + zeta),
|
|
25
|
+
0.125 * (1 - xi) * (1 + eta) * (1 + zeta),
|
|
26
|
+
],
|
|
27
|
+
axis=1,
|
|
28
|
+
)
|
|
29
|
+
return N
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def interpolate_field_at_element_points(space: FESpace, u: np.ndarray, xi_eta_zeta: np.ndarray) -> np.ndarray:
|
|
33
|
+
"""
|
|
34
|
+
Interpolate vector field u (3 dof/node ordering) at given local points for all elements.
|
|
35
|
+
- xi_eta_zeta: (m,3) local coords in [-1,1]^3
|
|
36
|
+
Returns: (n_elem, m, 3)
|
|
37
|
+
"""
|
|
38
|
+
if not isinstance(space.basis, HexTriLinearBasis):
|
|
39
|
+
raise NotImplementedError("interpolate_field_at_element_points currently supports Hex8 (trilinear) only.")
|
|
40
|
+
N = eval_shape_functions_hex8(xi_eta_zeta) # (m,8)
|
|
41
|
+
u_arr = np.asarray(u)
|
|
42
|
+
n_nodes = space.mesh.coords.shape[0]
|
|
43
|
+
if u_arr.shape[0] != 3 * n_nodes:
|
|
44
|
+
raise ValueError(f"Expected 3 dof/node; got {u_arr.shape[0]} for {n_nodes} nodes")
|
|
45
|
+
u_nodes = u_arr.reshape(n_nodes, 3)
|
|
46
|
+
conn = np.asarray(space.elem_dofs) // 3 # node indices
|
|
47
|
+
elem_u = u_nodes[conn] # (n_elem,8,3)
|
|
48
|
+
vals = np.einsum("pq,eqr->epr", N, elem_u) # (n_elem, m, 3)
|
|
49
|
+
return vals
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"eval_shape_functions_hex8",
|
|
54
|
+
"interpolate_field_at_element_points",
|
|
55
|
+
]
|
fluxfem/core/solver.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper to bridge JAX-assembled matrices back to NumPy/SciPy and solve.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import scipy.sparse as sp
|
|
12
|
+
from scipy.sparse.linalg import spsolve
|
|
13
|
+
except Exception as exc: # pragma: no cover
|
|
14
|
+
raise ImportError("scipy is required for spsolve utilities") from exc
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def coo_to_csr(rows: Any, cols: Any, data: Any, n_dofs: int):
|
|
18
|
+
"""
|
|
19
|
+
Convert COO triplets to SciPy CSR matrix.
|
|
20
|
+
"""
|
|
21
|
+
r = np.asarray(rows, dtype=np.int64)
|
|
22
|
+
c = np.asarray(cols, dtype=np.int64)
|
|
23
|
+
d = np.asarray(data)
|
|
24
|
+
return sp.csr_matrix((d, (r, c)), shape=(n_dofs, n_dofs))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def spdirect_solve_cpu(K: Any, F: jnp.ndarray, *, use_jax: bool = False) -> np.ndarray:
|
|
28
|
+
"""
|
|
29
|
+
Convert JAX arrays to NumPy/SciPy and solve K u = F with sparse solver.
|
|
30
|
+
If use_jax=True, dispatch to JAX's experimental sparse spsolve.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
K : jnp.ndarray
|
|
35
|
+
Global stiffness matrix (n_dofs, n_dofs), dense or symmetric.
|
|
36
|
+
F : jnp.ndarray
|
|
37
|
+
Load vector (n_dofs,) or multiple RHS (n_dofs, n_rhs)
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
np.ndarray
|
|
42
|
+
Solution vector u (n_dofs,) or (n_dofs, n_rhs)
|
|
43
|
+
"""
|
|
44
|
+
if use_jax:
|
|
45
|
+
try:
|
|
46
|
+
return spdirect_solve_jax(K, F)
|
|
47
|
+
except Exception:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
if hasattr(K, "to_csr"):
|
|
51
|
+
K_csr = K.to_csr()
|
|
52
|
+
elif isinstance(K, tuple) and len(K) == 4:
|
|
53
|
+
K_csr = coo_to_csr(*K)
|
|
54
|
+
elif sp.issparse(K):
|
|
55
|
+
K_csr = K.tocsr()
|
|
56
|
+
else:
|
|
57
|
+
K_np = np.asarray(K)
|
|
58
|
+
K_csr = sp.csr_matrix(K_np)
|
|
59
|
+
|
|
60
|
+
F_np = np.asarray(F)
|
|
61
|
+
u = spsolve(K_csr, F_np)
|
|
62
|
+
return np.asarray(u)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def spdirect_solve_jax(K: Any, F: jnp.ndarray) -> np.ndarray:
|
|
66
|
+
"""
|
|
67
|
+
Direct sparse solve in JAX via jax.experimental.sparse.linalg.spsolve.
|
|
68
|
+
Accepts FluxSparseMatrix or jax.experimental.sparse.BCOO.
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
import jax
|
|
72
|
+
if jax.default_backend() == "cpu":
|
|
73
|
+
# JAX spsolve falls back to SciPy on CPU and can hit read-only buffers.
|
|
74
|
+
return spdirect_solve_cpu(K, F, use_jax=False)
|
|
75
|
+
except Exception:
|
|
76
|
+
pass
|
|
77
|
+
try:
|
|
78
|
+
from jax.experimental.sparse.linalg import spsolve as jspsolve
|
|
79
|
+
from jax.experimental import sparse as jsparse
|
|
80
|
+
except Exception as exc: # pragma: no cover
|
|
81
|
+
raise ImportError("jax.experimental.sparse is required for spdirect_solve_jax") from exc
|
|
82
|
+
|
|
83
|
+
if sp.issparse(K):
|
|
84
|
+
data = jnp.asarray(K.data)
|
|
85
|
+
indices = jnp.asarray(K.indices)
|
|
86
|
+
indptr = jnp.asarray(K.indptr)
|
|
87
|
+
F_arr = jnp.asarray(F)
|
|
88
|
+
if F_arr.ndim == 1:
|
|
89
|
+
return np.asarray(jspsolve(data, indices, indptr, F_arr))
|
|
90
|
+
return np.asarray(jnp.stack([jspsolve(data, indices, indptr, F_arr[:, i]) for i in range(F_arr.shape[1])], axis=1))
|
|
91
|
+
|
|
92
|
+
if isinstance(K, tuple) and len(K) == 4:
|
|
93
|
+
rows, cols, data, n_dofs = K
|
|
94
|
+
idx = jnp.stack([jnp.asarray(rows), jnp.asarray(cols)], axis=-1)
|
|
95
|
+
bcoo = jsparse.BCOO((jnp.asarray(data), idx), shape=(int(n_dofs), int(n_dofs)))
|
|
96
|
+
elif isinstance(K, jsparse.BCOO):
|
|
97
|
+
bcoo = K
|
|
98
|
+
elif hasattr(K, "to_bcoo"):
|
|
99
|
+
bcoo = K.to_bcoo()
|
|
100
|
+
else:
|
|
101
|
+
raise TypeError("spdirect_solve_jax expects FluxSparseMatrix, BCOO, CSR, or COO tuple")
|
|
102
|
+
|
|
103
|
+
bcsr = jsparse.BCSR.from_bcoo(bcoo)
|
|
104
|
+
F_arr = jnp.asarray(F)
|
|
105
|
+
if F_arr.ndim == 1:
|
|
106
|
+
return np.asarray(jspsolve(bcsr.data, bcsr.indices, bcsr.indptr, F_arr))
|
|
107
|
+
return np.asarray(jnp.stack([jspsolve(bcsr.data, bcsr.indices, bcsr.indptr, F_arr[:, i]) for i in range(F_arr.shape[1])], axis=1))
|
|
108
|
+
|
|
109
|
+
def spdirect_solve_gpu(K: Any, F: jnp.ndarray) -> np.ndarray:
|
|
110
|
+
"""
|
|
111
|
+
GPU direct sparse solve via JAX experimental sparse solver.
|
|
112
|
+
"""
|
|
113
|
+
return spdirect_solve_jax(K, F)
|