fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/solver/sparse.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from typing import Any, Iterable, Sequence, TYPE_CHECKING, TypeAlias
4
5
 
5
6
  import numpy as np
6
7
  import jax
@@ -12,6 +13,90 @@ except Exception: # pragma: no cover
12
13
  sp = None
13
14
 
14
15
 
16
+ if TYPE_CHECKING:
17
+ from jax import Array as JaxArray
18
+
19
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
20
+ else:
21
+ ArrayLike: TypeAlias = np.ndarray
22
+ COOTuple: TypeAlias = tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]
23
+
24
+
25
+ def coalesce_coo(
26
+ rows: ArrayLike, cols: ArrayLike, data: ArrayLike
27
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
28
+ """
29
+ Sum duplicate COO entries by sorting (CPU-friendly).
30
+ Returns (rows_u, cols_u, data_u) as NumPy arrays.
31
+ """
32
+ r = np.asarray(rows, dtype=np.int64)
33
+ c = np.asarray(cols, dtype=np.int64)
34
+ d = np.asarray(data)
35
+ if r.size == 0:
36
+ return r, c, d
37
+ order = np.lexsort((c, r))
38
+ r_s = r[order]
39
+ c_s = c[order]
40
+ d_s = d[order]
41
+ new_group = np.ones(r_s.size, dtype=bool)
42
+ new_group[1:] = (r_s[1:] != r_s[:-1]) | (c_s[1:] != c_s[:-1])
43
+ starts = np.nonzero(new_group)[0]
44
+ r_u = r_s[starts]
45
+ c_u = c_s[starts]
46
+ d_u = np.add.reduceat(d_s, starts)
47
+ return r_u, c_u, d_u
48
+
49
+
50
+ def _normalize_flux_mats(mats: Sequence["FluxSparseMatrix"]) -> tuple["FluxSparseMatrix", ...]:
51
+ if len(mats) == 1 and isinstance(mats[0], (list, tuple)):
52
+ mats = tuple(mats[0])
53
+ if not mats:
54
+ raise ValueError("At least one FluxSparseMatrix is required.")
55
+ return mats
56
+
57
+
58
+ def concat_flux(*mats: "FluxSparseMatrix", n_dofs: int | None = None) -> "FluxSparseMatrix":
59
+ """
60
+ Concatenate COO entries from multiple FluxSparseMatrix objects.
61
+ All matrices must share the same n_dofs unless n_dofs is provided.
62
+ """
63
+ mats = _normalize_flux_mats(mats)
64
+ if n_dofs is None:
65
+ n_dofs = int(mats[0].n_dofs)
66
+ for mat in mats[1:]:
67
+ if int(mat.n_dofs) != n_dofs:
68
+ raise ValueError("All matrices must share n_dofs for concat_flux.")
69
+ rows_list = [np.asarray(mat.pattern.rows, dtype=np.int32) for mat in mats]
70
+ cols_list = [np.asarray(mat.pattern.cols, dtype=np.int32) for mat in mats]
71
+ data_list = [np.asarray(mat.data) for mat in mats]
72
+ rows = np.concatenate(rows_list) if rows_list else np.asarray([], dtype=np.int32)
73
+ cols = np.concatenate(cols_list) if cols_list else np.asarray([], dtype=np.int32)
74
+ data = np.concatenate(data_list) if data_list else np.asarray([], dtype=float)
75
+ return FluxSparseMatrix(rows, cols, data, int(n_dofs))
76
+
77
+
78
+ def block_diag_flux(*mats: "FluxSparseMatrix") -> "FluxSparseMatrix":
79
+ """Block-diagonal concatenation for FluxSparseMatrix objects."""
80
+ mats = _normalize_flux_mats(mats)
81
+ rows_out = []
82
+ cols_out = []
83
+ data_out = []
84
+ offset = 0
85
+ for mat in mats:
86
+ rows = np.asarray(mat.pattern.rows, dtype=np.int32)
87
+ cols = np.asarray(mat.pattern.cols, dtype=np.int32)
88
+ data = np.asarray(mat.data)
89
+ if rows.size:
90
+ rows_out.append(rows + offset)
91
+ cols_out.append(cols + offset)
92
+ data_out.append(data)
93
+ offset += int(mat.n_dofs)
94
+ rows = np.concatenate(rows_out) if rows_out else np.asarray([], dtype=np.int32)
95
+ cols = np.concatenate(cols_out) if cols_out else np.asarray([], dtype=np.int32)
96
+ data = np.concatenate(data_out) if data_out else np.asarray([], dtype=float)
97
+ return FluxSparseMatrix(rows, cols, data, int(offset))
98
+
99
+
15
100
  @jax.tree_util.register_pytree_node_class
16
101
  @dataclass(frozen=True)
17
102
  class SparsityPattern:
@@ -80,7 +165,14 @@ class FluxSparseMatrix:
80
165
  - data stores the numeric values for the current nonlinear iterate
81
166
  """
82
167
 
83
- def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None):
168
+ def __init__(
169
+ self,
170
+ rows_or_pattern: SparsityPattern | ArrayLike,
171
+ cols: ArrayLike | None = None,
172
+ data: ArrayLike | None = None,
173
+ n_dofs: int | None = None,
174
+ meta: dict | None = None,
175
+ ):
84
176
  # New signature: FluxSparseMatrix(pattern, data)
85
177
  if isinstance(rows_or_pattern, SparsityPattern):
86
178
  pattern = rows_or_pattern
@@ -88,15 +180,22 @@ class FluxSparseMatrix:
88
180
  values = jnp.asarray(values)
89
181
  else:
90
182
  # Legacy signature: FluxSparseMatrix(rows, cols, data, n_dofs)
91
- r_np = np.asarray(rows_or_pattern, dtype=np.int32)
92
- c_np = np.asarray(cols, dtype=np.int32)
93
- diag_idx_np = np.nonzero(r_np == c_np)[0].astype(np.int32)
183
+ r_j = jnp.asarray(rows_or_pattern, dtype=jnp.int32)
184
+ c_j = jnp.asarray(cols, dtype=jnp.int32)
185
+ is_tracer = isinstance(rows_or_pattern, jax.core.Tracer) or isinstance(cols, jax.core.Tracer)
186
+ diag_idx_j = None
187
+ if not is_tracer:
188
+ diag_idx_j = jnp.nonzero(r_j == c_j)[0].astype(jnp.int32)
189
+ if n_dofs is None:
190
+ if is_tracer:
191
+ raise ValueError("n_dofs must be provided when constructing FluxSparseMatrix under JIT.")
192
+ n_dofs = int(np.asarray(cols).max()) + 1
94
193
  pattern = SparsityPattern(
95
- rows=jnp.asarray(r_np),
96
- cols=jnp.asarray(c_np),
97
- n_dofs=int(n_dofs) if n_dofs is not None else int(c_np.max()) + 1,
194
+ rows=r_j,
195
+ cols=c_j,
196
+ n_dofs=int(n_dofs) if n_dofs is not None else int(np.asarray(cols).max()) + 1,
98
197
  idx=None,
99
- diag_idx=jnp.asarray(diag_idx_np),
198
+ diag_idx=diag_idx_j,
100
199
  )
101
200
  values = jnp.asarray(data)
102
201
 
@@ -105,27 +204,42 @@ class FluxSparseMatrix:
105
204
  self.cols = pattern.cols
106
205
  self.n_dofs = int(pattern.n_dofs)
107
206
  self.data = values
207
+ self.meta = dict(meta) if meta is not None else None
108
208
 
109
209
  @classmethod
110
- def from_bilinear(cls, coo_tuple):
210
+ def from_bilinear(cls, coo_tuple: COOTuple) -> "FluxSparseMatrix":
111
211
  """Construct from assemble_bilinear_dense(..., sparse=True)."""
112
212
  rows, cols, data, n_dofs = coo_tuple
113
213
  return cls(rows, cols, data, n_dofs)
114
214
 
115
215
  @classmethod
116
- def from_linear(cls, coo_tuple):
216
+ def from_linear(cls, coo_tuple: tuple[jnp.ndarray, jnp.ndarray, int]) -> "FluxSparseMatrix":
117
217
  """Construct from assemble_linear_form(..., sparse=True) (matrix interpretation only)."""
118
218
  rows, data, n_dofs = coo_tuple
119
219
  cols = jnp.zeros_like(rows)
120
220
  return cls(rows, cols, data, n_dofs)
121
221
 
122
- def with_data(self, data):
222
+ def with_data(self, data: ArrayLike) -> "FluxSparseMatrix":
123
223
  """Return a new FluxSparseMatrix sharing the same pattern with updated data."""
124
- return FluxSparseMatrix(self.pattern, data)
224
+ return FluxSparseMatrix(self.pattern, data, meta=self.meta)
225
+
226
+ def add_dense(self, dense: ArrayLike) -> "FluxSparseMatrix":
227
+ """Return a new FluxSparseMatrix with dense entries added on the pattern."""
228
+ dense_vals = jnp.asarray(dense)[self.pattern.rows, self.pattern.cols]
229
+ return FluxSparseMatrix(self.pattern, self.data + dense_vals)
125
230
 
126
- def to_coo(self):
231
+ def to_coo(self) -> COOTuple:
127
232
  return self.pattern.rows, self.pattern.cols, self.data, self.pattern.n_dofs
128
233
 
234
+ @property
235
+ def nnz(self) -> int:
236
+ return int(self.data.shape[0])
237
+
238
+ def coalesce(self) -> "FluxSparseMatrix":
239
+ """Return a new FluxSparseMatrix with duplicate entries summed."""
240
+ rows_u, cols_u, data_u = coalesce_coo(self.pattern.rows, self.pattern.cols, self.data)
241
+ return FluxSparseMatrix(rows_u, cols_u, data_u, self.pattern.n_dofs)
242
+
129
243
  def to_csr(self):
130
244
  if sp is None:
131
245
  raise ImportError("scipy is required for to_csr()")
@@ -143,7 +257,7 @@ class FluxSparseMatrix:
143
257
  d = np.array(self.data, copy=True)
144
258
  return sp.csr_matrix((d, (r, c)), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
145
259
 
146
- def to_dense(self):
260
+ def to_dense(self) -> jnp.ndarray:
147
261
  # small debug helper
148
262
  dense = jnp.zeros((self.pattern.n_dofs, self.pattern.n_dofs), dtype=self.data.dtype)
149
263
  dense = dense.at[self.pattern.rows, self.pattern.cols].add(self.data)
@@ -158,7 +272,7 @@ class FluxSparseMatrix:
158
272
  idx = jnp.stack([self.pattern.rows, self.pattern.cols], axis=-1)
159
273
  return jsparse.BCOO((self.data, idx), shape=(self.pattern.n_dofs, self.pattern.n_dofs))
160
274
 
161
- def matvec(self, x):
275
+ def matvec(self, x: ArrayLike) -> jnp.ndarray:
162
276
  """Compute y = A x in JAX (iterative solvers)."""
163
277
  xj = jnp.asarray(x)
164
278
  contrib = self.data * xj[self.pattern.cols]
@@ -167,6 +281,26 @@ class FluxSparseMatrix:
167
281
  out = jnp.zeros(self.pattern.n_dofs, dtype=contrib.dtype)
168
282
  return out.at[self.pattern.rows].add(contrib)
169
283
 
284
+ def as_cg_operator(
285
+ self,
286
+ *,
287
+ matvec: str = "flux",
288
+ preconditioner=None,
289
+ solver: str = "cg",
290
+ dof_per_node: int | None = None,
291
+ block_sizes=None,
292
+ ):
293
+ from .cg import build_cg_operator
294
+
295
+ return build_cg_operator(
296
+ self,
297
+ matvec=matvec,
298
+ preconditioner=preconditioner,
299
+ solver=solver,
300
+ dof_per_node=dof_per_node,
301
+ block_sizes=block_sizes,
302
+ )
303
+
170
304
  def diag(self):
171
305
  """Diagonal entries aggregated for Jacobi preconditioning."""
172
306
  if self.pattern.diag_idx is not None:
fluxfem/tools/jit.py CHANGED
@@ -1,10 +1,22 @@
1
+ from __future__ import annotations
2
+
1
3
  import jax
2
4
 
3
- from ..core.assembly import assemble_residual, assemble_jacobian
5
+ from typing import Callable, TypeVar
6
+
7
+ from ..core.assembly import JacobianReturn, LinearReturn, ResidualForm, assemble_jacobian, assemble_residual
4
8
  from ..core.space import FESpace
5
9
 
10
+ P = TypeVar("P")
11
+
6
12
 
7
- def make_jitted_residual(space: FESpace, res_form, params, *, sparse: bool = False):
13
+ def make_jitted_residual(
14
+ space: FESpace,
15
+ res_form: ResidualForm[P],
16
+ params: P,
17
+ *,
18
+ sparse: bool = False,
19
+ ) -> Callable[[jax.Array], LinearReturn]:
8
20
  """
9
21
  Create a jitted residual assembler: u -> R(u).
10
22
  params and space are closed over.
@@ -13,7 +25,7 @@ def make_jitted_residual(space: FESpace, res_form, params, *, sparse: bool = Fal
13
25
  params_jax = params
14
26
 
15
27
  @jax.jit
16
- def residual(u):
28
+ def residual(u: jax.Array) -> LinearReturn:
17
29
  return assemble_residual(space_jax, res_form, u, params_jax, sparse=sparse)
18
30
 
19
31
  return residual
@@ -21,12 +33,12 @@ def make_jitted_residual(space: FESpace, res_form, params, *, sparse: bool = Fal
21
33
 
22
34
  def make_jitted_jacobian(
23
35
  space: FESpace,
24
- res_form,
25
- params,
36
+ res_form: ResidualForm[P],
37
+ params: P,
26
38
  *,
27
39
  sparse: bool = False,
28
40
  return_flux_matrix: bool = False,
29
- ):
41
+ ) -> Callable[[jax.Array], JacobianReturn]:
30
42
  """
31
43
  Create a jitted Jacobian assembler: u -> J(u).
32
44
  params and space are closed over.
@@ -35,7 +47,7 @@ def make_jitted_jacobian(
35
47
  params_jax = params
36
48
 
37
49
  @jax.jit
38
- def jacobian(u):
50
+ def jacobian(u: jax.Array) -> JacobianReturn:
39
51
  return assemble_jacobian(
40
52
  space_jax,
41
53
  res_form,
fluxfem/tools/timer.py CHANGED
@@ -5,13 +5,15 @@ from contextlib import AbstractContextManager
5
5
  from collections import defaultdict
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass
8
- from typing import Callable, DefaultDict, Dict, Iterator, List, Optional
8
+ from typing import Any, Callable, DefaultDict, Dict, Iterator, List, Optional, TypeAlias
9
9
 
10
10
  import logging
11
11
 
12
12
  logging.basicConfig(level=logging.INFO)
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
+ PlotResult: TypeAlias = tuple[Any, Any]
16
+
15
17
 
16
18
  @dataclass
17
19
  class SectionStats:
@@ -30,7 +32,7 @@ class BaseTimer(ABC):
30
32
 
31
33
  class NullTimer(BaseTimer):
32
34
  @contextmanager
33
- def section(self, name: str):
35
+ def section(self, name: str) -> Iterator[None]:
34
36
  yield
35
37
 
36
38
 
@@ -85,7 +87,7 @@ class SectionTimer(BaseTimer):
85
87
  self._records[full_name].append(duration)
86
88
  self._stack.pop()
87
89
 
88
- def wrap(self, name: str):
90
+ def wrap(self, name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
89
91
  """
90
92
  Decorator form of :meth:`section`.
91
93
 
@@ -97,8 +99,8 @@ class SectionTimer(BaseTimer):
97
99
  ...
98
100
  """
99
101
 
100
- def _decorator(func):
101
- def _wrapper(*args, **kwargs):
102
+ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
103
+ def _wrapper(*args: Any, **kwargs: Any) -> Any:
102
104
  with self.section(name):
103
105
  return func(*args, **kwargs)
104
106
  return _wrapper
@@ -211,7 +213,7 @@ class SectionTimer(BaseTimer):
211
213
  self,
212
214
  sort_by: str = "total",
213
215
  descending: bool = True,
214
- logger_instance=None,
216
+ logger_instance: logging.Logger | None = None,
215
217
  ) -> str:
216
218
  stats = self.summary(sort_by=sort_by, descending=descending)
217
219
  if not stats:
@@ -231,7 +233,7 @@ class SectionTimer(BaseTimer):
231
233
 
232
234
  def plot_bar(
233
235
  self,
234
- ax=None,
236
+ ax: Any | None = None,
235
237
  sort_by: str = "total",
236
238
  value: str = "total",
237
239
  descending: bool = True,
@@ -240,7 +242,7 @@ class SectionTimer(BaseTimer):
240
242
  stacked_nested: bool = False,
241
243
  moving_average: bool = False,
242
244
  use_self_time: bool = False,
243
- ):
245
+ ) -> PlotResult:
244
246
  """
245
247
  Plot timing results as a horizontal bar chart without relying on pyplot state.
246
248
 
@@ -417,7 +419,7 @@ class SectionTimer(BaseTimer):
417
419
 
418
420
  def plot_pie(
419
421
  self,
420
- ax=None,
422
+ ax: Any | None = None,
421
423
  sort_by: str = "total",
422
424
  value: str = "total",
423
425
  descending: bool = True,
@@ -430,7 +432,7 @@ class SectionTimer(BaseTimer):
430
432
  show_total: bool = True,
431
433
  moving_average: bool = False,
432
434
  use_self_time: bool = False,
433
- ):
435
+ ) -> PlotResult:
434
436
  """
435
437
  Plot timing results as a pie chart to show relative time share.
436
438
 
@@ -552,7 +554,7 @@ class SectionTimer(BaseTimer):
552
554
 
553
555
  def plot(
554
556
  self,
555
- ax=None,
557
+ ax: Any | None = None,
556
558
  sort_by: str = "total",
557
559
  value: str = "total",
558
560
  descending: bool = True,
@@ -563,7 +565,7 @@ class SectionTimer(BaseTimer):
563
565
  moving_average: bool = False,
564
566
  use_self_time: bool = False,
565
567
  **kwargs,
566
- ):
568
+ ) -> PlotResult:
567
569
  """
568
570
  Plot timing results choosing between pie (default) or bar chart.
569
571
 
@@ -3,14 +3,14 @@ from typing import Mapping
3
3
 
4
4
  import numpy as np
5
5
 
6
- from ..mesh import HexMesh, TetMesh
6
+ from ..mesh import BaseMesh, HexMesh, TetMesh
7
7
 
8
8
 
9
9
  VTK_HEXAHEDRON = 12
10
10
  VTK_TETRA = 10
11
11
 
12
12
 
13
- def _cell_type_for_mesh(mesh):
13
+ def _cell_type_for_mesh(mesh: BaseMesh) -> int:
14
14
  if isinstance(mesh, HexMesh):
15
15
  return VTK_HEXAHEDRON
16
16
  if isinstance(mesh, TetMesh):
@@ -25,7 +25,13 @@ def _write_dataarray(name: str, data: np.ndarray, ncomp: int = 1) -> str:
25
25
  return f'<DataArray type="Float32" Name="{name}" format="ascii"{comp_attr}>{values}</DataArray>'
26
26
 
27
27
 
28
- def write_vtu(mesh, filepath: str, *, point_data: Mapping[str, np.ndarray] | None = None, cell_data: Mapping[str, np.ndarray] | None = None):
28
+ def write_vtu(
29
+ mesh: BaseMesh,
30
+ filepath: str,
31
+ *,
32
+ point_data: Mapping[str, np.ndarray] | None = None,
33
+ cell_data: Mapping[str, np.ndarray] | None = None,
34
+ ) -> None:
29
35
  """
30
36
  Write an UnstructuredGrid VTU for HexMesh or TetMesh.
31
37
  point_data/cell_data: dict name -> ndarray. Point data length must match n_points;
@@ -87,7 +93,13 @@ def write_vtu(mesh, filepath: str, *, point_data: Mapping[str, np.ndarray] | Non
87
93
  f.write("\n".join(lines))
88
94
 
89
95
 
90
- def write_displacement_vtu(mesh, u, filepath: str, *, name: str = "displacement"):
96
+ def write_displacement_vtu(
97
+ mesh: BaseMesh,
98
+ u: np.ndarray,
99
+ filepath: str,
100
+ *,
101
+ name: str = "displacement",
102
+ ) -> None:
91
103
  """
92
104
  Convenience wrapper: reshape displacement vector to point data and write VTU.
93
105
  Assumes 3 dof/node ordering [u0,v0,w0, u1,v1,w1, ...].