fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fluxfem/__init__.py +1 -13
  2. fluxfem/core/__init__.py +53 -71
  3. fluxfem/core/assembly.py +41 -32
  4. fluxfem/core/basis.py +2 -2
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/mixed_space.py +42 -8
  7. fluxfem/core/mixed_weakform.py +1 -1
  8. fluxfem/core/space.py +68 -28
  9. fluxfem/core/weakform.py +95 -77
  10. fluxfem/mesh/base.py +3 -3
  11. fluxfem/mesh/contact.py +33 -17
  12. fluxfem/mesh/io.py +3 -2
  13. fluxfem/mesh/mortar.py +106 -43
  14. fluxfem/mesh/supermesh.py +2 -0
  15. fluxfem/mesh/surface.py +82 -22
  16. fluxfem/mesh/tet.py +7 -4
  17. fluxfem/physics/elasticity/hyperelastic.py +32 -3
  18. fluxfem/physics/elasticity/linear.py +13 -2
  19. fluxfem/physics/elasticity/stress.py +9 -5
  20. fluxfem/physics/operators.py +12 -5
  21. fluxfem/physics/postprocess.py +29 -3
  22. fluxfem/solver/__init__.py +6 -1
  23. fluxfem/solver/block_matrix.py +165 -13
  24. fluxfem/solver/block_system.py +52 -29
  25. fluxfem/solver/cg.py +43 -30
  26. fluxfem/solver/dirichlet.py +35 -12
  27. fluxfem/solver/history.py +15 -3
  28. fluxfem/solver/newton.py +25 -12
  29. fluxfem/solver/petsc.py +13 -7
  30. fluxfem/solver/preconditioner.py +7 -4
  31. fluxfem/solver/solve_runner.py +42 -24
  32. fluxfem/solver/solver.py +23 -11
  33. fluxfem/solver/sparse.py +32 -13
  34. fluxfem/tools/jit.py +19 -7
  35. fluxfem/tools/timer.py +14 -12
  36. fluxfem/tools/visualizer.py +16 -4
  37. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
  38. fluxfem-0.2.1.dist-info/RECORD +59 -0
  39. fluxfem-0.2.0.dist-info/RECORD +0 -59
  40. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  41. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Mapping, Sequence
3
+ from collections.abc import Mapping as AbcMapping
4
+ from typing import Any, Iterator, Mapping, Sequence, TypeAlias
4
5
 
5
6
  import numpy as np
6
7
 
@@ -12,12 +13,16 @@ except Exception: # pragma: no cover
12
13
  from .block_system import split_block_matrix
13
14
  from .sparse import FluxSparseMatrix
14
15
 
16
+ MatrixLike: TypeAlias = Any
17
+ FieldKey: TypeAlias = str | int
18
+ BlockMap: TypeAlias = dict[FieldKey, dict[FieldKey, MatrixLike]]
15
19
 
16
- def diag(**blocks):
20
+
21
+ def diag(**blocks: MatrixLike) -> dict[str, MatrixLike]:
17
22
  return dict(blocks)
18
23
 
19
24
 
20
- def _infer_sizes_from_diag(diag_blocks):
25
+ def _infer_sizes_from_diag(diag_blocks: Mapping[FieldKey, MatrixLike]) -> dict[FieldKey, int]:
21
26
  sizes = {}
22
27
  for name, blk in diag_blocks.items():
23
28
  if isinstance(blk, FluxSparseMatrix):
@@ -35,7 +40,19 @@ def _infer_sizes_from_diag(diag_blocks):
35
40
  return sizes
36
41
 
37
42
 
38
- def _add_blocks(a, b):
43
+ def _infer_format(blocks: AbcMapping[FieldKey, AbcMapping[FieldKey, MatrixLike]], fmt: str) -> str:
44
+ if fmt != "auto":
45
+ return fmt
46
+ for row in blocks.values():
47
+ for blk in row.values():
48
+ if isinstance(blk, FluxSparseMatrix):
49
+ return "flux"
50
+ if sp is not None and sp.issparse(blk):
51
+ return "csr"
52
+ return "dense"
53
+
54
+
55
+ def _add_blocks(a: MatrixLike | None, b: MatrixLike | None) -> MatrixLike | None:
39
56
  if a is None:
40
57
  return b
41
58
  if b is None:
@@ -53,7 +70,7 @@ def _add_blocks(a, b):
53
70
  return np.asarray(a) + np.asarray(b)
54
71
 
55
72
 
56
- def _transpose_block(block, rule: str):
73
+ def _transpose_block(block: MatrixLike, rule: str) -> MatrixLike:
57
74
  if isinstance(block, FluxSparseMatrix):
58
75
  if sp is None:
59
76
  raise ImportError("scipy is required to transpose FluxSparseMatrix blocks.")
@@ -67,17 +84,146 @@ def _transpose_block(block, rule: str):
67
84
  return out
68
85
 
69
86
 
87
+ class FluxBlockMatrix(AbcMapping[FieldKey, Mapping[FieldKey, MatrixLike]]):
88
+ """
89
+ Lazy block-matrix container that assembles on demand.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ blocks: BlockMap,
95
+ *,
96
+ sizes: Mapping[FieldKey, int],
97
+ order: Sequence[FieldKey] | None = None,
98
+ symmetric: bool = False,
99
+ transpose_rule: str = "T",
100
+ ) -> None:
101
+ self._blocks = blocks
102
+ self.sizes = {name: int(size) for name, size in sizes.items()}
103
+ self.field_order = tuple(order) if order is not None else tuple(self.sizes.keys())
104
+ self.symmetric = bool(symmetric)
105
+ self.transpose_rule = transpose_rule
106
+ for name in self.field_order:
107
+ if name not in self.sizes:
108
+ raise KeyError(f"Missing size for field '{name}'")
109
+
110
+ def __getitem__(self, key: FieldKey) -> Mapping[FieldKey, MatrixLike]:
111
+ return self._blocks[key]
112
+
113
+ def __iter__(self) -> Iterator[FieldKey]:
114
+ return iter(self._blocks)
115
+
116
+ def __len__(self) -> int:
117
+ return len(self._blocks)
118
+
119
+ @property
120
+ def blocks(self) -> BlockMap:
121
+ return self._blocks
122
+
123
+ def assemble(self, *, format: str = "flux") -> MatrixLike:
124
+ if format not in {"auto", "flux", "csr", "dense"}:
125
+ raise ValueError("format must be one of: auto, flux, csr, dense")
126
+ use_format = _infer_format(self._blocks, format)
127
+
128
+ offsets = {}
129
+ offset = 0
130
+ for name in self.field_order:
131
+ size = int(self.sizes[name])
132
+ offsets[name] = offset
133
+ offset += size
134
+ n_total = offset
135
+
136
+ def _block_shape(name_i: FieldKey, name_j: FieldKey) -> tuple[int, int]:
137
+ return (int(self.sizes[name_i]), int(self.sizes[name_j]))
138
+
139
+ if use_format == "flux":
140
+ rows_list = []
141
+ cols_list = []
142
+ data_list = []
143
+ for name_i in self.field_order:
144
+ row_blocks = self._blocks.get(name_i, {})
145
+ for name_j in self.field_order:
146
+ blk = row_blocks.get(name_j)
147
+ if blk is None:
148
+ continue
149
+ shape = _block_shape(name_i, name_j)
150
+ if isinstance(blk, FluxSparseMatrix):
151
+ if shape[0] != shape[1] or int(blk.n_dofs) != shape[0]:
152
+ raise ValueError(f"Block {name_i},{name_j} has incompatible FluxSparseMatrix size")
153
+ r = np.asarray(blk.pattern.rows, dtype=np.int64)
154
+ c = np.asarray(blk.pattern.cols, dtype=np.int64)
155
+ d = np.asarray(blk.data)
156
+ elif sp is not None and sp.issparse(blk):
157
+ coo = blk.tocoo()
158
+ r = np.asarray(coo.row, dtype=np.int64)
159
+ c = np.asarray(coo.col, dtype=np.int64)
160
+ d = np.asarray(coo.data)
161
+ if coo.shape != shape:
162
+ raise ValueError(f"Block {name_i},{name_j} has shape {coo.shape}, expected {shape}")
163
+ else:
164
+ arr = np.asarray(blk)
165
+ if arr.shape != shape:
166
+ raise ValueError(f"Block {name_i},{name_j} has shape {arr.shape}, expected {shape}")
167
+ r, c = np.nonzero(arr)
168
+ d = arr[r, c]
169
+ if r.size:
170
+ rows_list.append(r + offsets[name_i])
171
+ cols_list.append(c + offsets[name_j])
172
+ data_list.append(d)
173
+ rows = np.concatenate(rows_list) if rows_list else np.asarray([], dtype=np.int32)
174
+ cols = np.concatenate(cols_list) if cols_list else np.asarray([], dtype=np.int32)
175
+ data = np.concatenate(data_list) if data_list else np.asarray([], dtype=float)
176
+ return FluxSparseMatrix(rows, cols, data, n_total)
177
+
178
+ if use_format == "csr" and sp is None:
179
+ raise ImportError("scipy is required for CSR block systems.")
180
+ block_rows = []
181
+ for name_i in self.field_order:
182
+ row = []
183
+ row_blocks = self._blocks.get(name_i, {})
184
+ for name_j in self.field_order:
185
+ blk = row_blocks.get(name_j)
186
+ shape = _block_shape(name_i, name_j)
187
+ if blk is None:
188
+ if use_format == "csr":
189
+ row.append(sp.csr_matrix(shape))
190
+ else:
191
+ row.append(np.zeros(shape, dtype=float))
192
+ continue
193
+ if isinstance(blk, FluxSparseMatrix):
194
+ if sp is None:
195
+ raise ImportError("scipy is required to assemble sparse block systems.")
196
+ blk = blk.to_csr()
197
+ if sp is not None and sp.issparse(blk):
198
+ blk = blk.tocsr()
199
+ if blk.shape != shape:
200
+ raise ValueError(f"Block {name_i},{name_j} has shape {blk.shape}, expected {shape}")
201
+ row.append(blk)
202
+ else:
203
+ arr = np.asarray(blk)
204
+ if arr.shape != shape:
205
+ raise ValueError(f"Block {name_i},{name_j} has shape {arr.shape}, expected {shape}")
206
+ if use_format == "csr":
207
+ row.append(sp.csr_matrix(arr))
208
+ else:
209
+ row.append(arr)
210
+ block_rows.append(row)
211
+ if use_format == "csr":
212
+ return sp.bmat(block_rows, format="csr")
213
+ return np.block(block_rows)
214
+
215
+
70
216
  def make(
71
217
  *,
72
- diag: Mapping[str, object] | Sequence[object],
73
- rel: Mapping[tuple[str, str], object] | None = None,
74
- add_contiguous: object | None = None,
75
- sizes: Mapping[str, int] | None = None,
218
+ diag: Mapping[FieldKey, MatrixLike] | Sequence[MatrixLike],
219
+ rel: Mapping[tuple[FieldKey, FieldKey], MatrixLike] | None = None,
220
+ add_contiguous: MatrixLike | None = None,
221
+ sizes: Mapping[FieldKey, int] | None = None,
76
222
  symmetric: bool = False,
77
223
  transpose_rule: str = "T",
78
- ):
224
+ ) -> FluxBlockMatrix:
79
225
  """
80
- Build a blocks dict from diagonal blocks, optional relations, and a full matrix.
226
+ Build a lazy FluxBlockMatrix from diagonal blocks, optional relations, and a full matrix.
81
227
  """
82
228
  if isinstance(diag, Mapping):
83
229
  diag_map = dict(diag)
@@ -126,7 +272,13 @@ def make(
126
272
  _transpose_block(blk, transpose_rule),
127
273
  )
128
274
 
129
- return blocks
275
+ return FluxBlockMatrix(
276
+ blocks,
277
+ sizes=sizes,
278
+ order=order,
279
+ symmetric=symmetric,
280
+ transpose_rule=transpose_rule,
281
+ )
130
282
 
131
283
 
132
- __all__ = ["diag", "make"]
284
+ __all__ = ["FluxBlockMatrix", "diag", "make"]
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Mapping, Sequence
4
+ from typing import Any, Mapping, Sequence, TypeAlias
5
5
 
6
6
  import numpy as np
7
7
 
@@ -13,23 +13,27 @@ except Exception: # pragma: no cover
13
13
  from .dirichlet import DirichletBC, free_dofs
14
14
  from .sparse import FluxSparseMatrix
15
15
 
16
+ MatrixLike: TypeAlias = Any
17
+ FieldKey: TypeAlias = str | int
18
+ BlockMap: TypeAlias = Mapping[FieldKey, Mapping[FieldKey, MatrixLike]]
19
+
16
20
 
17
21
  @dataclass(frozen=True)
18
22
  class BlockSystem:
19
- K: object
23
+ K: MatrixLike
20
24
  F: np.ndarray
21
25
  free_dofs: np.ndarray
22
26
  dirichlet: DirichletBC
23
- field_order: tuple[str, ...]
24
- field_slices: dict[str, slice]
27
+ field_order: tuple[FieldKey, ...]
28
+ field_slices: dict[FieldKey, slice]
25
29
 
26
- def expand(self, u_free):
30
+ def expand(self, u_free: np.ndarray) -> np.ndarray:
27
31
  return self.dirichlet.expand_solution(u_free, free=self.free_dofs, n_total=self.F.shape[0])
28
32
 
29
- def split(self, u_full: np.ndarray) -> dict[str, np.ndarray]:
33
+ def split(self, u_full: np.ndarray) -> dict[FieldKey, np.ndarray]:
30
34
  return {name: np.asarray(u_full)[self.field_slices[name]] for name in self.field_order}
31
35
 
32
- def join(self, fields: Mapping[str, np.ndarray]) -> np.ndarray:
36
+ def join(self, fields: Mapping[FieldKey, np.ndarray]) -> np.ndarray:
33
37
  parts = []
34
38
  for name in self.field_order:
35
39
  if name not in fields:
@@ -38,7 +42,9 @@ class BlockSystem:
38
42
  return np.concatenate(parts, axis=0)
39
43
 
40
44
 
41
- def _build_field_slices(order, sizes):
45
+ def _build_field_slices(
46
+ order: Sequence[FieldKey], sizes: Mapping[FieldKey, int]
47
+ ) -> tuple[dict[FieldKey, int], dict[FieldKey, slice], int]:
42
48
  offsets = {}
43
49
  slices = {}
44
50
  offset = 0
@@ -50,7 +56,12 @@ def _build_field_slices(order, sizes):
50
56
  return offsets, slices, offset
51
57
 
52
58
 
53
- def split_block_matrix(matrix, *, sizes: Mapping[str, int], order: Sequence[str] | None = None):
59
+ def split_block_matrix(
60
+ matrix: MatrixLike,
61
+ *,
62
+ sizes: Mapping[FieldKey, int],
63
+ order: Sequence[FieldKey] | None = None,
64
+ ) -> dict[FieldKey, dict[FieldKey, MatrixLike]]:
54
65
  """
55
66
  Split a block matrix into a dict-of-dicts by field order and sizes.
56
67
  """
@@ -72,7 +83,7 @@ def split_block_matrix(matrix, *, sizes: Mapping[str, int], order: Sequence[str]
72
83
  if mat.shape != (n_total, n_total):
73
84
  raise ValueError(f"matrix has shape {mat.shape}, expected {(n_total, n_total)}")
74
85
 
75
- blocks: dict[str, dict[str, object]] = {}
86
+ blocks: dict[FieldKey, dict[FieldKey, MatrixLike]] = {}
76
87
  for name_i in field_order:
77
88
  row = {}
78
89
  i0 = offsets[name_i]
@@ -85,7 +96,7 @@ def split_block_matrix(matrix, *, sizes: Mapping[str, int], order: Sequence[str]
85
96
  return blocks
86
97
 
87
98
 
88
- def _infer_format(blocks, fmt):
99
+ def _infer_format(blocks: BlockMap, fmt: str) -> str:
89
100
  if fmt != "auto":
90
101
  return fmt
91
102
  for row in blocks.values():
@@ -97,7 +108,7 @@ def _infer_format(blocks, fmt):
97
108
  return "dense"
98
109
 
99
110
 
100
- def _infer_sizes_from_diag_seq(diag_seq):
111
+ def _infer_sizes_from_diag_seq(diag_seq: Sequence[MatrixLike]) -> dict[int, int]:
101
112
  sizes = {}
102
113
  for idx, blk in enumerate(diag_seq):
103
114
  if isinstance(blk, FluxSparseMatrix):
@@ -115,7 +126,11 @@ def _infer_sizes_from_diag_seq(diag_seq):
115
126
  return sizes
116
127
 
117
128
 
118
- def _coerce_rhs(rhs, order, sizes):
129
+ def _coerce_rhs(
130
+ rhs: MatrixLike | Sequence[MatrixLike] | Mapping[FieldKey, MatrixLike] | None,
131
+ order: Sequence[FieldKey],
132
+ sizes: Mapping[FieldKey, int],
133
+ ) -> np.ndarray:
119
134
  if rhs is None:
120
135
  return np.zeros(sum(int(sizes[n]) for n in order), dtype=float)
121
136
  if isinstance(rhs, Mapping):
@@ -139,7 +154,9 @@ def _coerce_rhs(rhs, order, sizes):
139
154
  return np.concatenate(parts, axis=0)
140
155
 
141
156
 
142
- def _build_dirichlet_from_fields(fields, offsets, *, merge: str):
157
+ def _build_dirichlet_from_fields(
158
+ fields: Mapping[FieldKey, object], offsets: Mapping[FieldKey, int], *, merge: str
159
+ ) -> DirichletBC:
143
160
  if merge not in {"check_equal", "error", "first", "last"}:
144
161
  raise ValueError("merge must be one of: check_equal, error, first, last")
145
162
  dof_map: dict[int, float] = {}
@@ -174,7 +191,13 @@ def _build_dirichlet_from_fields(fields, offsets, *, merge: str):
174
191
  return DirichletBC(dofs_sorted, vals_sorted)
175
192
 
176
193
 
177
- def _build_dirichlet_from_sequence(seq, order, offsets, *, merge: str):
194
+ def _build_dirichlet_from_sequence(
195
+ seq: Sequence[object | None],
196
+ order: Sequence[FieldKey],
197
+ offsets: Mapping[FieldKey, int],
198
+ *,
199
+ merge: str,
200
+ ) -> DirichletBC:
178
201
  if merge not in {"check_equal", "error", "first", "last"}:
179
202
  raise ValueError("merge must be one of: check_equal, error, first, last")
180
203
  if len(seq) != len(order):
@@ -211,7 +234,7 @@ def _build_dirichlet_from_sequence(seq, order, offsets, *, merge: str):
211
234
  return DirichletBC(dofs_sorted, vals_sorted)
212
235
 
213
236
 
214
- def _transpose_block(block, rule: str):
237
+ def _transpose_block(block: MatrixLike, rule: str) -> MatrixLike:
215
238
  if isinstance(block, FluxSparseMatrix):
216
239
  if sp is None:
217
240
  raise ImportError("scipy is required to transpose FluxSparseMatrix blocks.")
@@ -225,7 +248,7 @@ def _transpose_block(block, rule: str):
225
248
  return out
226
249
 
227
250
 
228
- def _add_blocks(a, b):
251
+ def _add_blocks(a: MatrixLike | None, b: MatrixLike | None) -> MatrixLike | None:
229
252
  if a is None:
230
253
  return b
231
254
  if b is None:
@@ -245,14 +268,14 @@ def _add_blocks(a, b):
245
268
 
246
269
  def _blocks_from_diag_rel(
247
270
  *,
248
- diag: Mapping[str, object] | Sequence[object],
249
- sizes: Mapping[str, int],
250
- order: Sequence[str],
251
- rel: Mapping[tuple[str, str], object] | None = None,
252
- add_contiguous: object | None = None,
271
+ diag: Mapping[FieldKey, MatrixLike] | Sequence[MatrixLike],
272
+ sizes: Mapping[FieldKey, int],
273
+ order: Sequence[FieldKey],
274
+ rel: Mapping[tuple[FieldKey, FieldKey], MatrixLike] | None = None,
275
+ add_contiguous: MatrixLike | None = None,
253
276
  symmetric: bool = False,
254
277
  transpose_rule: str = "T",
255
- ) -> Mapping[str, Mapping[str, object]]:
278
+ ) -> BlockMap:
256
279
  if isinstance(diag, Mapping):
257
280
  diag_map = dict(diag)
258
281
  else:
@@ -296,12 +319,12 @@ def _blocks_from_diag_rel(
296
319
 
297
320
  def build_block_system(
298
321
  *,
299
- diag: Mapping[str, object] | Sequence[object],
300
- sizes: Mapping[str, int] | None = None,
301
- rel: Mapping[tuple[str, str], object] | None = None,
302
- add_contiguous: object | None = None,
303
- rhs: Mapping[str, object] | Sequence[object] | np.ndarray | None = None,
304
- constraints=None,
322
+ diag: Mapping[FieldKey, MatrixLike] | Sequence[MatrixLike],
323
+ sizes: Mapping[FieldKey, int] | None = None,
324
+ rel: Mapping[tuple[FieldKey, FieldKey], MatrixLike] | None = None,
325
+ add_contiguous: MatrixLike | None = None,
326
+ rhs: Mapping[FieldKey, MatrixLike] | Sequence[MatrixLike] | np.ndarray | None = None,
327
+ constraints: object | None = None,
305
328
  merge: str = "check_equal",
306
329
  format: str = "auto",
307
330
  symmetric: bool = False,
fluxfem/solver/cg.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Any, Callable, TypeAlias
4
+
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
  import jax.scipy as jsp
@@ -14,8 +16,12 @@ from dataclasses import dataclass
14
16
 
15
17
  from .preconditioner import make_block_jacobi_preconditioner
16
18
 
19
+ ArrayLike: TypeAlias = jnp.ndarray
20
+ MatVec: TypeAlias = Callable[[jnp.ndarray], jnp.ndarray]
21
+ CGInfo: TypeAlias = dict[str, Any]
22
+
17
23
 
18
- def _matvec_builder(A):
24
+ def _matvec_builder(A: Any) -> MatVec:
19
25
  if jsparse is not None and isinstance(A, jsparse.BCOO):
20
26
  return lambda x: A @ x
21
27
  if isinstance(A, FluxSparseMatrix):
@@ -33,7 +39,7 @@ def _matvec_builder(A):
33
39
  return mv
34
40
 
35
41
 
36
- def _coo_tuple_from_any(A):
42
+ def _coo_tuple_from_any(A: Any):
37
43
  if isinstance(A, FluxSparseMatrix):
38
44
  return A.to_coo()
39
45
  if isinstance(A, tuple) and len(A) == 4:
@@ -53,7 +59,7 @@ def _coo_tuple_from_any(A):
53
59
  return None
54
60
 
55
61
 
56
- def _to_flux_matrix(A):
62
+ def _to_flux_matrix(A: Any) -> FluxSparseMatrix:
57
63
  if isinstance(A, FluxSparseMatrix):
58
64
  return A
59
65
  coo = _coo_tuple_from_any(A)
@@ -62,7 +68,7 @@ def _to_flux_matrix(A):
62
68
  return FluxSparseMatrix.from_bilinear(coo)
63
69
 
64
70
 
65
- def _to_bcoo_matrix(A):
71
+ def _to_bcoo_matrix(A: Any):
66
72
  if jsparse is None:
67
73
  raise ImportError("jax.experimental.sparse is required for BCOO matvec")
68
74
  if jsparse is not None and isinstance(A, jsparse.BCOO):
@@ -75,7 +81,7 @@ def _to_bcoo_matrix(A):
75
81
  return jsparse.BCOO((data, idx), shape=(n, n))
76
82
 
77
83
 
78
- def _normalize_matvec_matrix(A, matvec: str):
84
+ def _normalize_matvec_matrix(A: Any, matvec: str):
79
85
  if matvec == "flux":
80
86
  return _to_flux_matrix(A)
81
87
  if matvec == "bcoo":
@@ -101,7 +107,14 @@ class CGOperator:
101
107
  preconditioner: object | None = None
102
108
  solver: str = "cg"
103
109
 
104
- def solve(self, b, *, x0=None, tol: float = 1e-8, maxiter: int | None = None):
110
+ def solve(
111
+ self,
112
+ b: jnp.ndarray,
113
+ *,
114
+ x0: jnp.ndarray | None = None,
115
+ tol: float = 1e-8,
116
+ maxiter: int | None = None,
117
+ ):
105
118
  if self.solver == "cg":
106
119
  return cg_solve(
107
120
  self.A,
@@ -124,13 +137,13 @@ class CGOperator:
124
137
 
125
138
 
126
139
  def build_cg_operator(
127
- A,
140
+ A: Any,
128
141
  *,
129
142
  matvec: str = "flux",
130
- preconditioner=None,
143
+ preconditioner: object | None = None,
131
144
  solver: str = "cg",
132
145
  dof_per_node: int | None = None,
133
- block_sizes=None,
146
+ block_sizes: object | None = None,
134
147
  ) -> CGOperator:
135
148
  """
136
149
  Normalize CG inputs into a single operator interface.
@@ -144,7 +157,7 @@ def build_cg_operator(
144
157
  return CGOperator(A=A_mat, preconditioner=precon, solver=solver)
145
158
 
146
159
 
147
- def _diag_builder(A, n: int):
160
+ def _diag_builder(A: Any, n: int) -> jnp.ndarray:
148
161
  """
149
162
  Build diagonal for a Jacobi preconditioner when available.
150
163
  """
@@ -171,14 +184,14 @@ def _diag_builder(A, n: int):
171
184
 
172
185
 
173
186
  def _cg_solve_single(
174
- A,
175
- b,
187
+ A: Any,
188
+ b: jnp.ndarray,
176
189
  *,
177
- x0=None,
190
+ x0: jnp.ndarray | None = None,
178
191
  tol: float = 1e-8,
179
192
  maxiter: int | None = None,
180
- preconditioner=None,
181
- ):
193
+ preconditioner: object | None = None,
194
+ ) -> tuple[jnp.ndarray, CGInfo]:
182
195
  """
183
196
  Conjugate gradient (Ax=b) in JAX.
184
197
  A: FluxSparseMatrix / (rows, cols, data, n) / dense array
@@ -241,14 +254,14 @@ def _cg_solve_single(
241
254
 
242
255
 
243
256
  def cg_solve(
244
- A,
245
- b,
257
+ A: Any,
258
+ b: jnp.ndarray,
246
259
  *,
247
- x0=None,
260
+ x0: jnp.ndarray | None = None,
248
261
  tol: float = 1e-8,
249
262
  maxiter: int | None = None,
250
- preconditioner=None,
251
- ):
263
+ preconditioner: object | None = None,
264
+ ) -> tuple[jnp.ndarray, CGInfo]:
252
265
  """
253
266
  Conjugate gradient (Ax=b) in JAX.
254
267
  Supports single RHS (n,) or multiple RHS (n, n_rhs).
@@ -290,14 +303,14 @@ def cg_solve(
290
303
 
291
304
 
292
305
  def _cg_solve_jax_single(
293
- A,
294
- b,
306
+ A: Any,
307
+ b: jnp.ndarray,
295
308
  *,
296
- x0=None,
309
+ x0: jnp.ndarray | None = None,
297
310
  tol: float = 1e-8,
298
311
  maxiter: int | None = None,
299
- preconditioner=None,
300
- ):
312
+ preconditioner: object | None = None,
313
+ ) -> tuple[jnp.ndarray, CGInfo]:
301
314
  """
302
315
  Conjugate gradient via jax.scipy.sparse.linalg.cg.
303
316
  A: FluxSparseMatrix / (rows, cols, data, n) / dense array / callable
@@ -359,14 +372,14 @@ def _cg_solve_jax_single(
359
372
 
360
373
 
361
374
  def cg_solve_jax(
362
- A,
363
- b,
375
+ A: Any,
376
+ b: jnp.ndarray,
364
377
  *,
365
- x0=None,
378
+ x0: jnp.ndarray | None = None,
366
379
  tol: float = 1e-8,
367
380
  maxiter: int | None = None,
368
- preconditioner=None,
369
- ):
381
+ preconditioner: object | None = None,
382
+ ) -> tuple[jnp.ndarray, CGInfo]:
370
383
  """
371
384
  Conjugate gradient via jax.scipy.sparse.linalg.cg.
372
385
  Supports single RHS (n,) or multiple RHS (n, n_rhs).