fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
@@ -1,10 +1,14 @@
1
+ from typing import TypeAlias
2
+
1
3
  import jax.numpy as jnp
2
4
 
3
- from ...core.assembly import assemble_linear_form
5
+ from ...core.assembly import LinearReturn
4
6
  from ...core.forms import FormContext, vector_load_form
5
- from ...core.basis import build_B_matrices
7
+ from ...core.space import FESpace
6
8
  from ...physics.operators import sym_grad
7
9
 
10
+ ArrayLike: TypeAlias = jnp.ndarray
11
+
8
12
  # from ...mechanics.kinematics import build_B_matrices
9
13
 
10
14
 
@@ -26,21 +30,35 @@ def linear_elasticity_form(ctx: FormContext, D: jnp.ndarray) -> jnp.ndarray:
26
30
  symmetric-gradient operator for the test/trial fields.
27
31
  """
28
32
  Bu = sym_grad(ctx.trial) # (n_q, 6, ndofs_e)
29
- Bv = sym_grad(ctx.test) # (n_q, 6, ndofs_e)
33
+ Bv = Bu if ctx.test is ctx.trial else sym_grad(ctx.test)
30
34
  return jnp.einsum("qik,kl,qlm->qim", jnp.swapaxes(Bv, 1, 2), D, Bu)
31
35
 
32
36
 
37
+ linear_elasticity_form._ff_kind = "bilinear"
38
+ linear_elasticity_form._ff_domain = "volume"
39
+
40
+
33
41
  def vector_body_force_form(ctx: FormContext, load_vec: jnp.ndarray) -> jnp.ndarray:
34
42
  """Linear form for 3D vector body force f (constant in space)."""
35
43
  return vector_load_form(ctx.test, load_vec)
36
44
 
37
45
 
38
- def assemble_constant_body_force(space, gravity_vec, density: float, *, sparse: bool = False):
46
+ vector_body_force_form._ff_kind = "linear"
47
+ vector_body_force_form._ff_domain = "volume"
48
+
49
+ def assemble_constant_body_force(
50
+ space: FESpace,
51
+ gravity_vec: ArrayLike,
52
+ density: float,
53
+ *,
54
+ sparse: bool = False,
55
+ ) -> LinearReturn:
39
56
  """
40
57
  Convenience: assemble body force from density * gravity vector.
41
58
  gravity_vec: length-3 array-like (direction and magnitude of g)
42
59
  density: scalar density (consistent with unit system)
43
60
  """
61
+ from ...core.assembly import assemble_linear_form
44
62
  g = jnp.asarray(gravity_vec)
45
63
  f_vec = density * g
46
64
  return assemble_linear_form(space, vector_body_force_form, params=f_vec, sparse=sparse)
@@ -1,13 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import TypeAlias
4
+
3
5
  import jax.numpy as jnp
4
6
 
7
+ ArrayLike: TypeAlias = jnp.ndarray
8
+
5
9
 
6
- def _sym(A: jnp.ndarray) -> jnp.ndarray:
10
+ def _sym(A: ArrayLike) -> ArrayLike:
7
11
  return 0.5 * (A + jnp.swapaxes(A, -1, -2))
8
12
 
9
13
 
10
- def principal_stresses(S: jnp.ndarray) -> jnp.ndarray:
14
+ def principal_stresses(S: ArrayLike) -> ArrayLike:
11
15
  """
12
16
  Return principal stresses (eigvals) for symmetric 3x3 stress tensor.
13
17
  Supports batching over leading dimensions.
@@ -16,12 +20,12 @@ def principal_stresses(S: jnp.ndarray) -> jnp.ndarray:
16
20
  return jnp.linalg.eigvalsh(S_sym)
17
21
 
18
22
 
19
- def principal_sum(S: jnp.ndarray) -> jnp.ndarray:
23
+ def principal_sum(S: ArrayLike) -> ArrayLike:
20
24
  """Sum of principal stresses (trace)."""
21
25
  return jnp.trace(S, axis1=-2, axis2=-1)
22
26
 
23
27
 
24
- def max_shear_stress(S: jnp.ndarray) -> jnp.ndarray:
28
+ def max_shear_stress(S: ArrayLike) -> ArrayLike:
25
29
  """
26
30
  Maximum shear stress = (sigma_max - sigma_min) / 2.
27
31
  """
@@ -29,7 +33,7 @@ def max_shear_stress(S: jnp.ndarray) -> jnp.ndarray:
29
33
  return 0.5 * (vals[..., -1] - vals[..., 0])
30
34
 
31
35
 
32
- def von_mises_stress(S: jnp.ndarray) -> jnp.ndarray:
36
+ def von_mises_stress(S: ArrayLike) -> ArrayLike:
33
37
  """
34
38
  von Mises equivalent stress: sqrt(3/2 * dev(S):dev(S)).
35
39
  """
@@ -2,11 +2,18 @@
2
2
 
3
3
  # fluxfem/mechanics/operators.py
4
4
  from __future__ import annotations
5
+
6
+ from typing import Any, TypeAlias
7
+
5
8
  import jax
6
9
  import jax.numpy as jnp
7
10
 
11
+ from ..core.context_types import FormFieldLike
12
+
13
+ ArrayLike: TypeAlias = jnp.ndarray
14
+
8
15
 
9
- def dot(a, b):
16
+ def dot(a: FormFieldLike | ArrayLike, b: ArrayLike) -> ArrayLike:
10
17
  """
11
18
  Batched matrix product on the last two axes.
12
19
 
@@ -19,7 +26,7 @@ def dot(a, b):
19
26
  return jnp.matmul(a, b)
20
27
 
21
28
 
22
- def ddot(a, b, c=None):
29
+ def ddot(a: ArrayLike, b: ArrayLike, c: ArrayLike | None = None) -> ArrayLike:
23
30
  """
24
31
  Double contraction on the last two axes.
25
32
 
@@ -32,12 +39,12 @@ def ddot(a, b, c=None):
32
39
  return jnp.einsum("...ik,kl,...lm->...im", a_t, b, c)
33
40
 
34
41
 
35
- def transpose_last2(a):
42
+ def transpose_last2(a: ArrayLike) -> ArrayLike:
36
43
  """Swap the last two axes (batched transpose)."""
37
44
  return jnp.swapaxes(a, -1, -2)
38
45
 
39
46
 
40
- def sym_grad(field) -> jnp.ndarray:
47
+ def sym_grad(field: FormFieldLike) -> jnp.ndarray:
41
48
  """
42
49
  Symmetric gradient operator for vector mechanics (small strain).
43
50
 
@@ -89,7 +96,7 @@ def sym_grad(field) -> jnp.ndarray:
89
96
  return jax.vmap(B_single)(gradN)
90
97
 
91
98
 
92
- def sym_grad_u(field, u_elem: jnp.ndarray) -> jnp.ndarray:
99
+ def sym_grad_u(field: FormFieldLike, u_elem: jnp.ndarray) -> jnp.ndarray:
93
100
  """
94
101
  Apply sym_grad(field) to a local displacement vector.
95
102
 
@@ -1,15 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import TYPE_CHECKING, TypeAlias
4
+
3
5
  import numpy as np
4
6
  import jax
5
7
  import jax.numpy as jnp
6
8
 
7
9
  # from ..core.assembly import build_form_contexts
8
10
  from ..tools.visualizer import write_vtu
11
+ from ..mesh import BaseMesh
12
+ from ..core.space import FESpace
9
13
  from ..core.interp import interpolate_field_at_element_points
10
14
 
15
+ if TYPE_CHECKING:
16
+ from jax import Array as JaxArray
17
+
18
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
19
+ else:
20
+ ArrayLike: TypeAlias = np.ndarray
21
+
11
22
 
12
- def make_point_data_displacement(mesh, space, u, *, compute_j: bool = True, deformed_scale: float = 1.0):
23
+ def make_point_data_displacement(
24
+ mesh: BaseMesh,
25
+ space: FESpace,
26
+ u: ArrayLike,
27
+ *,
28
+ compute_j: bool = True,
29
+ deformed_scale: float = 1.0,
30
+ ) -> dict[str, np.ndarray]:
13
31
  """
14
32
  Common postprocess helper to build point data dictionaries:
15
33
  - displacement
@@ -58,14 +76,22 @@ def make_point_data_displacement(mesh, space, u, *, compute_j: bool = True, defo
58
76
  return point_data
59
77
 
60
78
 
61
- def write_point_data_vtu(mesh, space, u, filepath: str, *, compute_j: bool = True, deformed_scale: float = 1.0):
79
+ def write_point_data_vtu(
80
+ mesh: BaseMesh,
81
+ space: FESpace,
82
+ u: ArrayLike,
83
+ filepath: str,
84
+ *,
85
+ compute_j: bool = True,
86
+ deformed_scale: float = 1.0,
87
+ ) -> None:
62
88
  """Write VTU with displacement/deformed_coords and optional J."""
63
89
  pdata = make_point_data_displacement(mesh, space, u, compute_j=compute_j, deformed_scale=deformed_scale)
64
90
  write_vtu(mesh, filepath, point_data=pdata)
65
91
 
66
92
 
67
93
  __all__ = ["make_point_data_displacement", "write_point_data_vtu", "interpolate_at_points"]
68
- def interpolate_at_points(space, u, points: np.ndarray):
94
+ def interpolate_at_points(space: FESpace, u: ArrayLike, points: np.ndarray) -> np.ndarray:
69
95
  """
70
96
  Interpolate displacement field at given physical points (Hex8 only, structured search).
71
97
  - points: (m,3) array of physical coordinates.
@@ -1,14 +1,33 @@
1
- from .sparse import SparsityPattern, FluxSparseMatrix
1
+ from .sparse import (
2
+ SparsityPattern,
3
+ FluxSparseMatrix,
4
+ coalesce_coo,
5
+ concat_flux,
6
+ block_diag_flux,
7
+ )
2
8
  from .dirichlet import (
9
+ DirichletBC,
3
10
  enforce_dirichlet_dense,
11
+ enforce_dirichlet_dense_jax,
12
+ enforce_dirichlet_fluxsparse,
4
13
  enforce_dirichlet_sparse,
5
14
  free_dofs,
15
+ split_dirichlet_matrix,
16
+ restrict_flux_to_free,
17
+ condense_dirichlet_system,
18
+ enforce_dirichlet_system,
6
19
  condense_dirichlet_fluxsparse,
20
+ condense_dirichlet_fluxsparse_coo,
7
21
  condense_dirichlet_dense,
8
22
  expand_dirichlet_solution,
9
23
  )
10
- from .cg import cg_solve, cg_solve_jax
24
+ from .cg import cg_solve, cg_solve_jax, build_cg_operator, CGOperator
25
+ from .preconditioner import make_block_jacobi_preconditioner
26
+ from .block_system import build_block_system, split_block_matrix, BlockSystem
27
+ from .block_matrix import FluxBlockMatrix, diag as block_diag, make as make_block_matrix
11
28
  from .newton import newton_solve
29
+ from .result import SolverResult
30
+ from .history import NewtonIterRecord
12
31
  from .solve_runner import (
13
32
  NonlinearAnalysis,
14
33
  NewtonLoopConfig,
@@ -21,19 +40,42 @@ from .solve_runner import (
21
40
  LinearSolveRunner,
22
41
  )
23
42
  from .solver import LinearSolver, NonlinearSolver
43
+ from .petsc import petsc_solve, petsc_shell_solve, petsc_is_available
24
44
 
25
45
  __all__ = [
26
46
  "SparsityPattern",
27
47
  "FluxSparseMatrix",
48
+ "coalesce_coo",
49
+ "concat_flux",
50
+ "block_diag_flux",
51
+ "DirichletBC",
28
52
  "enforce_dirichlet_dense",
53
+ "enforce_dirichlet_dense_jax",
54
+ "enforce_dirichlet_fluxsparse",
29
55
  "enforce_dirichlet_sparse",
56
+ "split_dirichlet_matrix",
57
+ "enforce_dirichlet_system",
30
58
  "free_dofs",
59
+ "restrict_flux_to_free",
60
+ "condense_dirichlet_system",
31
61
  "condense_dirichlet_fluxsparse",
62
+ "condense_dirichlet_fluxsparse_coo",
32
63
  "condense_dirichlet_dense",
33
64
  "expand_dirichlet_solution",
34
65
  "cg_solve",
35
66
  "cg_solve_jax",
67
+ "build_cg_operator",
68
+ "CGOperator",
69
+ "make_block_jacobi_preconditioner",
70
+ "build_block_system",
71
+ "split_block_matrix",
72
+ "BlockSystem",
73
+ "FluxBlockMatrix",
74
+ "block_diag",
75
+ "make_block_matrix",
36
76
  "newton_solve",
77
+ "SolverResult",
78
+ "NewtonIterRecord",
37
79
  "LinearAnalysis",
38
80
  "LinearSolveConfig",
39
81
  "LinearStepResult",
@@ -44,4 +86,7 @@ __all__ = [
44
86
  "solve_nonlinear",
45
87
  "LinearSolver",
46
88
  "NonlinearSolver",
89
+ "petsc_solve",
90
+ "petsc_shell_solve",
91
+ "petsc_is_available",
47
92
  ]
fluxfem/solver/bc.py CHANGED
@@ -107,6 +107,10 @@ def vector_surface_load_form(ctx: SurfaceFormContext, load: npt.ArrayLike) -> np
107
107
  return np.asarray(vector_load_form(ctx.v, load_arr))
108
108
 
109
109
 
110
+ vector_surface_load_form._ff_kind = "linear"
111
+ vector_surface_load_form._ff_domain = "surface"
112
+
113
+
110
114
  def make_vector_surface_load_form(load_fn):
111
115
  """
112
116
  Build a vector surface load form from a callable f(x_q) -> (n_q, dim).
@@ -115,9 +119,32 @@ def make_vector_surface_load_form(load_fn):
115
119
  load_q = load_fn(ctx.x_q)
116
120
  return vector_surface_load_form(ctx, load_q)
117
121
 
122
+ _form._ff_kind = "linear"
123
+ _form._ff_domain = "surface"
118
124
  return _form
119
125
 
120
126
 
127
+ def traction_vector(traction, traction_dir: str) -> np.ndarray:
128
+ """
129
+ Resolve traction magnitude and direction string into a vector.
130
+ """
131
+ dir_map = {
132
+ "x": (1.0, 0.0, 0.0),
133
+ "xpos": (1.0, 0.0, 0.0),
134
+ "xneg": (-1.0, 0.0, 0.0),
135
+ "y": (0.0, 1.0, 0.0),
136
+ "ypos": (0.0, 1.0, 0.0),
137
+ "yneg": (0.0, -1.0, 0.0),
138
+ "z": (0.0, 0.0, 1.0),
139
+ "zpos": (0.0, 0.0, 1.0),
140
+ "zneg": (0.0, 0.0, -1.0),
141
+ }
142
+ key = traction_dir.strip().lower()
143
+ if key not in dir_map:
144
+ raise ValueError("TRACTION_DIR must be one of x/xpos/xneg/y/ypos/yneg/z/zpos/zneg")
145
+ return float(traction) * np.asarray(dir_map[key], dtype=float)
146
+
147
+
121
148
  def _surface_quadrature(node_coords: np.ndarray):
122
149
  m = node_coords.shape[0]
123
150
  if m == 4:
@@ -404,8 +431,17 @@ def facet_normals(surface: SurfaceMesh, *, outward_from: npt.ArrayLike | None =
404
431
  for i, facet in enumerate(facets):
405
432
  if len(facet) < 3:
406
433
  continue
407
- p0, p1, p2 = coords[facet[0]], coords[facet[1]], coords[facet[2]]
408
- n = np.cross(p1 - p0, p2 - p0)
434
+ n = None
435
+ p0 = coords[facet[0]]
436
+ for j in range(1, len(facet) - 1):
437
+ p1 = coords[facet[j]]
438
+ p2 = coords[facet[j + 1]]
439
+ n_candidate = np.cross(p1 - p0, p2 - p0)
440
+ if np.linalg.norm(n_candidate) > 0.0:
441
+ n = n_candidate
442
+ break
443
+ if n is None:
444
+ continue
409
445
  if normalize:
410
446
  norm = np.linalg.norm(n)
411
447
  n = n / norm if norm != 0.0 else n
@@ -0,0 +1,284 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping as AbcMapping
4
+ from typing import Any, Iterator, Mapping, Sequence, TypeAlias
5
+
6
+ import numpy as np
7
+
8
+ try:
9
+ import scipy.sparse as sp
10
+ except Exception: # pragma: no cover
11
+ sp = None
12
+
13
+ from .block_system import split_block_matrix
14
+ from .sparse import FluxSparseMatrix
15
+
16
+ MatrixLike: TypeAlias = Any
17
+ FieldKey: TypeAlias = str | int
18
+ BlockMap: TypeAlias = dict[FieldKey, dict[FieldKey, MatrixLike]]
19
+
20
+
21
+ def diag(**blocks: MatrixLike) -> dict[str, MatrixLike]:
22
+ return dict(blocks)
23
+
24
+
25
+ def _infer_sizes_from_diag(diag_blocks: Mapping[FieldKey, MatrixLike]) -> dict[FieldKey, int]:
26
+ sizes = {}
27
+ for name, blk in diag_blocks.items():
28
+ if isinstance(blk, FluxSparseMatrix):
29
+ sizes[name] = int(blk.n_dofs)
30
+ elif sp is not None and sp.issparse(blk):
31
+ shape = blk.shape
32
+ if shape[0] != shape[1]:
33
+ raise ValueError(f"diag block {name} must be square, got {shape}")
34
+ sizes[name] = int(shape[0])
35
+ else:
36
+ arr = np.asarray(blk)
37
+ if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
38
+ raise ValueError(f"diag block {name} must be square, got {arr.shape}")
39
+ sizes[name] = int(arr.shape[0])
40
+ return sizes
41
+
42
+
43
+ def _infer_format(blocks: AbcMapping[FieldKey, AbcMapping[FieldKey, MatrixLike]], fmt: str) -> str:
44
+ if fmt != "auto":
45
+ return fmt
46
+ for row in blocks.values():
47
+ for blk in row.values():
48
+ if isinstance(blk, FluxSparseMatrix):
49
+ return "flux"
50
+ if sp is not None and sp.issparse(blk):
51
+ return "csr"
52
+ return "dense"
53
+
54
+
55
+ def _add_blocks(a: MatrixLike | None, b: MatrixLike | None) -> MatrixLike | None:
56
+ if a is None:
57
+ return b
58
+ if b is None:
59
+ return a
60
+ if isinstance(a, FluxSparseMatrix):
61
+ a = a.to_csr()
62
+ if isinstance(b, FluxSparseMatrix):
63
+ b = b.to_csr()
64
+ if sp is not None and sp.issparse(a):
65
+ if sp.issparse(b):
66
+ return a + b
67
+ return a + sp.csr_matrix(np.asarray(b))
68
+ if sp is not None and sp.issparse(b):
69
+ return sp.csr_matrix(np.asarray(a)) + b
70
+ return np.asarray(a) + np.asarray(b)
71
+
72
+
73
+ def _transpose_block(block: MatrixLike, rule: str) -> MatrixLike:
74
+ if isinstance(block, FluxSparseMatrix):
75
+ if sp is None:
76
+ raise ImportError("scipy is required to transpose FluxSparseMatrix blocks.")
77
+ block = block.to_csr()
78
+ if sp is not None and sp.issparse(block):
79
+ out = block.T
80
+ else:
81
+ out = np.asarray(block).T
82
+ if rule == "H":
83
+ return out.conjugate()
84
+ return out
85
+
86
+
87
+ class FluxBlockMatrix(AbcMapping[FieldKey, Mapping[FieldKey, MatrixLike]]):
88
+ """
89
+ Lazy block-matrix container that assembles on demand.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ blocks: BlockMap,
95
+ *,
96
+ sizes: Mapping[FieldKey, int],
97
+ order: Sequence[FieldKey] | None = None,
98
+ symmetric: bool = False,
99
+ transpose_rule: str = "T",
100
+ ) -> None:
101
+ self._blocks = blocks
102
+ self.sizes = {name: int(size) for name, size in sizes.items()}
103
+ self.field_order = tuple(order) if order is not None else tuple(self.sizes.keys())
104
+ self.symmetric = bool(symmetric)
105
+ self.transpose_rule = transpose_rule
106
+ for name in self.field_order:
107
+ if name not in self.sizes:
108
+ raise KeyError(f"Missing size for field '{name}'")
109
+
110
+ def __getitem__(self, key: FieldKey) -> Mapping[FieldKey, MatrixLike]:
111
+ return self._blocks[key]
112
+
113
+ def __iter__(self) -> Iterator[FieldKey]:
114
+ return iter(self._blocks)
115
+
116
+ def __len__(self) -> int:
117
+ return len(self._blocks)
118
+
119
+ @property
120
+ def blocks(self) -> BlockMap:
121
+ return self._blocks
122
+
123
+ def assemble(self, *, format: str = "flux") -> MatrixLike:
124
+ if format not in {"auto", "flux", "csr", "dense"}:
125
+ raise ValueError("format must be one of: auto, flux, csr, dense")
126
+ use_format = _infer_format(self._blocks, format)
127
+
128
+ offsets = {}
129
+ offset = 0
130
+ for name in self.field_order:
131
+ size = int(self.sizes[name])
132
+ offsets[name] = offset
133
+ offset += size
134
+ n_total = offset
135
+
136
+ def _block_shape(name_i: FieldKey, name_j: FieldKey) -> tuple[int, int]:
137
+ return (int(self.sizes[name_i]), int(self.sizes[name_j]))
138
+
139
+ if use_format == "flux":
140
+ rows_list = []
141
+ cols_list = []
142
+ data_list = []
143
+ for name_i in self.field_order:
144
+ row_blocks = self._blocks.get(name_i, {})
145
+ for name_j in self.field_order:
146
+ blk = row_blocks.get(name_j)
147
+ if blk is None:
148
+ continue
149
+ shape = _block_shape(name_i, name_j)
150
+ if isinstance(blk, FluxSparseMatrix):
151
+ if shape[0] != shape[1] or int(blk.n_dofs) != shape[0]:
152
+ raise ValueError(f"Block {name_i},{name_j} has incompatible FluxSparseMatrix size")
153
+ r = np.asarray(blk.pattern.rows, dtype=np.int64)
154
+ c = np.asarray(blk.pattern.cols, dtype=np.int64)
155
+ d = np.asarray(blk.data)
156
+ elif sp is not None and sp.issparse(blk):
157
+ coo = blk.tocoo()
158
+ r = np.asarray(coo.row, dtype=np.int64)
159
+ c = np.asarray(coo.col, dtype=np.int64)
160
+ d = np.asarray(coo.data)
161
+ if coo.shape != shape:
162
+ raise ValueError(f"Block {name_i},{name_j} has shape {coo.shape}, expected {shape}")
163
+ else:
164
+ arr = np.asarray(blk)
165
+ if arr.shape != shape:
166
+ raise ValueError(f"Block {name_i},{name_j} has shape {arr.shape}, expected {shape}")
167
+ r, c = np.nonzero(arr)
168
+ d = arr[r, c]
169
+ if r.size:
170
+ rows_list.append(r + offsets[name_i])
171
+ cols_list.append(c + offsets[name_j])
172
+ data_list.append(d)
173
+ rows = np.concatenate(rows_list) if rows_list else np.asarray([], dtype=np.int32)
174
+ cols = np.concatenate(cols_list) if cols_list else np.asarray([], dtype=np.int32)
175
+ data = np.concatenate(data_list) if data_list else np.asarray([], dtype=float)
176
+ return FluxSparseMatrix(rows, cols, data, n_total)
177
+
178
+ if use_format == "csr" and sp is None:
179
+ raise ImportError("scipy is required for CSR block systems.")
180
+ block_rows = []
181
+ for name_i in self.field_order:
182
+ row = []
183
+ row_blocks = self._blocks.get(name_i, {})
184
+ for name_j in self.field_order:
185
+ blk = row_blocks.get(name_j)
186
+ shape = _block_shape(name_i, name_j)
187
+ if blk is None:
188
+ if use_format == "csr":
189
+ row.append(sp.csr_matrix(shape))
190
+ else:
191
+ row.append(np.zeros(shape, dtype=float))
192
+ continue
193
+ if isinstance(blk, FluxSparseMatrix):
194
+ if sp is None:
195
+ raise ImportError("scipy is required to assemble sparse block systems.")
196
+ blk = blk.to_csr()
197
+ if sp is not None and sp.issparse(blk):
198
+ blk = blk.tocsr()
199
+ if blk.shape != shape:
200
+ raise ValueError(f"Block {name_i},{name_j} has shape {blk.shape}, expected {shape}")
201
+ row.append(blk)
202
+ else:
203
+ arr = np.asarray(blk)
204
+ if arr.shape != shape:
205
+ raise ValueError(f"Block {name_i},{name_j} has shape {arr.shape}, expected {shape}")
206
+ if use_format == "csr":
207
+ row.append(sp.csr_matrix(arr))
208
+ else:
209
+ row.append(arr)
210
+ block_rows.append(row)
211
+ if use_format == "csr":
212
+ return sp.bmat(block_rows, format="csr")
213
+ return np.block(block_rows)
214
+
215
+
216
+ def make(
217
+ *,
218
+ diag: Mapping[FieldKey, MatrixLike] | Sequence[MatrixLike],
219
+ rel: Mapping[tuple[FieldKey, FieldKey], MatrixLike] | None = None,
220
+ add_contiguous: MatrixLike | None = None,
221
+ sizes: Mapping[FieldKey, int] | None = None,
222
+ symmetric: bool = False,
223
+ transpose_rule: str = "T",
224
+ ) -> FluxBlockMatrix:
225
+ """
226
+ Build a lazy FluxBlockMatrix from diagonal blocks, optional relations, and a full matrix.
227
+ """
228
+ if isinstance(diag, Mapping):
229
+ diag_map = dict(diag)
230
+ else:
231
+ diag_seq = list(diag)
232
+ if sizes is None:
233
+ diag_map = dict(zip(range(len(diag_seq)), diag_seq))
234
+ else:
235
+ order = tuple(sizes.keys())
236
+ if len(diag_seq) != len(order):
237
+ raise ValueError("diag sequence length must match sizes")
238
+ diag_map = dict(zip(order, diag_seq))
239
+
240
+ if sizes is None:
241
+ sizes = _infer_sizes_from_diag(diag_map)
242
+ order = tuple(sizes.keys())
243
+
244
+ if add_contiguous is None:
245
+ blocks = {name: {} for name in order}
246
+ else:
247
+ blocks = split_block_matrix(add_contiguous, sizes=sizes)
248
+
249
+ for name, blk in diag_map.items():
250
+ if name not in sizes:
251
+ raise KeyError(f"Unknown field '{name}' in diag")
252
+ blocks.setdefault(name, {})
253
+ blocks[name][name] = _add_blocks(blocks[name].get(name), blk)
254
+
255
+ if transpose_rule not in {"T", "H", "none"}:
256
+ raise ValueError("transpose_rule must be one of: T, H, none")
257
+
258
+ if rel is not None:
259
+ for (name_i, name_j), blk in rel.items():
260
+ if name_i not in sizes or name_j not in sizes:
261
+ raise KeyError(f"Unknown field in rel: {(name_i, name_j)}")
262
+ blocks.setdefault(name_i, {})
263
+ blocks[name_i][name_j] = _add_blocks(blocks[name_i].get(name_j), blk)
264
+ if symmetric and name_i != name_j:
265
+ if transpose_rule == "none":
266
+ blocks.setdefault(name_j, {})
267
+ blocks[name_j][name_i] = _add_blocks(blocks[name_j].get(name_i), blk)
268
+ else:
269
+ blocks.setdefault(name_j, {})
270
+ blocks[name_j][name_i] = _add_blocks(
271
+ blocks[name_j].get(name_i),
272
+ _transpose_block(blk, transpose_rule),
273
+ )
274
+
275
+ return FluxBlockMatrix(
276
+ blocks,
277
+ sizes=sizes,
278
+ order=order,
279
+ symmetric=symmetric,
280
+ transpose_rule=transpose_rule,
281
+ )
282
+
283
+
284
+ __all__ = ["FluxBlockMatrix", "diag", "make"]