fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,348 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Mapping, Sequence, Callable
5
+
6
+ import numpy as np
7
+ import jax.numpy as jnp
8
+
9
+ from .dtypes import INDEX_DTYPE
10
+ from .forms import MixedFormContext, FieldPair
11
+ from .weakform import MixedWeakForm, compile_mixed_residual, make_mixed_residuals
12
+ from ..solver.dirichlet import DirichletBC, free_dofs
13
+ from ..solver.sparse import FluxSparseMatrix
14
+ from .space import FESpaceClosure
15
+
16
+
17
+ @dataclass(eq=False)
18
+ class MixedFESpace:
19
+ """
20
+ Mixed FE space composed of multiple scalar/vector spaces.
21
+
22
+ Field DOFs are concatenated in field order:
23
+ [field0 dofs | field1 dofs | ...]
24
+ """
25
+ fields: dict[str, FESpaceClosure]
26
+ field_order: Sequence[str] | None = None
27
+ field_names: tuple[str, ...] = field(init=False)
28
+ field_offsets: dict[str, int] = field(init=False)
29
+ field_slices: dict[str, slice] = field(init=False)
30
+ elem_slices: dict[str, slice] = field(init=False)
31
+ elem_dofs_by_field: dict[str, jnp.ndarray] = field(init=False)
32
+ elem_dofs: jnp.ndarray = field(init=False)
33
+ n_dofs: int = field(init=False)
34
+ n_ldofs: int = field(init=False)
35
+
36
+ def __post_init__(self):
37
+ if not self.fields:
38
+ raise ValueError("MixedFESpace requires at least one field.")
39
+
40
+ if self.field_order is None:
41
+ self.field_names = tuple(self.fields.keys())
42
+ else:
43
+ self.field_names = tuple(self.field_order)
44
+ missing = set(self.fields.keys()) - set(self.field_names)
45
+ extra = set(self.field_names) - set(self.fields.keys())
46
+ if missing or extra:
47
+ raise ValueError(f"field_order mismatch: missing={missing}, extra={extra}")
48
+
49
+ ref_space = self.fields[self.field_names[0]]
50
+ ref_mesh = ref_space.mesh
51
+ ref_basis = ref_space.basis
52
+ n_elems = int(ref_space.elem_dofs.shape[0])
53
+
54
+ offsets: dict[str, int] = {}
55
+ slices: dict[str, slice] = {}
56
+ elem_slices: dict[str, slice] = {}
57
+ elem_dofs_by_field: dict[str, jnp.ndarray] = {}
58
+ elem_dofs_list = []
59
+
60
+ dof_offset = 0
61
+ ldof_offset = 0
62
+ for name in self.field_names:
63
+ space = self.fields[name]
64
+ if space.mesh is not ref_mesh:
65
+ raise ValueError("All mixed fields must share the same mesh object.")
66
+ if space.basis.__class__ is not ref_basis.__class__:
67
+ raise ValueError("All mixed fields must share the same basis type.")
68
+ if int(space.elem_dofs.shape[0]) != n_elems:
69
+ raise ValueError("All mixed fields must have the same element count.")
70
+
71
+ n_dofs = int(space.n_dofs)
72
+ n_ldofs = int(space.n_ldofs)
73
+ offsets[name] = dof_offset
74
+ slices[name] = slice(dof_offset, dof_offset + n_dofs)
75
+ elem_slices[name] = slice(ldof_offset, ldof_offset + n_ldofs)
76
+
77
+ elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE) + dof_offset
78
+ elem_dofs_by_field[name] = elem_dofs
79
+ elem_dofs_list.append(elem_dofs)
80
+
81
+ dof_offset += n_dofs
82
+ ldof_offset += n_ldofs
83
+
84
+ self.field_offsets = offsets
85
+ self.field_slices = slices
86
+ self.elem_slices = elem_slices
87
+ self.elem_dofs_by_field = elem_dofs_by_field
88
+ self.elem_dofs = jnp.concatenate(elem_dofs_list, axis=1)
89
+ self.n_dofs = dof_offset
90
+ self.n_ldofs = ldof_offset
91
+
92
+ def pack_fields(self, fields: Mapping[str, jnp.ndarray]) -> jnp.ndarray:
93
+ """Concatenate per-field vectors into a single mixed vector."""
94
+ parts = []
95
+ for name in self.field_names:
96
+ if name not in fields:
97
+ raise KeyError(f"Missing field '{name}' in pack_fields.")
98
+ parts.append(jnp.asarray(fields[name]))
99
+ return jnp.concatenate(parts, axis=0)
100
+
101
+ def unpack_fields(self, u: jnp.ndarray) -> dict[str, jnp.ndarray]:
102
+ """Split a mixed vector into per-field vectors."""
103
+ u = jnp.asarray(u)
104
+ return {name: u[self.field_slices[name]] for name in self.field_names}
105
+
106
+ def split_element_vector(self, u_elem: jnp.ndarray) -> dict[str, jnp.ndarray]:
107
+ """Split an element-local mixed vector into per-field element vectors."""
108
+ return {name: u_elem[self.elem_slices[name]] for name in self.field_names}
109
+
110
+ def build_form_contexts(self, dep: jnp.ndarray | None = None) -> MixedFormContext:
111
+ ctxs_by_field = {name: sp.build_form_contexts(dep) for name, sp in self.fields.items()}
112
+ ref_ctx = ctxs_by_field[self.field_names[0]]
113
+
114
+ fields = {
115
+ name: FieldPair(test=ctx.test, trial=ctx.trial, unknown=None)
116
+ for name, ctx in ctxs_by_field.items()
117
+ }
118
+ trial_fields = {name: ctx.trial for name, ctx in ctxs_by_field.items()}
119
+ test_fields = {name: ctx.test for name, ctx in ctxs_by_field.items()}
120
+ unknown_fields = {name: ctx.trial for name, ctx in ctxs_by_field.items()}
121
+
122
+ return MixedFormContext(
123
+ fields=fields,
124
+ x_q=ref_ctx.x_q,
125
+ w=ref_ctx.w,
126
+ elem_id=ref_ctx.elem_id,
127
+ trial_fields=trial_fields,
128
+ test_fields=test_fields,
129
+ unknown_fields=unknown_fields,
130
+ )
131
+
132
+ def get_sparsity_pattern(self, *, with_idx: bool = True):
133
+ from .assembly import make_sparsity_pattern
134
+ return make_sparsity_pattern(self, with_idx=with_idx)
135
+
136
+ def assemble_residual(self, res_form, u, params, **kwargs):
137
+ from .mixed_assembly import assemble_mixed_residual
138
+ return assemble_mixed_residual(self, res_form, u, params, **kwargs)
139
+
140
+ def assemble_jacobian(self, res_form, u, params, **kwargs):
141
+ from .mixed_assembly import assemble_mixed_jacobian
142
+ return assemble_mixed_jacobian(self, res_form, u, params, **kwargs)
143
+
144
+ def make_dirichlet(self, *, merge: str = "check_equal", **fields):
145
+ """
146
+ Build mixed Dirichlet BCs from per-field constraints.
147
+
148
+ Usage:
149
+ bc = mixed.make_dirichlet(u=DirichletBC(...), T=(dofs, vals))
150
+ """
151
+ if merge not in {"check_equal", "error", "first", "last"}:
152
+ raise ValueError("merge must be one of: check_equal, error, first, last")
153
+
154
+ dof_map: dict[int, float] = {}
155
+ for name, spec in fields.items():
156
+ if name not in self.field_offsets:
157
+ raise KeyError(f"Unknown mixed field: {name}")
158
+ offset = int(self.field_offsets[name])
159
+ if isinstance(spec, DirichletBC):
160
+ dofs = spec.dofs
161
+ vals = spec.vals
162
+ elif isinstance(spec, tuple) and len(spec) == 2:
163
+ dofs, vals = spec
164
+ else:
165
+ dofs, vals = spec, None
166
+ bc = DirichletBC(dofs, vals)
167
+ g_dofs = np.asarray(bc.dofs, dtype=int) + offset
168
+ g_vals = np.asarray(bc.vals, dtype=float)
169
+ for d, v in zip(g_dofs, g_vals):
170
+ if d in dof_map:
171
+ if merge == "error":
172
+ raise ValueError(f"Duplicate Dirichlet DOF {d} in mixed BCs")
173
+ if merge == "check_equal":
174
+ if not np.isclose(dof_map[d], v):
175
+ raise ValueError(f"Conflicting Dirichlet value for DOF {d}")
176
+ if merge == "first":
177
+ continue
178
+ dof_map[d] = float(v)
179
+
180
+ if not dof_map:
181
+ return MixedDirichletBC(np.array([], dtype=int), np.array([], dtype=float))
182
+ dofs_sorted = np.array(sorted(dof_map.keys()), dtype=int)
183
+ vals_sorted = np.array([dof_map[d] for d in dofs_sorted], dtype=float)
184
+ return MixedDirichletBC(dofs_sorted, vals_sorted)
185
+
186
+ def build_block_system(
187
+ self,
188
+ *,
189
+ diag: Mapping[str, object] | Sequence[object],
190
+ rel: Mapping[tuple[str, str], object] | None = None,
191
+ add_contiguous: object | None = None,
192
+ rhs: Mapping[str, object] | Sequence[object] | np.ndarray | None = None,
193
+ constraints=None,
194
+ merge: str = "check_equal",
195
+ format: str = "auto",
196
+ symmetric: bool = False,
197
+ transpose_rule: str = "T",
198
+ ):
199
+ """
200
+ Build a mixed block system and apply optional constraints.
201
+ """
202
+ from ..solver.block_system import build_block_system as _build_block_system
203
+
204
+ sizes = {name: int(self.fields[name].n_dofs) for name in self.field_names}
205
+
206
+ if isinstance(constraints, MixedDirichletBC):
207
+ constraints = constraints.as_dirichlet_bc()
208
+
209
+ system = _build_block_system(
210
+ diag=diag,
211
+ rel=rel,
212
+ add_contiguous=add_contiguous,
213
+ rhs=rhs,
214
+ constraints=constraints,
215
+ merge=merge,
216
+ sizes=sizes,
217
+ format=format,
218
+ symmetric=symmetric,
219
+ transpose_rule=transpose_rule,
220
+ )
221
+ bc = MixedDirichletBC(system.dirichlet.dofs, system.dirichlet.vals)
222
+ return MixedBlockSystem(self, system.K, system.F, free_dofs=system.free_dofs, dirichlet=bc)
223
+
224
+
225
+ @dataclass(eq=False)
226
+ class MixedProblem:
227
+ """
228
+ Lightweight wrapper for mixed residual assembly with cached compilation.
229
+ """
230
+ space: MixedFESpace
231
+ residuals: dict[str, Callable] | MixedWeakForm
232
+ params: object | None = None
233
+ pattern: object | None = None
234
+ n_chunks: int | None = None
235
+ pad_trace: bool = False
236
+ _compiled: object = field(init=False, repr=False)
237
+
238
+ def __post_init__(self):
239
+ if isinstance(self.residuals, MixedWeakForm):
240
+ self._compiled = self.residuals.get_compiled()
241
+ else:
242
+ res = make_mixed_residuals(self.residuals)
243
+ self._compiled = compile_mixed_residual(res)
244
+
245
+ def _merge_kwargs(self, kwargs):
246
+ merged = dict(kwargs)
247
+ if self.pattern is not None and "pattern" not in merged:
248
+ merged["pattern"] = self.pattern
249
+ if self.n_chunks is not None and "n_chunks" not in merged:
250
+ merged["n_chunks"] = self.n_chunks
251
+ if self.pad_trace and "pad_trace" not in merged:
252
+ merged["pad_trace"] = True
253
+ return merged
254
+
255
+ def _wrap_params(self, params):
256
+ if callable(params):
257
+ def _wrapped(ctx, u_elem, _params):
258
+ return self._compiled(ctx, u_elem, params(ctx))
259
+
260
+ _wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False)
261
+ return _wrapped, None
262
+ return self._compiled, params
263
+
264
+ def assemble_residual(self, u, *, params=None, **kwargs):
265
+ use_params = self.params if params is None else params
266
+ res_form, use_params = self._wrap_params(use_params)
267
+ return self.space.assemble_residual(
268
+ res_form, u, use_params, **self._merge_kwargs(kwargs)
269
+ )
270
+
271
+ def assemble_jacobian(self, u, *, params=None, **kwargs):
272
+ use_params = self.params if params is None else params
273
+ res_form, use_params = self._wrap_params(use_params)
274
+ return self.space.assemble_jacobian(
275
+ res_form, u, use_params, **self._merge_kwargs(kwargs)
276
+ )
277
+
278
+ def with_params(self, params):
279
+ return MixedProblem(
280
+ self.space,
281
+ self.residuals,
282
+ params=params,
283
+ pattern=self.pattern,
284
+ n_chunks=self.n_chunks,
285
+ pad_trace=self.pad_trace,
286
+ )
287
+
288
+ def solve(
289
+ self,
290
+ K,
291
+ F,
292
+ *,
293
+ dirichlet=None,
294
+ dirichlet_mode: str = "condense",
295
+ solver=None,
296
+ n_total: int | None = None,
297
+ ):
298
+ """
299
+ Solve a mixed linear system with optional Dirichlet conditions.
300
+ """
301
+ from ..solver import LinearSolver
302
+
303
+ if solver is None:
304
+ solver = LinearSolver()
305
+ if isinstance(dirichlet, MixedDirichletBC):
306
+ dirichlet = dirichlet.as_dirichlet_bc()
307
+ return solver.solve(K, F, dirichlet=dirichlet, dirichlet_mode=dirichlet_mode, n_total=n_total)
308
+
309
+ @dataclass(frozen=True)
310
+ class MixedDirichletBC:
311
+ """
312
+ Mixed-system Dirichlet BCs in global mixed DOF numbering.
313
+ """
314
+ dir_dofs: np.ndarray
315
+ dir_vals: np.ndarray
316
+
317
+ def as_dirichlet_bc(self) -> DirichletBC:
318
+ return DirichletBC(self.dir_dofs, self.dir_vals)
319
+
320
+ def condense_system(self, A, F, *, check: bool = True):
321
+ return self.as_dirichlet_bc().condense_system(A, F, check=check)
322
+
323
+ def free_dofs(self, n_dofs: int) -> np.ndarray:
324
+ return free_dofs(n_dofs, self.dir_dofs)
325
+
326
+ def expand_solution(self, u_free, *, free=None, n_total: int | None = None):
327
+ return self.as_dirichlet_bc().expand_solution(u_free, free=free, n_total=n_total)
328
+
329
+
330
+ @dataclass(frozen=True)
331
+ class MixedBlockSystem:
332
+ mixed: MixedFESpace
333
+ K: object
334
+ F: object
335
+ free_dofs: np.ndarray
336
+ dirichlet: MixedDirichletBC
337
+
338
+ def expand(self, u_free):
339
+ return self.dirichlet.expand_solution(u_free, free=self.free_dofs, n_total=self.mixed.n_dofs)
340
+
341
+ def split(self, u_full: jnp.ndarray) -> dict[str, jnp.ndarray]:
342
+ return self.mixed.unpack_fields(u_full)
343
+
344
+ def join(self, fields: Mapping[str, jnp.ndarray]) -> jnp.ndarray:
345
+ return self.mixed.pack_fields(fields)
346
+
347
+
348
+ __all__ = ["MixedFESpace", "MixedProblem", "MixedDirichletBC", "MixedBlockSystem"]
@@ -0,0 +1,97 @@
1
+ """
2
+ Mixed weak-form helpers that keep core assembly untouched.
3
+
4
+ This module provides a small convenience API to build and assemble mixed
5
+ weak forms using the existing MixedWeakForm compiler.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Callable, Mapping
11
+
12
+ import jax.numpy as jnp
13
+
14
+ from .mixed_space import MixedFESpace
15
+ from .weakform import MixedWeakForm
16
+
17
+ MixedResiduals = Mapping[str, Callable]
18
+
19
+ class MixedResidualForm:
20
+ """Wrapper for mixed residuals to mirror the single-field ResidualForm API."""
21
+
22
+ def __init__(self, residuals: MixedResiduals):
23
+ self.residuals = dict(residuals)
24
+
25
+ def get_compiled(self):
26
+ return compile_mixed_weak_form(self.residuals)
27
+
28
+
29
+ def compile_mixed_weak_form(residuals: MixedResiduals):
30
+ """
31
+ Compile mixed weak-form residuals into an element kernel.
32
+
33
+ The residuals must return Expr and include dOmega().
34
+ """
35
+ return MixedWeakForm(residuals=dict(residuals)).get_compiled()
36
+
37
+
38
+ def _wrap_params(res_form, params):
39
+ if callable(params):
40
+ def _wrapped(ctx, u_elem, _params):
41
+ return res_form(ctx, u_elem, params(ctx))
42
+
43
+ _wrapped._includes_measure = getattr(res_form, "_includes_measure", False)
44
+ return _wrapped, None
45
+ return res_form, params
46
+
47
+
48
+ def assemble_mixed_residual_wf(
49
+ space: MixedFESpace,
50
+ residuals: MixedResiduals | MixedWeakForm | MixedResidualForm | Callable,
51
+ u: jnp.ndarray | Mapping[str, jnp.ndarray],
52
+ params,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Assemble mixed residual from weak-form definitions.
57
+ """
58
+ if isinstance(residuals, MixedWeakForm):
59
+ res_form = residuals.get_compiled()
60
+ elif isinstance(residuals, MixedResidualForm):
61
+ res_form = residuals.get_compiled()
62
+ elif isinstance(residuals, Mapping):
63
+ res_form = compile_mixed_weak_form(residuals)
64
+ else:
65
+ res_form = residuals
66
+ res_form, params = _wrap_params(res_form, params)
67
+ return space.assemble_residual(res_form, u, params, **kwargs)
68
+
69
+
70
+ def assemble_mixed_jacobian_wf(
71
+ space: MixedFESpace,
72
+ residuals: MixedResiduals | MixedWeakForm | MixedResidualForm | Callable,
73
+ u: jnp.ndarray | Mapping[str, jnp.ndarray],
74
+ params,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Assemble mixed Jacobian from weak-form definitions.
79
+ """
80
+ if isinstance(residuals, MixedWeakForm):
81
+ res_form = residuals.get_compiled()
82
+ elif isinstance(residuals, MixedResidualForm):
83
+ res_form = residuals.get_compiled()
84
+ elif isinstance(residuals, Mapping):
85
+ res_form = compile_mixed_weak_form(residuals)
86
+ else:
87
+ res_form = residuals
88
+ res_form, params = _wrap_params(res_form, params)
89
+ return space.assemble_jacobian(res_form, u, params, **kwargs)
90
+
91
+
92
+ __all__ = [
93
+ "MixedResidualForm",
94
+ "compile_mixed_weak_form",
95
+ "assemble_mixed_residual_wf",
96
+ "assemble_mixed_jacobian_wf",
97
+ ]
fluxfem/core/solver.py CHANGED
@@ -24,6 +24,8 @@ def coo_to_csr(rows: Any, cols: Any, data: Any, n_dofs: int):
24
24
  return sp.csr_matrix((d, (r, c)), shape=(n_dofs, n_dofs))
25
25
 
26
26
 
27
+
28
+
27
29
  def spdirect_solve_cpu(K: Any, F: jnp.ndarray, *, use_jax: bool = False) -> np.ndarray:
28
30
  """
29
31
  Convert JAX arrays to NumPy/SciPy and solve K u = F with sparse solver.