fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fluxfem/__init__.py +1 -13
  2. fluxfem/core/__init__.py +53 -71
  3. fluxfem/core/assembly.py +41 -32
  4. fluxfem/core/basis.py +2 -2
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/mixed_space.py +42 -8
  7. fluxfem/core/mixed_weakform.py +1 -1
  8. fluxfem/core/space.py +68 -28
  9. fluxfem/core/weakform.py +95 -77
  10. fluxfem/mesh/base.py +3 -3
  11. fluxfem/mesh/contact.py +33 -17
  12. fluxfem/mesh/io.py +3 -2
  13. fluxfem/mesh/mortar.py +106 -43
  14. fluxfem/mesh/supermesh.py +2 -0
  15. fluxfem/mesh/surface.py +82 -22
  16. fluxfem/mesh/tet.py +7 -4
  17. fluxfem/physics/elasticity/hyperelastic.py +32 -3
  18. fluxfem/physics/elasticity/linear.py +13 -2
  19. fluxfem/physics/elasticity/stress.py +9 -5
  20. fluxfem/physics/operators.py +12 -5
  21. fluxfem/physics/postprocess.py +29 -3
  22. fluxfem/solver/__init__.py +6 -1
  23. fluxfem/solver/block_matrix.py +165 -13
  24. fluxfem/solver/block_system.py +52 -29
  25. fluxfem/solver/cg.py +43 -30
  26. fluxfem/solver/dirichlet.py +35 -12
  27. fluxfem/solver/history.py +15 -3
  28. fluxfem/solver/newton.py +25 -12
  29. fluxfem/solver/petsc.py +13 -7
  30. fluxfem/solver/preconditioner.py +7 -4
  31. fluxfem/solver/solve_runner.py +42 -24
  32. fluxfem/solver/solver.py +23 -11
  33. fluxfem/solver/sparse.py +32 -13
  34. fluxfem/tools/jit.py +19 -7
  35. fluxfem/tools/timer.py +14 -12
  36. fluxfem/tools/visualizer.py +16 -4
  37. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
  38. fluxfem-0.2.1.dist-info/RECORD +59 -0
  39. fluxfem-0.2.0.dist-info/RECORD +0 -59
  40. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  41. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any
4
+ from typing import Any, TYPE_CHECKING, TypeAlias
5
5
 
6
6
  import numpy as np
7
7
  import jax.numpy as jnp
@@ -13,8 +13,16 @@ except Exception: # pragma: no cover
13
13
 
14
14
  from .sparse import FluxSparseMatrix, coalesce_coo
15
15
 
16
+ if TYPE_CHECKING:
17
+ from jax import Array as JaxArray
16
18
 
17
- def _normalize_dirichlet_values(dofs, vals):
19
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
20
+ else:
21
+ ArrayLike: TypeAlias = np.ndarray
22
+ DirichletLike: TypeAlias = tuple[np.ndarray, np.ndarray]
23
+
24
+
25
+ def _normalize_dirichlet_values(dofs: ArrayLike, vals: ArrayLike | None) -> np.ndarray:
18
26
  if vals is None:
19
27
  return np.zeros(np.asarray(dofs).shape[0], dtype=float)
20
28
  arr = np.asarray(vals)
@@ -23,7 +31,7 @@ def _normalize_dirichlet_values(dofs, vals):
23
31
  return arr
24
32
 
25
33
 
26
- def _normalize_dirichlet(dofs, vals):
34
+ def _normalize_dirichlet(dofs: ArrayLike, vals: ArrayLike | None) -> DirichletLike:
27
35
  dir_arr = np.asarray(dofs, dtype=int)
28
36
  return dir_arr, _normalize_dirichlet_values(dir_arr, vals)
29
37
 
@@ -37,7 +45,7 @@ class CondensedSystem:
37
45
  dir_vals: np.ndarray
38
46
  n_dofs: int
39
47
 
40
- def expand(self, u_free, *, fill_dirichlet: bool = True):
48
+ def expand(self, u_free: ArrayLike, *, fill_dirichlet: bool = True) -> np.ndarray:
41
49
  u_full = np.zeros(self.n_dofs, dtype=np.asarray(u_free).dtype)
42
50
  u_full[self.free_dofs] = np.asarray(u_free)
43
51
  if fill_dirichlet and self.dir_dofs.size:
@@ -61,7 +69,7 @@ class DirichletBC:
61
69
  object.__setattr__(self, "vals", vals)
62
70
 
63
71
  @classmethod
64
- def from_boundary_dofs(cls, mesh, predicate, *, values=None, **kwargs):
72
+ def from_boundary_dofs(cls, mesh, predicate, *, values: ArrayLike | None = None, **kwargs) -> "DirichletBC":
65
73
  """
66
74
  Build from mesh.boundary_dofs_where predicate.
67
75
 
@@ -72,7 +80,16 @@ class DirichletBC:
72
80
  return cls(dofs, vals)
73
81
 
74
82
  @classmethod
75
- def from_bbox(cls, mesh, *, mins=None, maxs=None, tol: float = 1e-8, values=None, **kwargs):
83
+ def from_bbox(
84
+ cls,
85
+ mesh,
86
+ *,
87
+ mins: ArrayLike | None = None,
88
+ maxs: ArrayLike | None = None,
89
+ tol: float = 1e-8,
90
+ values: ArrayLike | None = None,
91
+ **kwargs,
92
+ ) -> "DirichletBC":
76
93
  """
77
94
  Build from the mesh axis-aligned bounding box.
78
95
 
@@ -94,13 +111,13 @@ class DirichletBC:
94
111
  def as_tuple(self) -> tuple[np.ndarray, np.ndarray]:
95
112
  return self.dofs, self.vals
96
113
 
97
- def condense_system(self, A, F, *, check: bool = True) -> CondensedSystem:
114
+ def condense_system(self, A: Any, F: ArrayLike, *, check: bool = True) -> CondensedSystem:
98
115
  return condense_dirichlet_system(A, F, self.dofs, self.vals, check=check)
99
116
 
100
- def enforce_system(self, A, F):
117
+ def enforce_system(self, A: Any, F: ArrayLike):
101
118
  return enforce_dirichlet_system(A, F, self.dofs, self.vals)
102
119
 
103
- def condense_flux(self, A: FluxSparseMatrix, F):
120
+ def condense_flux(self, A: FluxSparseMatrix, F: ArrayLike):
104
121
  """
105
122
  Condense for FluxSparseMatrix and return (K_free, F_free, free_dofs).
106
123
  """
@@ -108,16 +125,22 @@ class DirichletBC:
108
125
  free = condensed.free_dofs
109
126
  return restrict_flux_to_free(A, free), condensed.F, free
110
127
 
111
- def enforce_flux(self, A: FluxSparseMatrix, F):
128
+ def enforce_flux(self, A: FluxSparseMatrix, F: ArrayLike):
112
129
  return enforce_dirichlet_fluxsparse(A, F, self.dofs, self.vals)
113
130
 
114
- def split_matrix(self, A, *, n_total: int | None = None):
131
+ def split_matrix(self, A: Any, *, n_total: int | None = None):
115
132
  return split_dirichlet_matrix(A, self.dofs, n_total=n_total)
116
133
 
117
134
  def free_dofs(self, n_dofs: int) -> np.ndarray:
118
135
  return free_dofs(n_dofs, self.dofs)
119
136
 
120
- def expand_solution(self, u_free, *, free=None, n_total: int | None = None):
137
+ def expand_solution(
138
+ self,
139
+ u_free: ArrayLike,
140
+ *,
141
+ free: np.ndarray | None = None,
142
+ n_total: int | None = None,
143
+ ) -> np.ndarray:
121
144
  if free is None:
122
145
  if n_total is None:
123
146
  raise ValueError("n_total is required when free is not provided.")
fluxfem/solver/history.py CHANGED
@@ -1,7 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
- from typing import Any, List, Optional
4
+ from typing import Any, List, Optional, TYPE_CHECKING, TypeAlias
5
+
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ from .result import SolverResult
10
+
11
+ if TYPE_CHECKING:
12
+ from jax import Array as JaxArray
13
+
14
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
15
+ else:
16
+ ArrayLike: TypeAlias = np.ndarray
5
17
 
6
18
 
7
19
  @dataclass
@@ -23,9 +35,9 @@ class NewtonIterRecord:
23
35
  @dataclass
24
36
  class LoadStepResult:
25
37
  load_factor: float
26
- info: Any
38
+ info: SolverResult
27
39
  solve_time: float
28
- u: Any
40
+ u: ArrayLike
29
41
  iter_history: List[NewtonIterRecord] = field(default_factory=list)
30
42
  exception: Optional[str] = None
31
43
  meta: dict[str, Any] = field(default_factory=dict)
fluxfem/solver/newton.py CHANGED
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
  import time
3
3
 
4
+ from typing import Any, Callable, Mapping, TYPE_CHECKING, TypeAlias
5
+
4
6
  import numpy as np
5
7
  import jax
6
8
  import jax.numpy as jnp
@@ -8,6 +10,7 @@ import jax.numpy as jnp
8
10
  from ..core.assembly import (
9
11
  assemble_residual_scatter,
10
12
  assemble_jacobian_scatter,
13
+ ResidualForm,
11
14
  make_element_residual_kernel,
12
15
  make_element_jacobian_kernel,
13
16
  make_sparsity_pattern,
@@ -17,14 +20,22 @@ from .cg import cg_solve, cg_solve_jax
17
20
  from .preconditioner import make_block_jacobi_preconditioner
18
21
  from .result import SolverResult
19
22
  from .sparse import SparsityPattern, FluxSparseMatrix
20
- from .dirichlet import _normalize_dirichlet
23
+ from .dirichlet import DirichletBC, _normalize_dirichlet
24
+
25
+ if TYPE_CHECKING:
26
+ from jax import Array as JaxArray
27
+
28
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
29
+ else:
30
+ ArrayLike: TypeAlias = np.ndarray
31
+ ExtraTerm: TypeAlias = Callable[[np.ndarray], tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Mapping[str, Any]] | None]
21
32
 
22
33
 
23
34
  def newton_solve(
24
35
  space,
25
- res_form,
26
- u0,
27
- params,
36
+ res_form: ResidualForm[Any],
37
+ u0: ArrayLike,
38
+ params: Any,
28
39
  *,
29
40
  tol: float = 1e-8,
30
41
  atol: float = 0.0,
@@ -32,18 +43,18 @@ def newton_solve(
32
43
  linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg", "cg_jax", "cg_custom", or "cg_matfree"
33
44
  linear_maxiter: int | None = None,
34
45
  linear_tol: float | None = None,
35
- linear_preconditioner=None,
46
+ linear_preconditioner: object | None = None,
36
47
  matfree_mode: str = "linearize",
37
- matfree_cache: dict | None = None,
38
- dirichlet=None,
39
- callback=None,
48
+ matfree_cache: dict[str, Any] | None = None,
49
+ dirichlet: tuple[np.ndarray, np.ndarray] | None = None,
50
+ callback: Callable[[np.ndarray, SolverResult], Any] | None = None,
40
51
  line_search: bool = False,
41
52
  max_ls: int = 10,
42
53
  ls_c: float = 1e-4,
43
- external_vector=None,
44
- jacobian_pattern=None,
45
- extra_terms=None,
46
- ):
54
+ external_vector: np.ndarray | None = None,
55
+ jacobian_pattern: SparsityPattern | None = None,
56
+ extra_terms: list[ExtraTerm] | None = None,
57
+ ) -> tuple[np.ndarray, SolverResult]:
47
58
  """
48
59
  Gridap-style Newton–Raphson solver on free DOFs only.
49
60
 
@@ -61,6 +72,8 @@ def newton_solve(
61
72
  """
62
73
 
63
74
  if dirichlet is not None:
75
+ if isinstance(dirichlet, DirichletBC):
76
+ dirichlet = dirichlet.as_tuple()
64
77
  dir_dofs, dir_vals = dirichlet
65
78
  dir_dofs, dir_vals = _normalize_dirichlet(dir_dofs, dir_vals)
66
79
  if dir_vals.ndim == 0:
fluxfem/solver/petsc.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import numpy as np
4
4
  import time
5
5
  import warnings
6
- from typing import Any, Callable
6
+ from typing import Any, Callable, TypeAlias
7
7
 
8
8
  try:
9
9
  import scipy.sparse as sp
@@ -12,6 +12,10 @@ except Exception: # pragma: no cover
12
12
 
13
13
  from .sparse import FluxSparseMatrix
14
14
 
15
+ ArrayLike: TypeAlias = np.ndarray
16
+ MatVec: TypeAlias = Callable[[np.ndarray], np.ndarray]
17
+ SolveInfo: TypeAlias = dict[str, Any]
18
+
15
19
 
16
20
  def petsc_is_available() -> bool:
17
21
  try:
@@ -31,7 +35,9 @@ def _require_petsc4py():
31
35
  raise ImportError("petsc4py is required for PETSc solves. Install with the petsc extra.") from exc
32
36
 
33
37
 
34
- def _coo_to_csr(rows, cols, data, n_dofs: int):
38
+ def _coo_to_csr(
39
+ rows: ArrayLike, cols: ArrayLike, data: ArrayLike, n_dofs: int
40
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
35
41
  r = np.asarray(rows, dtype=np.int64)
36
42
  c = np.asarray(cols, dtype=np.int64)
37
43
  d = np.asarray(data)
@@ -71,7 +77,7 @@ def _infer_n_dofs(K: Any, F: Any | None, n_dofs: int | None) -> int:
71
77
  raise ValueError("n_dofs is required when operator shape is not available.")
72
78
 
73
79
 
74
- def _matvec_builder(A: Any):
80
+ def _matvec_builder(A: Any) -> MatVec:
75
81
  if isinstance(A, FluxSparseMatrix):
76
82
  return lambda x: np.asarray(A.matvec(x))
77
83
  if hasattr(A, "matvec"):
@@ -89,7 +95,7 @@ def _matvec_builder(A: Any):
89
95
  return mv
90
96
 
91
97
 
92
- def _diag_from_coo(rows, cols, data, n_dofs: int) -> np.ndarray:
98
+ def _diag_from_coo(rows: ArrayLike, cols: ArrayLike, data: ArrayLike, n_dofs: int) -> np.ndarray:
93
99
  r = np.asarray(rows, dtype=np.int64)
94
100
  c = np.asarray(cols, dtype=np.int64)
95
101
  d = np.asarray(data)
@@ -119,7 +125,7 @@ def _diag_from_operator(A: Any, n_dofs: int) -> np.ndarray:
119
125
  raise ValueError("diag0 preconditioner requires access to the matrix diagonal.")
120
126
 
121
127
 
122
- def _as_csr(K: Any):
128
+ def _as_csr(K: Any) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
123
129
  if isinstance(K, FluxSparseMatrix):
124
130
  rows, cols, data, n_dofs = K.to_coo()
125
131
  indptr, indices, data = _coo_to_csr(rows, cols, data, int(n_dofs))
@@ -158,7 +164,7 @@ def petsc_solve(
158
164
  atol: float | None = None,
159
165
  max_it: int | None = None,
160
166
  options: dict[str, Any] | None = None,
161
- ) -> np.ndarray | tuple[np.ndarray, dict[str, Any]]:
167
+ ) -> np.ndarray:
162
168
  """
163
169
  Solve K u = F using PETSc.
164
170
 
@@ -235,7 +241,7 @@ def petsc_shell_solve(
235
241
  options: dict[str, Any] | None = None,
236
242
  options_prefix: str | None = "fluxfem_",
237
243
  return_info: bool = False,
238
- ) -> np.ndarray:
244
+ ) -> np.ndarray | tuple[np.ndarray, SolveInfo]:
239
245
  """
240
246
  Solve A x = F using PETSc with a matrix-free Shell Mat.
241
247
 
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Iterable, Sequence
3
+ from typing import Callable, Iterable, Sequence, TypeAlias
4
4
 
5
5
  import numpy as np
6
6
  import jax.numpy as jnp
@@ -12,6 +12,9 @@ except Exception: # pragma: no cover
12
12
 
13
13
  from .sparse import FluxSparseMatrix
14
14
 
15
+ ArrayLike: TypeAlias = jnp.ndarray
16
+ Preconditioner: TypeAlias = Callable[[jnp.ndarray], jnp.ndarray]
17
+
15
18
 
16
19
  def _extract_block_sizes(
17
20
  n: int,
@@ -51,11 +54,11 @@ def _extract_block_sizes(
51
54
 
52
55
 
53
56
  def make_block_jacobi_preconditioner(
54
- A,
57
+ A: FluxSparseMatrix | "jsparse.BCOO",
55
58
  *,
56
59
  dof_per_node: int | None = None,
57
60
  block_sizes: Sequence[int] | None = None,
58
- ):
61
+ ) -> Preconditioner:
59
62
  """
60
63
  Build block Jacobi preconditioner for blocked DOF layouts.
61
64
 
@@ -98,7 +101,7 @@ def make_block_jacobi_preconditioner(
98
101
  blocks = blocks + 1e-12 * jnp.eye(block_size)[None, :, :]
99
102
  inv_blocks = jnp.linalg.inv(blocks)
100
103
 
101
- def precon(r):
104
+ def precon(r: jnp.ndarray) -> jnp.ndarray:
102
105
  rb = r.reshape((n_block, block_size))
103
106
  zb = jnp.einsum("bij,bj->bi", inv_blocks, rb)
104
107
  return zb.reshape((-1,))
@@ -3,22 +3,34 @@ from __future__ import annotations
3
3
  import time
4
4
  from dataclasses import dataclass, field
5
5
  import warnings
6
- from typing import Any, Callable, Iterable, List, Sequence
6
+ from typing import Any, Callable, Iterable, List, Sequence, TYPE_CHECKING, TypeAlias
7
7
 
8
8
  import numpy as np
9
9
  import jax.numpy as jnp
10
10
 
11
- from ..core.assembly import assemble_bilinear_form
11
+ from ..core.assembly import FormKernel, ResidualForm, assemble_bilinear_form
12
12
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
13
13
  from .cg import cg_solve, cg_solve_jax
14
14
  from .petsc import petsc_shell_solve
15
15
  from .sparse import FluxSparseMatrix
16
- from .dirichlet import expand_dirichlet_solution
16
+ from .dirichlet import DirichletBC, expand_dirichlet_solution
17
17
  from .newton import newton_solve
18
18
  from .result import SolverResult
19
19
  from .history import NewtonIterRecord, LoadStepResult
20
20
  from ..tools.timer import SectionTimer, NullTimer
21
21
 
22
+ if TYPE_CHECKING:
23
+ from jax import Array as JaxArray
24
+
25
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
26
+ else:
27
+ ArrayLike: TypeAlias = np.ndarray
28
+ DirichletLike: TypeAlias = tuple[np.ndarray, np.ndarray]
29
+ ExtraTerm: TypeAlias = Callable[
30
+ [np.ndarray],
31
+ tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, dict[str, Any]] | None,
32
+ ]
33
+
22
34
 
23
35
  @dataclass
24
36
  class NonlinearAnalysis:
@@ -46,15 +58,19 @@ class NonlinearAnalysis:
46
58
  """
47
59
 
48
60
  space: Any
49
- residual_form: Any
61
+ residual_form: ResidualForm[Any]
50
62
  params: Any
51
- base_external_vector: Any | None = None
52
- dirichlet: tuple | None = None
53
- extra_terms: list[Callable] | None = None
63
+ base_external_vector: ArrayLike | None = None
64
+ dirichlet: DirichletLike | None = None
65
+ extra_terms: list[ExtraTerm] | None = None
54
66
  jacobian_pattern: Any | None = None
55
67
  dtype: Any = jnp.float64
56
68
 
57
- def external_for_load(self, load_factor: float):
69
+ def __post_init__(self) -> None:
70
+ if isinstance(self.dirichlet, DirichletBC):
71
+ self.dirichlet = self.dirichlet.as_tuple()
72
+
73
+ def external_for_load(self, load_factor: float) -> ArrayLike | None:
58
74
  if self.base_external_vector is None:
59
75
  return None
60
76
  return jnp.asarray(load_factor * self.base_external_vector, dtype=self.dtype)
@@ -105,14 +121,14 @@ class NewtonSolveRunner:
105
121
 
106
122
  def run(
107
123
  self,
108
- u0=None,
124
+ u0: ArrayLike | None = None,
109
125
  *,
110
126
  load_sequence: Sequence[float] | None = None,
111
- newton_callback: Callable | None = None,
127
+ newton_callback: Callable[[dict[str, Any]], None] | None = None,
112
128
  step_callback: Callable[[LoadStepResult], None] | None = None,
113
129
  timer: "SectionTimer | None" = None,
114
130
  report_timing: bool = True
115
- ):
131
+ ) -> tuple[np.ndarray, list[LoadStepResult]]:
116
132
  """
117
133
  Execute Newton solves over the configured load schedule.
118
134
 
@@ -276,7 +292,9 @@ class NewtonSolveRunner:
276
292
  return u, history
277
293
 
278
294
 
279
- def _condense_flux_dirichlet(K: FluxSparseMatrix, F, dirichlet):
295
+ def _condense_flux_dirichlet(
296
+ K: FluxSparseMatrix, F: ArrayLike, dirichlet: DirichletLike
297
+ ) -> tuple[Any, np.ndarray, np.ndarray | None, np.ndarray, np.ndarray, np.ndarray]:
280
298
  dir_dofs, dir_vals = dirichlet
281
299
  dir_arr = np.asarray(dir_dofs, dtype=int)
282
300
  dir_vals_arr = np.asarray(dir_vals, dtype=float)
@@ -294,12 +312,12 @@ def _condense_flux_dirichlet(K: FluxSparseMatrix, F, dirichlet):
294
312
 
295
313
  def solve_nonlinear(
296
314
  space,
297
- residual_form,
298
- params,
315
+ residual_form: ResidualForm[Any],
316
+ params: Any,
299
317
  *,
300
- dirichlet: tuple | None = None,
301
- base_external_vector=None,
302
- extra_terms=None,
318
+ dirichlet: DirichletLike | None = None,
319
+ base_external_vector: ArrayLike | None = None,
320
+ extra_terms: list[ExtraTerm] | None = None,
303
321
  dtype=jnp.float64,
304
322
  maxiter: int = 20,
305
323
  tol: float = 1e-8,
@@ -314,8 +332,8 @@ def solve_nonlinear(
314
332
  ls_c: float = 1e-4,
315
333
  n_steps: int = 1,
316
334
  jacobian_pattern=None,
317
- u0=None,
318
- ):
335
+ u0: ArrayLike | None = None,
336
+ ) -> tuple[np.ndarray, list[LoadStepResult]]:
319
337
  """
320
338
  Convenience wrapper: build NonlinearAnalysis and run NewtonSolveRunner.
321
339
  """
@@ -359,10 +377,10 @@ class LinearAnalysis:
359
377
  """
360
378
 
361
379
  space: Any
362
- bilinear_form: Any
380
+ bilinear_form: FormKernel[Any]
363
381
  params: Any
364
- base_rhs_vector: Any
365
- dirichlet: tuple | None = None
382
+ base_rhs_vector: ArrayLike
383
+ dirichlet: DirichletLike | None = None
366
384
  pattern: Any | None = None
367
385
  dtype: Any = jnp.float64
368
386
 
@@ -373,7 +391,7 @@ class LinearAnalysis:
373
391
  pattern=self.pattern,
374
392
  )
375
393
 
376
- def rhs_for_load(self, load_factor: float):
394
+ def rhs_for_load(self, load_factor: float) -> ArrayLike:
377
395
  return jnp.asarray(load_factor * self.base_rhs_vector, dtype=self.dtype)
378
396
 
379
397
 
@@ -437,7 +455,7 @@ class LinearStepResult:
437
455
  """
438
456
  info: SolverResult
439
457
  solve_time: float
440
- u: Any
458
+ u: ArrayLike
441
459
 
442
460
 
443
461
  class LinearSolveRunner:
fluxfem/solver/solver.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Any, TYPE_CHECKING, TypeAlias
4
+
3
5
  import numpy as np
4
6
  import jax.numpy as jnp
5
7
 
@@ -18,6 +20,16 @@ from .dirichlet import (
18
20
  from .sparse import FluxSparseMatrix
19
21
  from ..core.space import FESpace
20
22
 
23
+ if TYPE_CHECKING:
24
+ from jax import Array as JaxArray
25
+
26
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
27
+ else:
28
+ ArrayLike: TypeAlias = np.ndarray
29
+ DirichletLike: TypeAlias = DirichletBC | tuple[np.ndarray, np.ndarray]
30
+ SolveInfo: TypeAlias = dict[str, Any]
31
+ SolveReturn: TypeAlias = tuple[np.ndarray, SolveInfo]
32
+
21
33
 
22
34
  class LinearSolver:
23
35
  """
@@ -32,7 +44,7 @@ class LinearSolver:
32
44
  self.tol = tol
33
45
  self.maxiter = maxiter
34
46
 
35
- def _solve_free(self, A, b):
47
+ def _solve_free(self, A: Any, b: Any) -> SolveReturn:
36
48
  if self.method == "cg":
37
49
  x, info = cg_solve_jax(A, b, tol=self.tol, maxiter=self.maxiter)
38
50
  return np.asarray(x), {"iters": info.get("iters"), "converged": info.get("converged", True)}
@@ -59,13 +71,13 @@ class LinearSolver:
59
71
 
60
72
  def solve(
61
73
  self,
62
- A,
63
- b,
74
+ A: Any,
75
+ b: Any,
64
76
  *,
65
- dirichlet=None,
77
+ dirichlet: DirichletLike | None = None,
66
78
  dirichlet_mode: str = "condense",
67
79
  n_total: int | None = None,
68
- ):
80
+ ) -> SolveReturn:
69
81
  if dirichlet is None:
70
82
  return self._solve_free(A, b)
71
83
 
@@ -109,8 +121,8 @@ class NonlinearSolver:
109
121
  def __init__(
110
122
  self,
111
123
  space: FESpace,
112
- res_form,
113
- params,
124
+ res_form: Any,
125
+ params: Any,
114
126
  *,
115
127
  tol: float = 1e-8,
116
128
  maxiter: int = 20,
@@ -119,10 +131,10 @@ class NonlinearSolver:
119
131
  max_ls: int = 10,
120
132
  ls_c: float = 1e-4,
121
133
  linear_tol: float | None = None,
122
- dirichlet=None,
123
- external_vector=None,
134
+ dirichlet: DirichletLike | None = None,
135
+ external_vector: ArrayLike | None = None,
124
136
  linear_maxiter: int | None = None,
125
- jacobian_pattern=None,
137
+ jacobian_pattern: Any | None = None,
126
138
  ):
127
139
  self.space = space
128
140
  self.res_form = res_form
@@ -139,7 +151,7 @@ class NonlinearSolver:
139
151
  self.linear_maxiter = linear_maxiter
140
152
  self.jacobian_pattern = jacobian_pattern
141
153
 
142
- def solve(self, u0):
154
+ def solve(self, u0: ArrayLike):
143
155
  return newton_solve(
144
156
  self.space,
145
157
  self.res_form,