fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +68 -0
- fluxfem/core/__init__.py +115 -10
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +768 -7
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +316 -7
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +6 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Mapping, Sequence, Callable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
|
|
9
|
+
from .dtypes import INDEX_DTYPE
|
|
10
|
+
from .forms import MixedFormContext, FieldPair
|
|
11
|
+
from .weakform import MixedWeakForm, compile_mixed_residual, make_mixed_residuals
|
|
12
|
+
from ..solver.dirichlet import DirichletBC, free_dofs
|
|
13
|
+
from ..solver.sparse import FluxSparseMatrix
|
|
14
|
+
from .space import FESpaceClosure
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(eq=False)
|
|
18
|
+
class MixedFESpace:
|
|
19
|
+
"""
|
|
20
|
+
Mixed FE space composed of multiple scalar/vector spaces.
|
|
21
|
+
|
|
22
|
+
Field DOFs are concatenated in field order:
|
|
23
|
+
[field0 dofs | field1 dofs | ...]
|
|
24
|
+
"""
|
|
25
|
+
fields: dict[str, FESpaceClosure]
|
|
26
|
+
field_order: Sequence[str] | None = None
|
|
27
|
+
field_names: tuple[str, ...] = field(init=False)
|
|
28
|
+
field_offsets: dict[str, int] = field(init=False)
|
|
29
|
+
field_slices: dict[str, slice] = field(init=False)
|
|
30
|
+
elem_slices: dict[str, slice] = field(init=False)
|
|
31
|
+
elem_dofs_by_field: dict[str, jnp.ndarray] = field(init=False)
|
|
32
|
+
elem_dofs: jnp.ndarray = field(init=False)
|
|
33
|
+
n_dofs: int = field(init=False)
|
|
34
|
+
n_ldofs: int = field(init=False)
|
|
35
|
+
|
|
36
|
+
def __post_init__(self):
|
|
37
|
+
if not self.fields:
|
|
38
|
+
raise ValueError("MixedFESpace requires at least one field.")
|
|
39
|
+
|
|
40
|
+
if self.field_order is None:
|
|
41
|
+
self.field_names = tuple(self.fields.keys())
|
|
42
|
+
else:
|
|
43
|
+
self.field_names = tuple(self.field_order)
|
|
44
|
+
missing = set(self.fields.keys()) - set(self.field_names)
|
|
45
|
+
extra = set(self.field_names) - set(self.fields.keys())
|
|
46
|
+
if missing or extra:
|
|
47
|
+
raise ValueError(f"field_order mismatch: missing={missing}, extra={extra}")
|
|
48
|
+
|
|
49
|
+
ref_space = self.fields[self.field_names[0]]
|
|
50
|
+
ref_mesh = ref_space.mesh
|
|
51
|
+
ref_basis = ref_space.basis
|
|
52
|
+
n_elems = int(ref_space.elem_dofs.shape[0])
|
|
53
|
+
|
|
54
|
+
offsets: dict[str, int] = {}
|
|
55
|
+
slices: dict[str, slice] = {}
|
|
56
|
+
elem_slices: dict[str, slice] = {}
|
|
57
|
+
elem_dofs_by_field: dict[str, jnp.ndarray] = {}
|
|
58
|
+
elem_dofs_list = []
|
|
59
|
+
|
|
60
|
+
dof_offset = 0
|
|
61
|
+
ldof_offset = 0
|
|
62
|
+
for name in self.field_names:
|
|
63
|
+
space = self.fields[name]
|
|
64
|
+
if space.mesh is not ref_mesh:
|
|
65
|
+
raise ValueError("All mixed fields must share the same mesh object.")
|
|
66
|
+
if space.basis.__class__ is not ref_basis.__class__:
|
|
67
|
+
raise ValueError("All mixed fields must share the same basis type.")
|
|
68
|
+
if int(space.elem_dofs.shape[0]) != n_elems:
|
|
69
|
+
raise ValueError("All mixed fields must have the same element count.")
|
|
70
|
+
|
|
71
|
+
n_dofs = int(space.n_dofs)
|
|
72
|
+
n_ldofs = int(space.n_ldofs)
|
|
73
|
+
offsets[name] = dof_offset
|
|
74
|
+
slices[name] = slice(dof_offset, dof_offset + n_dofs)
|
|
75
|
+
elem_slices[name] = slice(ldof_offset, ldof_offset + n_ldofs)
|
|
76
|
+
|
|
77
|
+
elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE) + dof_offset
|
|
78
|
+
elem_dofs_by_field[name] = elem_dofs
|
|
79
|
+
elem_dofs_list.append(elem_dofs)
|
|
80
|
+
|
|
81
|
+
dof_offset += n_dofs
|
|
82
|
+
ldof_offset += n_ldofs
|
|
83
|
+
|
|
84
|
+
self.field_offsets = offsets
|
|
85
|
+
self.field_slices = slices
|
|
86
|
+
self.elem_slices = elem_slices
|
|
87
|
+
self.elem_dofs_by_field = elem_dofs_by_field
|
|
88
|
+
self.elem_dofs = jnp.concatenate(elem_dofs_list, axis=1)
|
|
89
|
+
self.n_dofs = dof_offset
|
|
90
|
+
self.n_ldofs = ldof_offset
|
|
91
|
+
|
|
92
|
+
def pack_fields(self, fields: Mapping[str, jnp.ndarray]) -> jnp.ndarray:
|
|
93
|
+
"""Concatenate per-field vectors into a single mixed vector."""
|
|
94
|
+
parts = []
|
|
95
|
+
for name in self.field_names:
|
|
96
|
+
if name not in fields:
|
|
97
|
+
raise KeyError(f"Missing field '{name}' in pack_fields.")
|
|
98
|
+
parts.append(jnp.asarray(fields[name]))
|
|
99
|
+
return jnp.concatenate(parts, axis=0)
|
|
100
|
+
|
|
101
|
+
def unpack_fields(self, u: jnp.ndarray) -> dict[str, jnp.ndarray]:
|
|
102
|
+
"""Split a mixed vector into per-field vectors."""
|
|
103
|
+
u = jnp.asarray(u)
|
|
104
|
+
return {name: u[self.field_slices[name]] for name in self.field_names}
|
|
105
|
+
|
|
106
|
+
def split_element_vector(self, u_elem: jnp.ndarray) -> dict[str, jnp.ndarray]:
|
|
107
|
+
"""Split an element-local mixed vector into per-field element vectors."""
|
|
108
|
+
return {name: u_elem[self.elem_slices[name]] for name in self.field_names}
|
|
109
|
+
|
|
110
|
+
def build_form_contexts(self, dep: jnp.ndarray | None = None) -> MixedFormContext:
|
|
111
|
+
ctxs_by_field = {name: sp.build_form_contexts(dep) for name, sp in self.fields.items()}
|
|
112
|
+
ref_ctx = ctxs_by_field[self.field_names[0]]
|
|
113
|
+
|
|
114
|
+
fields = {
|
|
115
|
+
name: FieldPair(test=ctx.test, trial=ctx.trial, unknown=None)
|
|
116
|
+
for name, ctx in ctxs_by_field.items()
|
|
117
|
+
}
|
|
118
|
+
trial_fields = {name: ctx.trial for name, ctx in ctxs_by_field.items()}
|
|
119
|
+
test_fields = {name: ctx.test for name, ctx in ctxs_by_field.items()}
|
|
120
|
+
unknown_fields = {name: ctx.trial for name, ctx in ctxs_by_field.items()}
|
|
121
|
+
|
|
122
|
+
return MixedFormContext(
|
|
123
|
+
fields=fields,
|
|
124
|
+
x_q=ref_ctx.x_q,
|
|
125
|
+
w=ref_ctx.w,
|
|
126
|
+
elem_id=ref_ctx.elem_id,
|
|
127
|
+
trial_fields=trial_fields,
|
|
128
|
+
test_fields=test_fields,
|
|
129
|
+
unknown_fields=unknown_fields,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def get_sparsity_pattern(self, *, with_idx: bool = True):
|
|
133
|
+
from .assembly import make_sparsity_pattern
|
|
134
|
+
return make_sparsity_pattern(self, with_idx=with_idx)
|
|
135
|
+
|
|
136
|
+
def assemble_residual(self, res_form, u, params, **kwargs):
|
|
137
|
+
from .mixed_assembly import assemble_mixed_residual
|
|
138
|
+
return assemble_mixed_residual(self, res_form, u, params, **kwargs)
|
|
139
|
+
|
|
140
|
+
def assemble_jacobian(self, res_form, u, params, **kwargs):
|
|
141
|
+
from .mixed_assembly import assemble_mixed_jacobian
|
|
142
|
+
return assemble_mixed_jacobian(self, res_form, u, params, **kwargs)
|
|
143
|
+
|
|
144
|
+
def make_dirichlet(self, *, merge: str = "check_equal", **fields):
|
|
145
|
+
"""
|
|
146
|
+
Build mixed Dirichlet BCs from per-field constraints.
|
|
147
|
+
|
|
148
|
+
Usage:
|
|
149
|
+
bc = mixed.make_dirichlet(u=DirichletBC(...), T=(dofs, vals))
|
|
150
|
+
"""
|
|
151
|
+
if merge not in {"check_equal", "error", "first", "last"}:
|
|
152
|
+
raise ValueError("merge must be one of: check_equal, error, first, last")
|
|
153
|
+
|
|
154
|
+
dof_map: dict[int, float] = {}
|
|
155
|
+
for name, spec in fields.items():
|
|
156
|
+
if name not in self.field_offsets:
|
|
157
|
+
raise KeyError(f"Unknown mixed field: {name}")
|
|
158
|
+
offset = int(self.field_offsets[name])
|
|
159
|
+
if isinstance(spec, DirichletBC):
|
|
160
|
+
dofs = spec.dofs
|
|
161
|
+
vals = spec.vals
|
|
162
|
+
elif isinstance(spec, tuple) and len(spec) == 2:
|
|
163
|
+
dofs, vals = spec
|
|
164
|
+
else:
|
|
165
|
+
dofs, vals = spec, None
|
|
166
|
+
bc = DirichletBC(dofs, vals)
|
|
167
|
+
g_dofs = np.asarray(bc.dofs, dtype=int) + offset
|
|
168
|
+
g_vals = np.asarray(bc.vals, dtype=float)
|
|
169
|
+
for d, v in zip(g_dofs, g_vals):
|
|
170
|
+
if d in dof_map:
|
|
171
|
+
if merge == "error":
|
|
172
|
+
raise ValueError(f"Duplicate Dirichlet DOF {d} in mixed BCs")
|
|
173
|
+
if merge == "check_equal":
|
|
174
|
+
if not np.isclose(dof_map[d], v):
|
|
175
|
+
raise ValueError(f"Conflicting Dirichlet value for DOF {d}")
|
|
176
|
+
if merge == "first":
|
|
177
|
+
continue
|
|
178
|
+
dof_map[d] = float(v)
|
|
179
|
+
|
|
180
|
+
if not dof_map:
|
|
181
|
+
return MixedDirichletBC(np.array([], dtype=int), np.array([], dtype=float))
|
|
182
|
+
dofs_sorted = np.array(sorted(dof_map.keys()), dtype=int)
|
|
183
|
+
vals_sorted = np.array([dof_map[d] for d in dofs_sorted], dtype=float)
|
|
184
|
+
return MixedDirichletBC(dofs_sorted, vals_sorted)
|
|
185
|
+
|
|
186
|
+
def build_block_system(
|
|
187
|
+
self,
|
|
188
|
+
*,
|
|
189
|
+
diag: Mapping[str, object] | Sequence[object],
|
|
190
|
+
rel: Mapping[tuple[str, str], object] | None = None,
|
|
191
|
+
add_contiguous: object | None = None,
|
|
192
|
+
rhs: Mapping[str, object] | Sequence[object] | np.ndarray | None = None,
|
|
193
|
+
constraints=None,
|
|
194
|
+
merge: str = "check_equal",
|
|
195
|
+
format: str = "auto",
|
|
196
|
+
symmetric: bool = False,
|
|
197
|
+
transpose_rule: str = "T",
|
|
198
|
+
):
|
|
199
|
+
"""
|
|
200
|
+
Build a mixed block system and apply optional constraints.
|
|
201
|
+
"""
|
|
202
|
+
from ..solver.block_system import build_block_system as _build_block_system
|
|
203
|
+
|
|
204
|
+
sizes = {name: int(self.fields[name].n_dofs) for name in self.field_names}
|
|
205
|
+
|
|
206
|
+
if isinstance(constraints, MixedDirichletBC):
|
|
207
|
+
constraints = constraints.as_dirichlet_bc()
|
|
208
|
+
|
|
209
|
+
system = _build_block_system(
|
|
210
|
+
diag=diag,
|
|
211
|
+
rel=rel,
|
|
212
|
+
add_contiguous=add_contiguous,
|
|
213
|
+
rhs=rhs,
|
|
214
|
+
constraints=constraints,
|
|
215
|
+
merge=merge,
|
|
216
|
+
sizes=sizes,
|
|
217
|
+
format=format,
|
|
218
|
+
symmetric=symmetric,
|
|
219
|
+
transpose_rule=transpose_rule,
|
|
220
|
+
)
|
|
221
|
+
bc = MixedDirichletBC(system.dirichlet.dofs, system.dirichlet.vals)
|
|
222
|
+
return MixedBlockSystem(self, system.K, system.F, free_dofs=system.free_dofs, dirichlet=bc)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@dataclass(eq=False)
|
|
226
|
+
class MixedProblem:
|
|
227
|
+
"""
|
|
228
|
+
Lightweight wrapper for mixed residual assembly with cached compilation.
|
|
229
|
+
"""
|
|
230
|
+
space: MixedFESpace
|
|
231
|
+
residuals: dict[str, Callable] | MixedWeakForm
|
|
232
|
+
params: object | None = None
|
|
233
|
+
pattern: object | None = None
|
|
234
|
+
n_chunks: int | None = None
|
|
235
|
+
pad_trace: bool = False
|
|
236
|
+
_compiled: object = field(init=False, repr=False)
|
|
237
|
+
|
|
238
|
+
def __post_init__(self):
|
|
239
|
+
if isinstance(self.residuals, MixedWeakForm):
|
|
240
|
+
self._compiled = self.residuals.get_compiled()
|
|
241
|
+
else:
|
|
242
|
+
res = make_mixed_residuals(self.residuals)
|
|
243
|
+
self._compiled = compile_mixed_residual(res)
|
|
244
|
+
|
|
245
|
+
def _merge_kwargs(self, kwargs):
|
|
246
|
+
merged = dict(kwargs)
|
|
247
|
+
if self.pattern is not None and "pattern" not in merged:
|
|
248
|
+
merged["pattern"] = self.pattern
|
|
249
|
+
if self.n_chunks is not None and "n_chunks" not in merged:
|
|
250
|
+
merged["n_chunks"] = self.n_chunks
|
|
251
|
+
if self.pad_trace and "pad_trace" not in merged:
|
|
252
|
+
merged["pad_trace"] = True
|
|
253
|
+
return merged
|
|
254
|
+
|
|
255
|
+
def _wrap_params(self, params):
|
|
256
|
+
if callable(params):
|
|
257
|
+
def _wrapped(ctx, u_elem, _params):
|
|
258
|
+
return self._compiled(ctx, u_elem, params(ctx))
|
|
259
|
+
|
|
260
|
+
_wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False)
|
|
261
|
+
return _wrapped, None
|
|
262
|
+
return self._compiled, params
|
|
263
|
+
|
|
264
|
+
def assemble_residual(self, u, *, params=None, **kwargs):
|
|
265
|
+
use_params = self.params if params is None else params
|
|
266
|
+
res_form, use_params = self._wrap_params(use_params)
|
|
267
|
+
return self.space.assemble_residual(
|
|
268
|
+
res_form, u, use_params, **self._merge_kwargs(kwargs)
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def assemble_jacobian(self, u, *, params=None, **kwargs):
|
|
272
|
+
use_params = self.params if params is None else params
|
|
273
|
+
res_form, use_params = self._wrap_params(use_params)
|
|
274
|
+
return self.space.assemble_jacobian(
|
|
275
|
+
res_form, u, use_params, **self._merge_kwargs(kwargs)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def with_params(self, params):
|
|
279
|
+
return MixedProblem(
|
|
280
|
+
self.space,
|
|
281
|
+
self.residuals,
|
|
282
|
+
params=params,
|
|
283
|
+
pattern=self.pattern,
|
|
284
|
+
n_chunks=self.n_chunks,
|
|
285
|
+
pad_trace=self.pad_trace,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def solve(
|
|
289
|
+
self,
|
|
290
|
+
K,
|
|
291
|
+
F,
|
|
292
|
+
*,
|
|
293
|
+
dirichlet=None,
|
|
294
|
+
dirichlet_mode: str = "condense",
|
|
295
|
+
solver=None,
|
|
296
|
+
n_total: int | None = None,
|
|
297
|
+
):
|
|
298
|
+
"""
|
|
299
|
+
Solve a mixed linear system with optional Dirichlet conditions.
|
|
300
|
+
"""
|
|
301
|
+
from ..solver import LinearSolver
|
|
302
|
+
|
|
303
|
+
if solver is None:
|
|
304
|
+
solver = LinearSolver()
|
|
305
|
+
if isinstance(dirichlet, MixedDirichletBC):
|
|
306
|
+
dirichlet = dirichlet.as_dirichlet_bc()
|
|
307
|
+
return solver.solve(K, F, dirichlet=dirichlet, dirichlet_mode=dirichlet_mode, n_total=n_total)
|
|
308
|
+
|
|
309
|
+
@dataclass(frozen=True)
|
|
310
|
+
class MixedDirichletBC:
|
|
311
|
+
"""
|
|
312
|
+
Mixed-system Dirichlet BCs in global mixed DOF numbering.
|
|
313
|
+
"""
|
|
314
|
+
dir_dofs: np.ndarray
|
|
315
|
+
dir_vals: np.ndarray
|
|
316
|
+
|
|
317
|
+
def as_dirichlet_bc(self) -> DirichletBC:
|
|
318
|
+
return DirichletBC(self.dir_dofs, self.dir_vals)
|
|
319
|
+
|
|
320
|
+
def condense_system(self, A, F, *, check: bool = True):
|
|
321
|
+
return self.as_dirichlet_bc().condense_system(A, F, check=check)
|
|
322
|
+
|
|
323
|
+
def free_dofs(self, n_dofs: int) -> np.ndarray:
|
|
324
|
+
return free_dofs(n_dofs, self.dir_dofs)
|
|
325
|
+
|
|
326
|
+
def expand_solution(self, u_free, *, free=None, n_total: int | None = None):
|
|
327
|
+
return self.as_dirichlet_bc().expand_solution(u_free, free=free, n_total=n_total)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@dataclass(frozen=True)
|
|
331
|
+
class MixedBlockSystem:
|
|
332
|
+
mixed: MixedFESpace
|
|
333
|
+
K: object
|
|
334
|
+
F: object
|
|
335
|
+
free_dofs: np.ndarray
|
|
336
|
+
dirichlet: MixedDirichletBC
|
|
337
|
+
|
|
338
|
+
def expand(self, u_free):
|
|
339
|
+
return self.dirichlet.expand_solution(u_free, free=self.free_dofs, n_total=self.mixed.n_dofs)
|
|
340
|
+
|
|
341
|
+
def split(self, u_full: jnp.ndarray) -> dict[str, jnp.ndarray]:
|
|
342
|
+
return self.mixed.unpack_fields(u_full)
|
|
343
|
+
|
|
344
|
+
def join(self, fields: Mapping[str, jnp.ndarray]) -> jnp.ndarray:
|
|
345
|
+
return self.mixed.pack_fields(fields)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
__all__ = ["MixedFESpace", "MixedProblem", "MixedDirichletBC", "MixedBlockSystem"]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mixed weak-form helpers that keep core assembly untouched.
|
|
3
|
+
|
|
4
|
+
This module provides a small convenience API to build and assemble mixed
|
|
5
|
+
weak forms using the existing MixedWeakForm compiler.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Callable, Mapping
|
|
11
|
+
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
|
|
14
|
+
from .mixed_space import MixedFESpace
|
|
15
|
+
from .weakform import MixedWeakForm
|
|
16
|
+
|
|
17
|
+
MixedResiduals = Mapping[str, Callable]
|
|
18
|
+
|
|
19
|
+
class MixedResidualForm:
|
|
20
|
+
"""Wrapper for mixed residuals to mirror the single-field ResidualForm API."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, residuals: MixedResiduals):
|
|
23
|
+
self.residuals = dict(residuals)
|
|
24
|
+
|
|
25
|
+
def get_compiled(self):
|
|
26
|
+
return compile_mixed_weak_form(self.residuals)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compile_mixed_weak_form(residuals: MixedResiduals):
|
|
30
|
+
"""
|
|
31
|
+
Compile mixed weak-form residuals into an element kernel.
|
|
32
|
+
|
|
33
|
+
The residuals must return Expr and include dOmega().
|
|
34
|
+
"""
|
|
35
|
+
return MixedWeakForm(residuals=dict(residuals)).get_compiled()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _wrap_params(res_form, params):
|
|
39
|
+
if callable(params):
|
|
40
|
+
def _wrapped(ctx, u_elem, _params):
|
|
41
|
+
return res_form(ctx, u_elem, params(ctx))
|
|
42
|
+
|
|
43
|
+
_wrapped._includes_measure = getattr(res_form, "_includes_measure", False)
|
|
44
|
+
return _wrapped, None
|
|
45
|
+
return res_form, params
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def assemble_mixed_residual_wf(
|
|
49
|
+
space: MixedFESpace,
|
|
50
|
+
residuals: MixedResiduals | MixedWeakForm | MixedResidualForm | Callable,
|
|
51
|
+
u: jnp.ndarray | Mapping[str, jnp.ndarray],
|
|
52
|
+
params,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Assemble mixed residual from weak-form definitions.
|
|
57
|
+
"""
|
|
58
|
+
if isinstance(residuals, MixedWeakForm):
|
|
59
|
+
res_form = residuals.get_compiled()
|
|
60
|
+
elif isinstance(residuals, MixedResidualForm):
|
|
61
|
+
res_form = residuals.get_compiled()
|
|
62
|
+
elif isinstance(residuals, Mapping):
|
|
63
|
+
res_form = compile_mixed_weak_form(residuals)
|
|
64
|
+
else:
|
|
65
|
+
res_form = residuals
|
|
66
|
+
res_form, params = _wrap_params(res_form, params)
|
|
67
|
+
return space.assemble_residual(res_form, u, params, **kwargs)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def assemble_mixed_jacobian_wf(
|
|
71
|
+
space: MixedFESpace,
|
|
72
|
+
residuals: MixedResiduals | MixedWeakForm | MixedResidualForm | Callable,
|
|
73
|
+
u: jnp.ndarray | Mapping[str, jnp.ndarray],
|
|
74
|
+
params,
|
|
75
|
+
**kwargs,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Assemble mixed Jacobian from weak-form definitions.
|
|
79
|
+
"""
|
|
80
|
+
if isinstance(residuals, MixedWeakForm):
|
|
81
|
+
res_form = residuals.get_compiled()
|
|
82
|
+
elif isinstance(residuals, MixedResidualForm):
|
|
83
|
+
res_form = residuals.get_compiled()
|
|
84
|
+
elif isinstance(residuals, Mapping):
|
|
85
|
+
res_form = compile_mixed_weak_form(residuals)
|
|
86
|
+
else:
|
|
87
|
+
res_form = residuals
|
|
88
|
+
res_form, params = _wrap_params(res_form, params)
|
|
89
|
+
return space.assemble_jacobian(res_form, u, params, **kwargs)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
__all__ = [
|
|
93
|
+
"MixedResidualForm",
|
|
94
|
+
"compile_mixed_weak_form",
|
|
95
|
+
"assemble_mixed_residual_wf",
|
|
96
|
+
"assemble_mixed_jacobian_wf",
|
|
97
|
+
]
|
fluxfem/core/solver.py
CHANGED
|
@@ -24,6 +24,8 @@ def coo_to_csr(rows: Any, cols: Any, data: Any, n_dofs: int):
|
|
|
24
24
|
return sp.csr_matrix((d, (r, c)), shape=(n_dofs, n_dofs))
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
|
|
28
|
+
|
|
27
29
|
def spdirect_solve_cpu(K: Any, F: jnp.ndarray, *, use_jax: bool = False) -> np.ndarray:
|
|
28
30
|
"""
|
|
29
31
|
Convert JAX arrays to NumPy/SciPy and solve K u = F with sparse solver.
|