fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,382 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
5
+
6
+ import numpy as np
7
+ import jax.numpy as jnp
8
+
9
+ from .dtypes import INDEX_DTYPE
10
+ from .forms import MixedFormContext, FieldPair
11
+ from .weakform import MixedWeakForm, compile_mixed_residual, make_mixed_residuals
12
+ from ..solver.dirichlet import DirichletBC, free_dofs
13
+ from ..solver.sparse import FluxSparseMatrix
14
+ from .space import FESpaceClosure
15
+
16
+ P = TypeVar("P")
17
+
18
+ if TYPE_CHECKING:
19
+ from .assembly import JacobianReturn, LinearReturn
20
+
21
+ MixedResidualForm: TypeAlias = Callable[
22
+ [MixedFormContext, Mapping[str, jnp.ndarray], P],
23
+ Mapping[str, jnp.ndarray],
24
+ ]
25
+
26
+
27
+ @dataclass(eq=False)
28
+ class MixedFESpace:
29
+ """
30
+ Mixed FE space composed of multiple scalar/vector spaces.
31
+
32
+ Field DOFs are concatenated in field order:
33
+ [field0 dofs | field1 dofs | ...]
34
+ """
35
+ fields: dict[str, FESpaceClosure]
36
+ field_order: Sequence[str] | None = None
37
+ field_names: tuple[str, ...] = field(init=False)
38
+ field_offsets: dict[str, int] = field(init=False)
39
+ field_slices: dict[str, slice] = field(init=False)
40
+ elem_slices: dict[str, slice] = field(init=False)
41
+ elem_dofs_by_field: dict[str, jnp.ndarray] = field(init=False)
42
+ elem_dofs: jnp.ndarray = field(init=False)
43
+ n_dofs: int = field(init=False)
44
+ n_ldofs: int = field(init=False)
45
+
46
+ def __post_init__(self):
47
+ if not self.fields:
48
+ raise ValueError("MixedFESpace requires at least one field.")
49
+
50
+ if self.field_order is None:
51
+ self.field_names = tuple(self.fields.keys())
52
+ else:
53
+ self.field_names = tuple(self.field_order)
54
+ missing = set(self.fields.keys()) - set(self.field_names)
55
+ extra = set(self.field_names) - set(self.fields.keys())
56
+ if missing or extra:
57
+ raise ValueError(f"field_order mismatch: missing={missing}, extra={extra}")
58
+
59
+ ref_space = self.fields[self.field_names[0]]
60
+ ref_mesh = ref_space.mesh
61
+ ref_basis = ref_space.basis
62
+ n_elems = int(ref_space.elem_dofs.shape[0])
63
+
64
+ offsets: dict[str, int] = {}
65
+ slices: dict[str, slice] = {}
66
+ elem_slices: dict[str, slice] = {}
67
+ elem_dofs_by_field: dict[str, jnp.ndarray] = {}
68
+ elem_dofs_list = []
69
+
70
+ dof_offset = 0
71
+ ldof_offset = 0
72
+ for name in self.field_names:
73
+ space = self.fields[name]
74
+ if space.mesh is not ref_mesh:
75
+ raise ValueError("All mixed fields must share the same mesh object.")
76
+ if space.basis.__class__ is not ref_basis.__class__:
77
+ raise ValueError("All mixed fields must share the same basis type.")
78
+ if int(space.elem_dofs.shape[0]) != n_elems:
79
+ raise ValueError("All mixed fields must have the same element count.")
80
+
81
+ n_dofs = int(space.n_dofs)
82
+ n_ldofs = int(space.n_ldofs)
83
+ offsets[name] = dof_offset
84
+ slices[name] = slice(dof_offset, dof_offset + n_dofs)
85
+ elem_slices[name] = slice(ldof_offset, ldof_offset + n_ldofs)
86
+
87
+ elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE) + dof_offset
88
+ elem_dofs_by_field[name] = elem_dofs
89
+ elem_dofs_list.append(elem_dofs)
90
+
91
+ dof_offset += n_dofs
92
+ ldof_offset += n_ldofs
93
+
94
+ self.field_offsets = offsets
95
+ self.field_slices = slices
96
+ self.elem_slices = elem_slices
97
+ self.elem_dofs_by_field = elem_dofs_by_field
98
+ self.elem_dofs = jnp.concatenate(elem_dofs_list, axis=1)
99
+ self.n_dofs = dof_offset
100
+ self.n_ldofs = ldof_offset
101
+
102
+ def pack_fields(self, fields: Mapping[str, jnp.ndarray]) -> jnp.ndarray:
103
+ """Concatenate per-field vectors into a single mixed vector."""
104
+ parts = []
105
+ for name in self.field_names:
106
+ if name not in fields:
107
+ raise KeyError(f"Missing field '{name}' in pack_fields.")
108
+ parts.append(jnp.asarray(fields[name]))
109
+ return jnp.concatenate(parts, axis=0)
110
+
111
+ def unpack_fields(self, u: jnp.ndarray) -> dict[str, jnp.ndarray]:
112
+ """Split a mixed vector into per-field vectors."""
113
+ u = jnp.asarray(u)
114
+ return {name: u[self.field_slices[name]] for name in self.field_names}
115
+
116
+ def split_element_vector(self, u_elem: jnp.ndarray) -> dict[str, jnp.ndarray]:
117
+ """Split an element-local mixed vector into per-field element vectors."""
118
+ return {name: u_elem[self.elem_slices[name]] for name in self.field_names}
119
+
120
+ def build_form_contexts(self, dep: jnp.ndarray | None = None) -> MixedFormContext:
121
+ ctxs_by_field = {name: sp.build_form_contexts(dep) for name, sp in self.fields.items()}
122
+ ref_ctx = ctxs_by_field[self.field_names[0]]
123
+
124
+ fields = {
125
+ name: FieldPair(test=ctx.test, trial=ctx.trial, unknown=None)
126
+ for name, ctx in ctxs_by_field.items()
127
+ }
128
+ trial_fields = {name: ctx.trial for name, ctx in ctxs_by_field.items()}
129
+ test_fields = {name: ctx.test for name, ctx in ctxs_by_field.items()}
130
+ unknown_fields = {name: ctx.trial for name, ctx in ctxs_by_field.items()}
131
+
132
+ return MixedFormContext(
133
+ fields=fields,
134
+ x_q=ref_ctx.x_q,
135
+ w=ref_ctx.w,
136
+ elem_id=ref_ctx.elem_id,
137
+ trial_fields=trial_fields,
138
+ test_fields=test_fields,
139
+ unknown_fields=unknown_fields,
140
+ )
141
+
142
+ def get_sparsity_pattern(self, *, with_idx: bool = True):
143
+ from .assembly import make_sparsity_pattern
144
+ return make_sparsity_pattern(cast(Any, self), with_idx=with_idx)
145
+
146
+ def assemble_residual(
147
+ self,
148
+ res_form: MixedResidualForm[P],
149
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
150
+ params: P,
151
+ **kwargs,
152
+ ) -> "LinearReturn":
153
+ from .mixed_assembly import assemble_mixed_residual
154
+ return assemble_mixed_residual(self, res_form, u, params, **kwargs)
155
+
156
+ def assemble_jacobian(
157
+ self,
158
+ res_form: MixedResidualForm[P],
159
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
160
+ params: P,
161
+ **kwargs,
162
+ ) -> "JacobianReturn":
163
+ from .mixed_assembly import assemble_mixed_jacobian
164
+ return assemble_mixed_jacobian(self, res_form, u, params, **kwargs)
165
+
166
+ def make_dirichlet(self, *, merge: str = "check_equal", **fields):
167
+ """
168
+ Build mixed Dirichlet BCs from per-field constraints.
169
+
170
+ Usage:
171
+ bc = mixed.make_dirichlet(u=DirichletBC(...), T=(dofs, vals))
172
+ """
173
+ if merge not in {"check_equal", "error", "first", "last"}:
174
+ raise ValueError("merge must be one of: check_equal, error, first, last")
175
+
176
+ dof_map: dict[int, float] = {}
177
+ for name, spec in fields.items():
178
+ if name not in self.field_offsets:
179
+ raise KeyError(f"Unknown mixed field: {name}")
180
+ offset = int(self.field_offsets[name])
181
+ if isinstance(spec, DirichletBC):
182
+ dofs = spec.dofs
183
+ vals = spec.vals
184
+ elif isinstance(spec, tuple) and len(spec) == 2:
185
+ dofs, vals = spec
186
+ else:
187
+ dofs, vals = spec, None
188
+ bc = DirichletBC(dofs, vals)
189
+ g_dofs = np.asarray(bc.dofs, dtype=int) + offset
190
+ g_vals = np.asarray(bc.vals, dtype=float)
191
+ for d, v in zip(g_dofs, g_vals):
192
+ if d in dof_map:
193
+ if merge == "error":
194
+ raise ValueError(f"Duplicate Dirichlet DOF {d} in mixed BCs")
195
+ if merge == "check_equal":
196
+ if not np.isclose(dof_map[d], v):
197
+ raise ValueError(f"Conflicting Dirichlet value for DOF {d}")
198
+ if merge == "first":
199
+ continue
200
+ dof_map[d] = float(v)
201
+
202
+ if not dof_map:
203
+ return MixedDirichletBC(np.array([], dtype=int), np.array([], dtype=float))
204
+ dofs_sorted = np.array(sorted(dof_map.keys()), dtype=int)
205
+ vals_sorted = np.array([dof_map[d] for d in dofs_sorted], dtype=float)
206
+ return MixedDirichletBC(dofs_sorted, vals_sorted)
207
+
208
+ def build_block_system(
209
+ self,
210
+ *,
211
+ diag: Mapping[str, object] | Sequence[object],
212
+ rel: Mapping[tuple[str, str], object] | None = None,
213
+ add_contiguous: object | None = None,
214
+ rhs: Mapping[str, object] | Sequence[object] | np.ndarray | None = None,
215
+ constraints=None,
216
+ merge: str = "check_equal",
217
+ format: str = "auto",
218
+ symmetric: bool = False,
219
+ transpose_rule: str = "T",
220
+ ):
221
+ """
222
+ Build a mixed block system and apply optional constraints.
223
+ """
224
+ from ..solver.block_system import build_block_system as _build_block_system
225
+
226
+ sizes = {name: int(self.fields[name].n_dofs) for name in self.field_names}
227
+
228
+ if isinstance(constraints, MixedDirichletBC):
229
+ constraints = constraints.as_dirichlet_bc()
230
+
231
+ system = _build_block_system(
232
+ diag=diag,
233
+ rel=rel,
234
+ add_contiguous=add_contiguous,
235
+ rhs=rhs,
236
+ constraints=constraints,
237
+ merge=merge,
238
+ sizes=sizes,
239
+ format=format,
240
+ symmetric=symmetric,
241
+ transpose_rule=transpose_rule,
242
+ )
243
+ bc = MixedDirichletBC(system.dirichlet.dofs, system.dirichlet.vals)
244
+ return MixedBlockSystem(self, system.K, system.F, free_dofs=system.free_dofs, dirichlet=bc)
245
+
246
+
247
+ @dataclass(eq=False)
248
+ class MixedProblem:
249
+ """
250
+ Lightweight wrapper for mixed residual assembly with cached compilation.
251
+ """
252
+ space: MixedFESpace
253
+ residuals: dict[str, Callable] | MixedWeakForm
254
+ params: object | None = None
255
+ pattern: object | None = None
256
+ n_chunks: int | None = None
257
+ pad_trace: bool = False
258
+ _compiled: Callable[..., Any] = field(init=False, repr=False)
259
+
260
+ def __post_init__(self):
261
+ if isinstance(self.residuals, MixedWeakForm):
262
+ self._compiled = self.residuals.get_compiled()
263
+ else:
264
+ res = make_mixed_residuals(self.residuals)
265
+ self._compiled = compile_mixed_residual(res)
266
+
267
+ def _merge_kwargs(self, kwargs):
268
+ merged = dict(kwargs)
269
+ if self.pattern is not None and "pattern" not in merged:
270
+ merged["pattern"] = self.pattern
271
+ if self.n_chunks is not None and "n_chunks" not in merged:
272
+ merged["n_chunks"] = self.n_chunks
273
+ if self.pad_trace and "pad_trace" not in merged:
274
+ merged["pad_trace"] = True
275
+ return merged
276
+
277
+ def _wrap_params(self, params):
278
+ if callable(params):
279
+ def _wrapped(ctx, u_elem, _params):
280
+ return self._compiled(ctx, u_elem, params(ctx))
281
+
282
+ _wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False) # type: ignore[attr-defined]
283
+ return _wrapped, None
284
+ return self._compiled, params
285
+
286
+ def assemble_residual(
287
+ self,
288
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
289
+ *,
290
+ params: P | None = None,
291
+ **kwargs,
292
+ ) -> "LinearReturn":
293
+ use_params = self.params if params is None else params
294
+ res_form, use_params = self._wrap_params(use_params)
295
+ return self.space.assemble_residual(
296
+ res_form, u, use_params, **self._merge_kwargs(kwargs)
297
+ )
298
+
299
+ def assemble_jacobian(
300
+ self,
301
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
302
+ *,
303
+ params: P | None = None,
304
+ **kwargs,
305
+ ) -> "JacobianReturn":
306
+ use_params = self.params if params is None else params
307
+ res_form, use_params = self._wrap_params(use_params)
308
+ return self.space.assemble_jacobian(
309
+ res_form, u, use_params, **self._merge_kwargs(kwargs)
310
+ )
311
+
312
+ def with_params(self, params):
313
+ return MixedProblem(
314
+ self.space,
315
+ self.residuals,
316
+ params=params,
317
+ pattern=self.pattern,
318
+ n_chunks=self.n_chunks,
319
+ pad_trace=self.pad_trace,
320
+ )
321
+
322
+ def solve(
323
+ self,
324
+ K,
325
+ F,
326
+ *,
327
+ dirichlet=None,
328
+ dirichlet_mode: str = "condense",
329
+ solver=None,
330
+ n_total: int | None = None,
331
+ ):
332
+ """
333
+ Solve a mixed linear system with optional Dirichlet conditions.
334
+ """
335
+ from ..solver import LinearSolver
336
+
337
+ if solver is None:
338
+ solver = LinearSolver()
339
+ if isinstance(dirichlet, MixedDirichletBC):
340
+ dirichlet = dirichlet.as_dirichlet_bc()
341
+ return solver.solve(K, F, dirichlet=dirichlet, dirichlet_mode=dirichlet_mode, n_total=n_total)
342
+
343
+ @dataclass(frozen=True)
344
+ class MixedDirichletBC:
345
+ """
346
+ Mixed-system Dirichlet BCs in global mixed DOF numbering.
347
+ """
348
+ dir_dofs: np.ndarray
349
+ dir_vals: np.ndarray
350
+
351
+ def as_dirichlet_bc(self) -> DirichletBC:
352
+ return DirichletBC(self.dir_dofs, self.dir_vals)
353
+
354
+ def condense_system(self, A, F, *, check: bool = True):
355
+ return self.as_dirichlet_bc().condense_system(A, F, check=check)
356
+
357
+ def free_dofs(self, n_dofs: int) -> np.ndarray:
358
+ return free_dofs(n_dofs, self.dir_dofs)
359
+
360
+ def expand_solution(self, u_free, *, free=None, n_total: int | None = None):
361
+ return self.as_dirichlet_bc().expand_solution(u_free, free=free, n_total=n_total)
362
+
363
+
364
+ @dataclass(frozen=True)
365
+ class MixedBlockSystem:
366
+ mixed: MixedFESpace
367
+ K: object
368
+ F: object
369
+ free_dofs: np.ndarray
370
+ dirichlet: MixedDirichletBC
371
+
372
+ def expand(self, u_free):
373
+ return self.dirichlet.expand_solution(u_free, free=self.free_dofs, n_total=self.mixed.n_dofs)
374
+
375
+ def split(self, u_full: jnp.ndarray) -> dict[str, jnp.ndarray]:
376
+ return self.mixed.unpack_fields(u_full)
377
+
378
+ def join(self, fields: Mapping[str, jnp.ndarray]) -> jnp.ndarray:
379
+ return self.mixed.pack_fields(fields)
380
+
381
+
382
+ __all__ = ["MixedFESpace", "MixedProblem", "MixedDirichletBC", "MixedBlockSystem"]
@@ -0,0 +1,97 @@
1
+ """
2
+ Mixed weak-form helpers that keep core assembly untouched.
3
+
4
+ This module provides a small convenience API to build and assemble mixed
5
+ weak forms using the existing MixedWeakForm compiler.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Callable, Mapping
11
+
12
+ import jax.numpy as jnp
13
+
14
+ from .mixed_space import MixedFESpace
15
+ from .weakform import MixedWeakForm
16
+
17
+ MixedResiduals = Mapping[str, Callable]
18
+
19
+ class MixedResidualForm:
20
+ """Wrapper for mixed residuals to mirror the single-field ResidualForm API."""
21
+
22
+ def __init__(self, residuals: MixedResiduals):
23
+ self.residuals = dict(residuals)
24
+
25
+ def get_compiled(self):
26
+ return compile_mixed_weak_form(self.residuals)
27
+
28
+
29
+ def compile_mixed_weak_form(residuals: MixedResiduals):
30
+ """
31
+ Compile mixed weak-form residuals into an element kernel.
32
+
33
+ The residuals must return Expr and include dOmega().
34
+ """
35
+ return MixedWeakForm(residuals=dict(residuals)).get_compiled()
36
+
37
+
38
+ def _wrap_params(res_form, params):
39
+ if callable(params):
40
+ def _wrapped(ctx, u_elem, _params):
41
+ return res_form(ctx, u_elem, params(ctx))
42
+
43
+ _wrapped._includes_measure = getattr(res_form, "_includes_measure", False) # type: ignore[attr-defined]
44
+ return _wrapped, None
45
+ return res_form, params
46
+
47
+
48
+ def assemble_mixed_residual_wf(
49
+ space: MixedFESpace,
50
+ residuals: MixedResiduals | MixedWeakForm | MixedResidualForm | Callable,
51
+ u: jnp.ndarray | Mapping[str, jnp.ndarray],
52
+ params,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Assemble mixed residual from weak-form definitions.
57
+ """
58
+ if isinstance(residuals, MixedWeakForm):
59
+ res_form = residuals.get_compiled()
60
+ elif isinstance(residuals, MixedResidualForm):
61
+ res_form = residuals.get_compiled()
62
+ elif isinstance(residuals, Mapping):
63
+ res_form = compile_mixed_weak_form(residuals)
64
+ else:
65
+ res_form = residuals
66
+ res_form, params = _wrap_params(res_form, params)
67
+ return space.assemble_residual(res_form, u, params, **kwargs)
68
+
69
+
70
+ def assemble_mixed_jacobian_wf(
71
+ space: MixedFESpace,
72
+ residuals: MixedResiduals | MixedWeakForm | MixedResidualForm | Callable,
73
+ u: jnp.ndarray | Mapping[str, jnp.ndarray],
74
+ params,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Assemble mixed Jacobian from weak-form definitions.
79
+ """
80
+ if isinstance(residuals, MixedWeakForm):
81
+ res_form = residuals.get_compiled()
82
+ elif isinstance(residuals, MixedResidualForm):
83
+ res_form = residuals.get_compiled()
84
+ elif isinstance(residuals, Mapping):
85
+ res_form = compile_mixed_weak_form(residuals)
86
+ else:
87
+ res_form = residuals
88
+ res_form, params = _wrap_params(res_form, params)
89
+ return space.assemble_jacobian(res_form, u, params, **kwargs)
90
+
91
+
92
+ __all__ = [
93
+ "MixedResidualForm",
94
+ "compile_mixed_weak_form",
95
+ "assemble_mixed_residual_wf",
96
+ "assemble_mixed_jacobian_wf",
97
+ ]
fluxfem/core/solver.py CHANGED
@@ -24,6 +24,8 @@ def coo_to_csr(rows: Any, cols: Any, data: Any, n_dofs: int):
24
24
  return sp.csr_matrix((d, (r, c)), shape=(n_dofs, n_dofs))
25
25
 
26
26
 
27
+
28
+
27
29
  def spdirect_solve_cpu(K: Any, F: jnp.ndarray, *, use_jax: bool = False) -> np.ndarray:
28
30
  """
29
31
  Convert JAX arrays to NumPy/SciPy and solve K u = F with sparse solver.