fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -1,10 +1,14 @@
|
|
|
1
|
+
from typing import TypeAlias
|
|
2
|
+
|
|
1
3
|
import jax.numpy as jnp
|
|
2
4
|
|
|
3
|
-
from ...core.assembly import
|
|
5
|
+
from ...core.assembly import LinearReturn
|
|
4
6
|
from ...core.forms import FormContext, vector_load_form
|
|
5
|
-
from ...core.
|
|
7
|
+
from ...core.space import FESpace
|
|
6
8
|
from ...physics.operators import sym_grad
|
|
7
9
|
|
|
10
|
+
ArrayLike: TypeAlias = jnp.ndarray
|
|
11
|
+
|
|
8
12
|
# from ...mechanics.kinematics import build_B_matrices
|
|
9
13
|
|
|
10
14
|
|
|
@@ -26,21 +30,35 @@ def linear_elasticity_form(ctx: FormContext, D: jnp.ndarray) -> jnp.ndarray:
|
|
|
26
30
|
symmetric-gradient operator for the test/trial fields.
|
|
27
31
|
"""
|
|
28
32
|
Bu = sym_grad(ctx.trial) # (n_q, 6, ndofs_e)
|
|
29
|
-
Bv =
|
|
33
|
+
Bv = Bu if ctx.test is ctx.trial else sym_grad(ctx.test)
|
|
30
34
|
return jnp.einsum("qik,kl,qlm->qim", jnp.swapaxes(Bv, 1, 2), D, Bu)
|
|
31
35
|
|
|
32
36
|
|
|
37
|
+
linear_elasticity_form._ff_kind = "bilinear"
|
|
38
|
+
linear_elasticity_form._ff_domain = "volume"
|
|
39
|
+
|
|
40
|
+
|
|
33
41
|
def vector_body_force_form(ctx: FormContext, load_vec: jnp.ndarray) -> jnp.ndarray:
|
|
34
42
|
"""Linear form for 3D vector body force f (constant in space)."""
|
|
35
43
|
return vector_load_form(ctx.test, load_vec)
|
|
36
44
|
|
|
37
45
|
|
|
38
|
-
|
|
46
|
+
vector_body_force_form._ff_kind = "linear"
|
|
47
|
+
vector_body_force_form._ff_domain = "volume"
|
|
48
|
+
|
|
49
|
+
def assemble_constant_body_force(
|
|
50
|
+
space: FESpace,
|
|
51
|
+
gravity_vec: ArrayLike,
|
|
52
|
+
density: float,
|
|
53
|
+
*,
|
|
54
|
+
sparse: bool = False,
|
|
55
|
+
) -> LinearReturn:
|
|
39
56
|
"""
|
|
40
57
|
Convenience: assemble body force from density * gravity vector.
|
|
41
58
|
gravity_vec: length-3 array-like (direction and magnitude of g)
|
|
42
59
|
density: scalar density (consistent with unit system)
|
|
43
60
|
"""
|
|
61
|
+
from ...core.assembly import assemble_linear_form
|
|
44
62
|
g = jnp.asarray(gravity_vec)
|
|
45
63
|
f_vec = density * g
|
|
46
64
|
return assemble_linear_form(space, vector_body_force_form, params=f_vec, sparse=sparse)
|
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
|
|
3
5
|
import jax.numpy as jnp
|
|
4
6
|
|
|
7
|
+
ArrayLike: TypeAlias = jnp.ndarray
|
|
8
|
+
|
|
5
9
|
|
|
6
|
-
def _sym(A:
|
|
10
|
+
def _sym(A: ArrayLike) -> ArrayLike:
|
|
7
11
|
return 0.5 * (A + jnp.swapaxes(A, -1, -2))
|
|
8
12
|
|
|
9
13
|
|
|
10
|
-
def principal_stresses(S:
|
|
14
|
+
def principal_stresses(S: ArrayLike) -> ArrayLike:
|
|
11
15
|
"""
|
|
12
16
|
Return principal stresses (eigvals) for symmetric 3x3 stress tensor.
|
|
13
17
|
Supports batching over leading dimensions.
|
|
@@ -16,12 +20,12 @@ def principal_stresses(S: jnp.ndarray) -> jnp.ndarray:
|
|
|
16
20
|
return jnp.linalg.eigvalsh(S_sym)
|
|
17
21
|
|
|
18
22
|
|
|
19
|
-
def principal_sum(S:
|
|
23
|
+
def principal_sum(S: ArrayLike) -> ArrayLike:
|
|
20
24
|
"""Sum of principal stresses (trace)."""
|
|
21
25
|
return jnp.trace(S, axis1=-2, axis2=-1)
|
|
22
26
|
|
|
23
27
|
|
|
24
|
-
def max_shear_stress(S:
|
|
28
|
+
def max_shear_stress(S: ArrayLike) -> ArrayLike:
|
|
25
29
|
"""
|
|
26
30
|
Maximum shear stress = (sigma_max - sigma_min) / 2.
|
|
27
31
|
"""
|
|
@@ -29,7 +33,7 @@ def max_shear_stress(S: jnp.ndarray) -> jnp.ndarray:
|
|
|
29
33
|
return 0.5 * (vals[..., -1] - vals[..., 0])
|
|
30
34
|
|
|
31
35
|
|
|
32
|
-
def von_mises_stress(S:
|
|
36
|
+
def von_mises_stress(S: ArrayLike) -> ArrayLike:
|
|
33
37
|
"""
|
|
34
38
|
von Mises equivalent stress: sqrt(3/2 * dev(S):dev(S)).
|
|
35
39
|
"""
|
fluxfem/physics/operators.py
CHANGED
|
@@ -2,11 +2,18 @@
|
|
|
2
2
|
|
|
3
3
|
# fluxfem/mechanics/operators.py
|
|
4
4
|
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Any, TypeAlias
|
|
7
|
+
|
|
5
8
|
import jax
|
|
6
9
|
import jax.numpy as jnp
|
|
7
10
|
|
|
11
|
+
from ..core.context_types import FormFieldLike
|
|
12
|
+
|
|
13
|
+
ArrayLike: TypeAlias = jnp.ndarray
|
|
14
|
+
|
|
8
15
|
|
|
9
|
-
def dot(a, b):
|
|
16
|
+
def dot(a: FormFieldLike | ArrayLike, b: ArrayLike) -> ArrayLike:
|
|
10
17
|
"""
|
|
11
18
|
Batched matrix product on the last two axes.
|
|
12
19
|
|
|
@@ -19,7 +26,7 @@ def dot(a, b):
|
|
|
19
26
|
return jnp.matmul(a, b)
|
|
20
27
|
|
|
21
28
|
|
|
22
|
-
def ddot(a, b, c=None):
|
|
29
|
+
def ddot(a: ArrayLike, b: ArrayLike, c: ArrayLike | None = None) -> ArrayLike:
|
|
23
30
|
"""
|
|
24
31
|
Double contraction on the last two axes.
|
|
25
32
|
|
|
@@ -32,12 +39,12 @@ def ddot(a, b, c=None):
|
|
|
32
39
|
return jnp.einsum("...ik,kl,...lm->...im", a_t, b, c)
|
|
33
40
|
|
|
34
41
|
|
|
35
|
-
def transpose_last2(a):
|
|
42
|
+
def transpose_last2(a: ArrayLike) -> ArrayLike:
|
|
36
43
|
"""Swap the last two axes (batched transpose)."""
|
|
37
44
|
return jnp.swapaxes(a, -1, -2)
|
|
38
45
|
|
|
39
46
|
|
|
40
|
-
def sym_grad(field) -> jnp.ndarray:
|
|
47
|
+
def sym_grad(field: FormFieldLike) -> jnp.ndarray:
|
|
41
48
|
"""
|
|
42
49
|
Symmetric gradient operator for vector mechanics (small strain).
|
|
43
50
|
|
|
@@ -89,7 +96,7 @@ def sym_grad(field) -> jnp.ndarray:
|
|
|
89
96
|
return jax.vmap(B_single)(gradN)
|
|
90
97
|
|
|
91
98
|
|
|
92
|
-
def sym_grad_u(field, u_elem: jnp.ndarray) -> jnp.ndarray:
|
|
99
|
+
def sym_grad_u(field: FormFieldLike, u_elem: jnp.ndarray) -> jnp.ndarray:
|
|
93
100
|
"""
|
|
94
101
|
Apply sym_grad(field) to a local displacement vector.
|
|
95
102
|
|
fluxfem/physics/postprocess.py
CHANGED
|
@@ -1,15 +1,33 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from typing import TYPE_CHECKING, TypeAlias
|
|
4
|
+
|
|
3
5
|
import numpy as np
|
|
4
6
|
import jax
|
|
5
7
|
import jax.numpy as jnp
|
|
6
8
|
|
|
7
9
|
# from ..core.assembly import build_form_contexts
|
|
8
10
|
from ..tools.visualizer import write_vtu
|
|
11
|
+
from ..mesh import BaseMesh
|
|
12
|
+
from ..core.space import FESpace
|
|
9
13
|
from ..core.interp import interpolate_field_at_element_points
|
|
10
14
|
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from jax import Array as JaxArray
|
|
17
|
+
|
|
18
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
19
|
+
else:
|
|
20
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
21
|
+
|
|
11
22
|
|
|
12
|
-
def make_point_data_displacement(
|
|
23
|
+
def make_point_data_displacement(
|
|
24
|
+
mesh: BaseMesh,
|
|
25
|
+
space: FESpace,
|
|
26
|
+
u: ArrayLike,
|
|
27
|
+
*,
|
|
28
|
+
compute_j: bool = True,
|
|
29
|
+
deformed_scale: float = 1.0,
|
|
30
|
+
) -> dict[str, np.ndarray]:
|
|
13
31
|
"""
|
|
14
32
|
Common postprocess helper to build point data dictionaries:
|
|
15
33
|
- displacement
|
|
@@ -58,14 +76,22 @@ def make_point_data_displacement(mesh, space, u, *, compute_j: bool = True, defo
|
|
|
58
76
|
return point_data
|
|
59
77
|
|
|
60
78
|
|
|
61
|
-
def write_point_data_vtu(
|
|
79
|
+
def write_point_data_vtu(
|
|
80
|
+
mesh: BaseMesh,
|
|
81
|
+
space: FESpace,
|
|
82
|
+
u: ArrayLike,
|
|
83
|
+
filepath: str,
|
|
84
|
+
*,
|
|
85
|
+
compute_j: bool = True,
|
|
86
|
+
deformed_scale: float = 1.0,
|
|
87
|
+
) -> None:
|
|
62
88
|
"""Write VTU with displacement/deformed_coords and optional J."""
|
|
63
89
|
pdata = make_point_data_displacement(mesh, space, u, compute_j=compute_j, deformed_scale=deformed_scale)
|
|
64
90
|
write_vtu(mesh, filepath, point_data=pdata)
|
|
65
91
|
|
|
66
92
|
|
|
67
93
|
__all__ = ["make_point_data_displacement", "write_point_data_vtu", "interpolate_at_points"]
|
|
68
|
-
def interpolate_at_points(space, u, points: np.ndarray):
|
|
94
|
+
def interpolate_at_points(space: FESpace, u: ArrayLike, points: np.ndarray) -> np.ndarray:
|
|
69
95
|
"""
|
|
70
96
|
Interpolate displacement field at given physical points (Hex8 only, structured search).
|
|
71
97
|
- points: (m,3) array of physical coordinates.
|
fluxfem/solver/__init__.py
CHANGED
|
@@ -1,14 +1,33 @@
|
|
|
1
|
-
from .sparse import
|
|
1
|
+
from .sparse import (
|
|
2
|
+
SparsityPattern,
|
|
3
|
+
FluxSparseMatrix,
|
|
4
|
+
coalesce_coo,
|
|
5
|
+
concat_flux,
|
|
6
|
+
block_diag_flux,
|
|
7
|
+
)
|
|
2
8
|
from .dirichlet import (
|
|
9
|
+
DirichletBC,
|
|
3
10
|
enforce_dirichlet_dense,
|
|
11
|
+
enforce_dirichlet_dense_jax,
|
|
12
|
+
enforce_dirichlet_fluxsparse,
|
|
4
13
|
enforce_dirichlet_sparse,
|
|
5
14
|
free_dofs,
|
|
15
|
+
split_dirichlet_matrix,
|
|
16
|
+
restrict_flux_to_free,
|
|
17
|
+
condense_dirichlet_system,
|
|
18
|
+
enforce_dirichlet_system,
|
|
6
19
|
condense_dirichlet_fluxsparse,
|
|
20
|
+
condense_dirichlet_fluxsparse_coo,
|
|
7
21
|
condense_dirichlet_dense,
|
|
8
22
|
expand_dirichlet_solution,
|
|
9
23
|
)
|
|
10
|
-
from .cg import cg_solve, cg_solve_jax
|
|
24
|
+
from .cg import cg_solve, cg_solve_jax, build_cg_operator, CGOperator
|
|
25
|
+
from .preconditioner import make_block_jacobi_preconditioner
|
|
26
|
+
from .block_system import build_block_system, split_block_matrix, BlockSystem
|
|
27
|
+
from .block_matrix import FluxBlockMatrix, diag as block_diag, make as make_block_matrix
|
|
11
28
|
from .newton import newton_solve
|
|
29
|
+
from .result import SolverResult
|
|
30
|
+
from .history import NewtonIterRecord
|
|
12
31
|
from .solve_runner import (
|
|
13
32
|
NonlinearAnalysis,
|
|
14
33
|
NewtonLoopConfig,
|
|
@@ -21,19 +40,42 @@ from .solve_runner import (
|
|
|
21
40
|
LinearSolveRunner,
|
|
22
41
|
)
|
|
23
42
|
from .solver import LinearSolver, NonlinearSolver
|
|
43
|
+
from .petsc import petsc_solve, petsc_shell_solve, petsc_is_available
|
|
24
44
|
|
|
25
45
|
__all__ = [
|
|
26
46
|
"SparsityPattern",
|
|
27
47
|
"FluxSparseMatrix",
|
|
48
|
+
"coalesce_coo",
|
|
49
|
+
"concat_flux",
|
|
50
|
+
"block_diag_flux",
|
|
51
|
+
"DirichletBC",
|
|
28
52
|
"enforce_dirichlet_dense",
|
|
53
|
+
"enforce_dirichlet_dense_jax",
|
|
54
|
+
"enforce_dirichlet_fluxsparse",
|
|
29
55
|
"enforce_dirichlet_sparse",
|
|
56
|
+
"split_dirichlet_matrix",
|
|
57
|
+
"enforce_dirichlet_system",
|
|
30
58
|
"free_dofs",
|
|
59
|
+
"restrict_flux_to_free",
|
|
60
|
+
"condense_dirichlet_system",
|
|
31
61
|
"condense_dirichlet_fluxsparse",
|
|
62
|
+
"condense_dirichlet_fluxsparse_coo",
|
|
32
63
|
"condense_dirichlet_dense",
|
|
33
64
|
"expand_dirichlet_solution",
|
|
34
65
|
"cg_solve",
|
|
35
66
|
"cg_solve_jax",
|
|
67
|
+
"build_cg_operator",
|
|
68
|
+
"CGOperator",
|
|
69
|
+
"make_block_jacobi_preconditioner",
|
|
70
|
+
"build_block_system",
|
|
71
|
+
"split_block_matrix",
|
|
72
|
+
"BlockSystem",
|
|
73
|
+
"FluxBlockMatrix",
|
|
74
|
+
"block_diag",
|
|
75
|
+
"make_block_matrix",
|
|
36
76
|
"newton_solve",
|
|
77
|
+
"SolverResult",
|
|
78
|
+
"NewtonIterRecord",
|
|
37
79
|
"LinearAnalysis",
|
|
38
80
|
"LinearSolveConfig",
|
|
39
81
|
"LinearStepResult",
|
|
@@ -44,4 +86,7 @@ __all__ = [
|
|
|
44
86
|
"solve_nonlinear",
|
|
45
87
|
"LinearSolver",
|
|
46
88
|
"NonlinearSolver",
|
|
89
|
+
"petsc_solve",
|
|
90
|
+
"petsc_shell_solve",
|
|
91
|
+
"petsc_is_available",
|
|
47
92
|
]
|
fluxfem/solver/bc.py
CHANGED
|
@@ -107,6 +107,10 @@ def vector_surface_load_form(ctx: SurfaceFormContext, load: npt.ArrayLike) -> np
|
|
|
107
107
|
return np.asarray(vector_load_form(ctx.v, load_arr))
|
|
108
108
|
|
|
109
109
|
|
|
110
|
+
vector_surface_load_form._ff_kind = "linear"
|
|
111
|
+
vector_surface_load_form._ff_domain = "surface"
|
|
112
|
+
|
|
113
|
+
|
|
110
114
|
def make_vector_surface_load_form(load_fn):
|
|
111
115
|
"""
|
|
112
116
|
Build a vector surface load form from a callable f(x_q) -> (n_q, dim).
|
|
@@ -115,9 +119,32 @@ def make_vector_surface_load_form(load_fn):
|
|
|
115
119
|
load_q = load_fn(ctx.x_q)
|
|
116
120
|
return vector_surface_load_form(ctx, load_q)
|
|
117
121
|
|
|
122
|
+
_form._ff_kind = "linear"
|
|
123
|
+
_form._ff_domain = "surface"
|
|
118
124
|
return _form
|
|
119
125
|
|
|
120
126
|
|
|
127
|
+
def traction_vector(traction, traction_dir: str) -> np.ndarray:
|
|
128
|
+
"""
|
|
129
|
+
Resolve traction magnitude and direction string into a vector.
|
|
130
|
+
"""
|
|
131
|
+
dir_map = {
|
|
132
|
+
"x": (1.0, 0.0, 0.0),
|
|
133
|
+
"xpos": (1.0, 0.0, 0.0),
|
|
134
|
+
"xneg": (-1.0, 0.0, 0.0),
|
|
135
|
+
"y": (0.0, 1.0, 0.0),
|
|
136
|
+
"ypos": (0.0, 1.0, 0.0),
|
|
137
|
+
"yneg": (0.0, -1.0, 0.0),
|
|
138
|
+
"z": (0.0, 0.0, 1.0),
|
|
139
|
+
"zpos": (0.0, 0.0, 1.0),
|
|
140
|
+
"zneg": (0.0, 0.0, -1.0),
|
|
141
|
+
}
|
|
142
|
+
key = traction_dir.strip().lower()
|
|
143
|
+
if key not in dir_map:
|
|
144
|
+
raise ValueError("TRACTION_DIR must be one of x/xpos/xneg/y/ypos/yneg/z/zpos/zneg")
|
|
145
|
+
return float(traction) * np.asarray(dir_map[key], dtype=float)
|
|
146
|
+
|
|
147
|
+
|
|
121
148
|
def _surface_quadrature(node_coords: np.ndarray):
|
|
122
149
|
m = node_coords.shape[0]
|
|
123
150
|
if m == 4:
|
|
@@ -404,8 +431,17 @@ def facet_normals(surface: SurfaceMesh, *, outward_from: npt.ArrayLike | None =
|
|
|
404
431
|
for i, facet in enumerate(facets):
|
|
405
432
|
if len(facet) < 3:
|
|
406
433
|
continue
|
|
407
|
-
|
|
408
|
-
|
|
434
|
+
n = None
|
|
435
|
+
p0 = coords[facet[0]]
|
|
436
|
+
for j in range(1, len(facet) - 1):
|
|
437
|
+
p1 = coords[facet[j]]
|
|
438
|
+
p2 = coords[facet[j + 1]]
|
|
439
|
+
n_candidate = np.cross(p1 - p0, p2 - p0)
|
|
440
|
+
if np.linalg.norm(n_candidate) > 0.0:
|
|
441
|
+
n = n_candidate
|
|
442
|
+
break
|
|
443
|
+
if n is None:
|
|
444
|
+
continue
|
|
409
445
|
if normalize:
|
|
410
446
|
norm = np.linalg.norm(n)
|
|
411
447
|
n = n / norm if norm != 0.0 else n
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping as AbcMapping
|
|
4
|
+
from typing import Any, Iterator, Mapping, Sequence, TypeAlias
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import scipy.sparse as sp
|
|
10
|
+
except Exception: # pragma: no cover
|
|
11
|
+
sp = None
|
|
12
|
+
|
|
13
|
+
from .block_system import split_block_matrix
|
|
14
|
+
from .sparse import FluxSparseMatrix
|
|
15
|
+
|
|
16
|
+
MatrixLike: TypeAlias = Any
|
|
17
|
+
FieldKey: TypeAlias = str | int
|
|
18
|
+
BlockMap: TypeAlias = dict[FieldKey, dict[FieldKey, MatrixLike]]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def diag(**blocks: MatrixLike) -> dict[str, MatrixLike]:
|
|
22
|
+
return dict(blocks)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _infer_sizes_from_diag(diag_blocks: Mapping[FieldKey, MatrixLike]) -> dict[FieldKey, int]:
|
|
26
|
+
sizes = {}
|
|
27
|
+
for name, blk in diag_blocks.items():
|
|
28
|
+
if isinstance(blk, FluxSparseMatrix):
|
|
29
|
+
sizes[name] = int(blk.n_dofs)
|
|
30
|
+
elif sp is not None and sp.issparse(blk):
|
|
31
|
+
shape = blk.shape
|
|
32
|
+
if shape[0] != shape[1]:
|
|
33
|
+
raise ValueError(f"diag block {name} must be square, got {shape}")
|
|
34
|
+
sizes[name] = int(shape[0])
|
|
35
|
+
else:
|
|
36
|
+
arr = np.asarray(blk)
|
|
37
|
+
if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
|
|
38
|
+
raise ValueError(f"diag block {name} must be square, got {arr.shape}")
|
|
39
|
+
sizes[name] = int(arr.shape[0])
|
|
40
|
+
return sizes
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _infer_format(blocks: AbcMapping[FieldKey, AbcMapping[FieldKey, MatrixLike]], fmt: str) -> str:
|
|
44
|
+
if fmt != "auto":
|
|
45
|
+
return fmt
|
|
46
|
+
for row in blocks.values():
|
|
47
|
+
for blk in row.values():
|
|
48
|
+
if isinstance(blk, FluxSparseMatrix):
|
|
49
|
+
return "flux"
|
|
50
|
+
if sp is not None and sp.issparse(blk):
|
|
51
|
+
return "csr"
|
|
52
|
+
return "dense"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _add_blocks(a: MatrixLike | None, b: MatrixLike | None) -> MatrixLike | None:
|
|
56
|
+
if a is None:
|
|
57
|
+
return b
|
|
58
|
+
if b is None:
|
|
59
|
+
return a
|
|
60
|
+
if isinstance(a, FluxSparseMatrix):
|
|
61
|
+
a = a.to_csr()
|
|
62
|
+
if isinstance(b, FluxSparseMatrix):
|
|
63
|
+
b = b.to_csr()
|
|
64
|
+
if sp is not None and sp.issparse(a):
|
|
65
|
+
if sp.issparse(b):
|
|
66
|
+
return a + b
|
|
67
|
+
return a + sp.csr_matrix(np.asarray(b))
|
|
68
|
+
if sp is not None and sp.issparse(b):
|
|
69
|
+
return sp.csr_matrix(np.asarray(a)) + b
|
|
70
|
+
return np.asarray(a) + np.asarray(b)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _transpose_block(block: MatrixLike, rule: str) -> MatrixLike:
|
|
74
|
+
if isinstance(block, FluxSparseMatrix):
|
|
75
|
+
if sp is None:
|
|
76
|
+
raise ImportError("scipy is required to transpose FluxSparseMatrix blocks.")
|
|
77
|
+
block = block.to_csr()
|
|
78
|
+
if sp is not None and sp.issparse(block):
|
|
79
|
+
out = block.T
|
|
80
|
+
else:
|
|
81
|
+
out = np.asarray(block).T
|
|
82
|
+
if rule == "H":
|
|
83
|
+
return out.conjugate()
|
|
84
|
+
return out
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class FluxBlockMatrix(AbcMapping[FieldKey, Mapping[FieldKey, MatrixLike]]):
|
|
88
|
+
"""
|
|
89
|
+
Lazy block-matrix container that assembles on demand.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
blocks: BlockMap,
|
|
95
|
+
*,
|
|
96
|
+
sizes: Mapping[FieldKey, int],
|
|
97
|
+
order: Sequence[FieldKey] | None = None,
|
|
98
|
+
symmetric: bool = False,
|
|
99
|
+
transpose_rule: str = "T",
|
|
100
|
+
) -> None:
|
|
101
|
+
self._blocks = blocks
|
|
102
|
+
self.sizes = {name: int(size) for name, size in sizes.items()}
|
|
103
|
+
self.field_order = tuple(order) if order is not None else tuple(self.sizes.keys())
|
|
104
|
+
self.symmetric = bool(symmetric)
|
|
105
|
+
self.transpose_rule = transpose_rule
|
|
106
|
+
for name in self.field_order:
|
|
107
|
+
if name not in self.sizes:
|
|
108
|
+
raise KeyError(f"Missing size for field '{name}'")
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, key: FieldKey) -> Mapping[FieldKey, MatrixLike]:
|
|
111
|
+
return self._blocks[key]
|
|
112
|
+
|
|
113
|
+
def __iter__(self) -> Iterator[FieldKey]:
|
|
114
|
+
return iter(self._blocks)
|
|
115
|
+
|
|
116
|
+
def __len__(self) -> int:
|
|
117
|
+
return len(self._blocks)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def blocks(self) -> BlockMap:
|
|
121
|
+
return self._blocks
|
|
122
|
+
|
|
123
|
+
def assemble(self, *, format: str = "flux") -> MatrixLike:
|
|
124
|
+
if format not in {"auto", "flux", "csr", "dense"}:
|
|
125
|
+
raise ValueError("format must be one of: auto, flux, csr, dense")
|
|
126
|
+
use_format = _infer_format(self._blocks, format)
|
|
127
|
+
|
|
128
|
+
offsets = {}
|
|
129
|
+
offset = 0
|
|
130
|
+
for name in self.field_order:
|
|
131
|
+
size = int(self.sizes[name])
|
|
132
|
+
offsets[name] = offset
|
|
133
|
+
offset += size
|
|
134
|
+
n_total = offset
|
|
135
|
+
|
|
136
|
+
def _block_shape(name_i: FieldKey, name_j: FieldKey) -> tuple[int, int]:
|
|
137
|
+
return (int(self.sizes[name_i]), int(self.sizes[name_j]))
|
|
138
|
+
|
|
139
|
+
if use_format == "flux":
|
|
140
|
+
rows_list = []
|
|
141
|
+
cols_list = []
|
|
142
|
+
data_list = []
|
|
143
|
+
for name_i in self.field_order:
|
|
144
|
+
row_blocks = self._blocks.get(name_i, {})
|
|
145
|
+
for name_j in self.field_order:
|
|
146
|
+
blk = row_blocks.get(name_j)
|
|
147
|
+
if blk is None:
|
|
148
|
+
continue
|
|
149
|
+
shape = _block_shape(name_i, name_j)
|
|
150
|
+
if isinstance(blk, FluxSparseMatrix):
|
|
151
|
+
if shape[0] != shape[1] or int(blk.n_dofs) != shape[0]:
|
|
152
|
+
raise ValueError(f"Block {name_i},{name_j} has incompatible FluxSparseMatrix size")
|
|
153
|
+
r = np.asarray(blk.pattern.rows, dtype=np.int64)
|
|
154
|
+
c = np.asarray(blk.pattern.cols, dtype=np.int64)
|
|
155
|
+
d = np.asarray(blk.data)
|
|
156
|
+
elif sp is not None and sp.issparse(blk):
|
|
157
|
+
coo = blk.tocoo()
|
|
158
|
+
r = np.asarray(coo.row, dtype=np.int64)
|
|
159
|
+
c = np.asarray(coo.col, dtype=np.int64)
|
|
160
|
+
d = np.asarray(coo.data)
|
|
161
|
+
if coo.shape != shape:
|
|
162
|
+
raise ValueError(f"Block {name_i},{name_j} has shape {coo.shape}, expected {shape}")
|
|
163
|
+
else:
|
|
164
|
+
arr = np.asarray(blk)
|
|
165
|
+
if arr.shape != shape:
|
|
166
|
+
raise ValueError(f"Block {name_i},{name_j} has shape {arr.shape}, expected {shape}")
|
|
167
|
+
r, c = np.nonzero(arr)
|
|
168
|
+
d = arr[r, c]
|
|
169
|
+
if r.size:
|
|
170
|
+
rows_list.append(r + offsets[name_i])
|
|
171
|
+
cols_list.append(c + offsets[name_j])
|
|
172
|
+
data_list.append(d)
|
|
173
|
+
rows = np.concatenate(rows_list) if rows_list else np.asarray([], dtype=np.int32)
|
|
174
|
+
cols = np.concatenate(cols_list) if cols_list else np.asarray([], dtype=np.int32)
|
|
175
|
+
data = np.concatenate(data_list) if data_list else np.asarray([], dtype=float)
|
|
176
|
+
return FluxSparseMatrix(rows, cols, data, n_total)
|
|
177
|
+
|
|
178
|
+
if use_format == "csr" and sp is None:
|
|
179
|
+
raise ImportError("scipy is required for CSR block systems.")
|
|
180
|
+
block_rows = []
|
|
181
|
+
for name_i in self.field_order:
|
|
182
|
+
row = []
|
|
183
|
+
row_blocks = self._blocks.get(name_i, {})
|
|
184
|
+
for name_j in self.field_order:
|
|
185
|
+
blk = row_blocks.get(name_j)
|
|
186
|
+
shape = _block_shape(name_i, name_j)
|
|
187
|
+
if blk is None:
|
|
188
|
+
if use_format == "csr":
|
|
189
|
+
row.append(sp.csr_matrix(shape))
|
|
190
|
+
else:
|
|
191
|
+
row.append(np.zeros(shape, dtype=float))
|
|
192
|
+
continue
|
|
193
|
+
if isinstance(blk, FluxSparseMatrix):
|
|
194
|
+
if sp is None:
|
|
195
|
+
raise ImportError("scipy is required to assemble sparse block systems.")
|
|
196
|
+
blk = blk.to_csr()
|
|
197
|
+
if sp is not None and sp.issparse(blk):
|
|
198
|
+
blk = blk.tocsr()
|
|
199
|
+
if blk.shape != shape:
|
|
200
|
+
raise ValueError(f"Block {name_i},{name_j} has shape {blk.shape}, expected {shape}")
|
|
201
|
+
row.append(blk)
|
|
202
|
+
else:
|
|
203
|
+
arr = np.asarray(blk)
|
|
204
|
+
if arr.shape != shape:
|
|
205
|
+
raise ValueError(f"Block {name_i},{name_j} has shape {arr.shape}, expected {shape}")
|
|
206
|
+
if use_format == "csr":
|
|
207
|
+
row.append(sp.csr_matrix(arr))
|
|
208
|
+
else:
|
|
209
|
+
row.append(arr)
|
|
210
|
+
block_rows.append(row)
|
|
211
|
+
if use_format == "csr":
|
|
212
|
+
return sp.bmat(block_rows, format="csr")
|
|
213
|
+
return np.block(block_rows)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def make(
|
|
217
|
+
*,
|
|
218
|
+
diag: Mapping[FieldKey, MatrixLike] | Sequence[MatrixLike],
|
|
219
|
+
rel: Mapping[tuple[FieldKey, FieldKey], MatrixLike] | None = None,
|
|
220
|
+
add_contiguous: MatrixLike | None = None,
|
|
221
|
+
sizes: Mapping[FieldKey, int] | None = None,
|
|
222
|
+
symmetric: bool = False,
|
|
223
|
+
transpose_rule: str = "T",
|
|
224
|
+
) -> FluxBlockMatrix:
|
|
225
|
+
"""
|
|
226
|
+
Build a lazy FluxBlockMatrix from diagonal blocks, optional relations, and a full matrix.
|
|
227
|
+
"""
|
|
228
|
+
if isinstance(diag, Mapping):
|
|
229
|
+
diag_map = dict(diag)
|
|
230
|
+
else:
|
|
231
|
+
diag_seq = list(diag)
|
|
232
|
+
if sizes is None:
|
|
233
|
+
diag_map = dict(zip(range(len(diag_seq)), diag_seq))
|
|
234
|
+
else:
|
|
235
|
+
order = tuple(sizes.keys())
|
|
236
|
+
if len(diag_seq) != len(order):
|
|
237
|
+
raise ValueError("diag sequence length must match sizes")
|
|
238
|
+
diag_map = dict(zip(order, diag_seq))
|
|
239
|
+
|
|
240
|
+
if sizes is None:
|
|
241
|
+
sizes = _infer_sizes_from_diag(diag_map)
|
|
242
|
+
order = tuple(sizes.keys())
|
|
243
|
+
|
|
244
|
+
if add_contiguous is None:
|
|
245
|
+
blocks = {name: {} for name in order}
|
|
246
|
+
else:
|
|
247
|
+
blocks = split_block_matrix(add_contiguous, sizes=sizes)
|
|
248
|
+
|
|
249
|
+
for name, blk in diag_map.items():
|
|
250
|
+
if name not in sizes:
|
|
251
|
+
raise KeyError(f"Unknown field '{name}' in diag")
|
|
252
|
+
blocks.setdefault(name, {})
|
|
253
|
+
blocks[name][name] = _add_blocks(blocks[name].get(name), blk)
|
|
254
|
+
|
|
255
|
+
if transpose_rule not in {"T", "H", "none"}:
|
|
256
|
+
raise ValueError("transpose_rule must be one of: T, H, none")
|
|
257
|
+
|
|
258
|
+
if rel is not None:
|
|
259
|
+
for (name_i, name_j), blk in rel.items():
|
|
260
|
+
if name_i not in sizes or name_j not in sizes:
|
|
261
|
+
raise KeyError(f"Unknown field in rel: {(name_i, name_j)}")
|
|
262
|
+
blocks.setdefault(name_i, {})
|
|
263
|
+
blocks[name_i][name_j] = _add_blocks(blocks[name_i].get(name_j), blk)
|
|
264
|
+
if symmetric and name_i != name_j:
|
|
265
|
+
if transpose_rule == "none":
|
|
266
|
+
blocks.setdefault(name_j, {})
|
|
267
|
+
blocks[name_j][name_i] = _add_blocks(blocks[name_j].get(name_i), blk)
|
|
268
|
+
else:
|
|
269
|
+
blocks.setdefault(name_j, {})
|
|
270
|
+
blocks[name_j][name_i] = _add_blocks(
|
|
271
|
+
blocks[name_j].get(name_i),
|
|
272
|
+
_transpose_block(blk, transpose_rule),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return FluxBlockMatrix(
|
|
276
|
+
blocks,
|
|
277
|
+
sizes=sizes,
|
|
278
|
+
order=order,
|
|
279
|
+
symmetric=symmetric,
|
|
280
|
+
transpose_rule=transpose_rule,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
__all__ = ["FluxBlockMatrix", "diag", "make"]
|