redbirdpy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redbirdpy/__init__.py +112 -0
- redbirdpy/analytical.py +927 -0
- redbirdpy/forward.py +589 -0
- redbirdpy/property.py +602 -0
- redbirdpy/recon.py +893 -0
- redbirdpy/solver.py +814 -0
- redbirdpy/utility.py +1117 -0
- redbirdpy-0.1.0.dist-info/METADATA +596 -0
- redbirdpy-0.1.0.dist-info/RECORD +13 -0
- redbirdpy-0.1.0.dist-info/WHEEL +5 -0
- redbirdpy-0.1.0.dist-info/licenses/LICENSE.txt +674 -0
- redbirdpy-0.1.0.dist-info/top_level.txt +1 -0
- redbirdpy-0.1.0.dist-info/zip-safe +1 -0
redbirdpy/solver.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Redbird Solver Module - Linear system solvers for FEM.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
femsolve: Main solver interface with automatic method selection
|
|
6
|
+
get_solver_info: Query available solver backends
|
|
7
|
+
|
|
8
|
+
Supported solvers:
|
|
9
|
+
Direct: pardiso, umfpack, cholmod, superlu
|
|
10
|
+
Iterative: blqmr, cg, cg+amg, gmres, bicgstab
|
|
11
|
+
|
|
12
|
+
Dependencies:
|
|
13
|
+
- blocksolver: For BLQMR iterative solver (complex symmetric systems)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"femsolve",
|
|
18
|
+
"get_solver_info",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
from scipy import sparse
|
|
23
|
+
from scipy.sparse.linalg import spsolve, cg, gmres, bicgstab, splu
|
|
24
|
+
from typing import Dict, Tuple, Optional, Union, List, Any
|
|
25
|
+
import warnings
|
|
26
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
27
|
+
import multiprocessing
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# Solver Backend Detection
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
_DIRECT_SOLVER = "superlu" # fallback
|
|
34
|
+
_HAS_UMFPACK = False
|
|
35
|
+
_HAS_CHOLMOD = False
|
|
36
|
+
_HAS_AMG = False
|
|
37
|
+
_HAS_BLQMR = False
|
|
38
|
+
|
|
39
|
+
# Pardiso references
|
|
40
|
+
_pardiso_solve = None
|
|
41
|
+
_pardiso_factorized = None
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
from pypardiso import spsolve as pardiso_solve, factorized as pardiso_factorized
|
|
45
|
+
|
|
46
|
+
_pardiso_solve = pardiso_solve
|
|
47
|
+
_pardiso_factorized = pardiso_factorized
|
|
48
|
+
_DIRECT_SOLVER = "pardiso"
|
|
49
|
+
except ImportError:
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
from scikits.umfpack import spsolve as umfpack_spsolve, splu as umfpack_splu
|
|
54
|
+
|
|
55
|
+
_HAS_UMFPACK = True
|
|
56
|
+
if _DIRECT_SOLVER == "superlu":
|
|
57
|
+
_DIRECT_SOLVER = "umfpack"
|
|
58
|
+
except ImportError:
|
|
59
|
+
umfpack_spsolve = None
|
|
60
|
+
umfpack_splu = None
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
from sksparse.cholmod import cholesky as _cholmod_cholesky
|
|
64
|
+
|
|
65
|
+
_HAS_CHOLMOD = True
|
|
66
|
+
except ImportError:
|
|
67
|
+
_cholmod_cholesky = None
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
import pyamg as _pyamg
|
|
71
|
+
|
|
72
|
+
_HAS_AMG = True
|
|
73
|
+
except ImportError:
|
|
74
|
+
_pyamg = None
|
|
75
|
+
|
|
76
|
+
# Import blocksolver for BLQMR
|
|
77
|
+
try:
|
|
78
|
+
from blocksolver import (
|
|
79
|
+
blqmr,
|
|
80
|
+
BLQMRResult,
|
|
81
|
+
BLQMR_EXT,
|
|
82
|
+
HAS_NUMBA,
|
|
83
|
+
make_preconditioner,
|
|
84
|
+
BLQMRWorkspace,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
_HAS_BLQMR = True
|
|
88
|
+
except ImportError:
|
|
89
|
+
blqmr = None
|
|
90
|
+
BLQMRResult = None
|
|
91
|
+
BLQMR_EXT = False
|
|
92
|
+
HAS_NUMBA = False
|
|
93
|
+
make_preconditioner = None
|
|
94
|
+
BLQMRWorkspace = None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# =============================================================================
|
|
98
|
+
# Parallel Solver Helper Functions
|
|
99
|
+
# =============================================================================
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _solve_blqmr_batch(args):
|
|
103
|
+
"""
|
|
104
|
+
Worker function for parallel BLQMR solving.
|
|
105
|
+
|
|
106
|
+
This function runs in a separate process and reconstructs the sparse
|
|
107
|
+
matrix from its components before solving.
|
|
108
|
+
"""
|
|
109
|
+
(
|
|
110
|
+
A_data,
|
|
111
|
+
A_indices,
|
|
112
|
+
A_indptr,
|
|
113
|
+
A_shape,
|
|
114
|
+
A_dtype,
|
|
115
|
+
rhs_batch,
|
|
116
|
+
tol,
|
|
117
|
+
maxiter,
|
|
118
|
+
precond_type,
|
|
119
|
+
droptol,
|
|
120
|
+
batch_id,
|
|
121
|
+
start_col,
|
|
122
|
+
) = args
|
|
123
|
+
|
|
124
|
+
# Import blqmr in worker process
|
|
125
|
+
from blocksolver import blqmr as worker_blqmr
|
|
126
|
+
|
|
127
|
+
# Reconstruct sparse matrix in worker process
|
|
128
|
+
A = sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape, dtype=A_dtype)
|
|
129
|
+
|
|
130
|
+
result = worker_blqmr(
|
|
131
|
+
A,
|
|
132
|
+
rhs_batch,
|
|
133
|
+
tol=tol,
|
|
134
|
+
maxiter=maxiter,
|
|
135
|
+
M1=None,
|
|
136
|
+
M2=None,
|
|
137
|
+
x0=None,
|
|
138
|
+
workspace=None,
|
|
139
|
+
precond_type=precond_type,
|
|
140
|
+
droptol=droptol,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
x_batch = result.x if result.x.ndim > 1 else result.x.reshape(-1, 1)
|
|
144
|
+
return batch_id, start_col, x_batch, result.flag, result.iter, result.relres
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _solve_iterative_column(args):
|
|
148
|
+
"""
|
|
149
|
+
Worker function for parallel iterative solving (gmres, bicgstab, cg).
|
|
150
|
+
|
|
151
|
+
Solves a single RHS column using the specified iterative method.
|
|
152
|
+
"""
|
|
153
|
+
(
|
|
154
|
+
A_data,
|
|
155
|
+
A_indices,
|
|
156
|
+
A_indptr,
|
|
157
|
+
A_shape,
|
|
158
|
+
A_dtype,
|
|
159
|
+
rhs_col,
|
|
160
|
+
col_idx,
|
|
161
|
+
solver_type,
|
|
162
|
+
tol,
|
|
163
|
+
maxiter,
|
|
164
|
+
use_amg,
|
|
165
|
+
) = args
|
|
166
|
+
|
|
167
|
+
from scipy.sparse.linalg import gmres, bicgstab, cg
|
|
168
|
+
from scipy import sparse
|
|
169
|
+
|
|
170
|
+
# Reconstruct sparse matrix in worker process
|
|
171
|
+
A = sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape, dtype=A_dtype)
|
|
172
|
+
|
|
173
|
+
# Setup preconditioner if AMG requested
|
|
174
|
+
M = None
|
|
175
|
+
if use_amg and solver_type == "cg":
|
|
176
|
+
try:
|
|
177
|
+
import pyamg
|
|
178
|
+
|
|
179
|
+
ml = pyamg.smoothed_aggregation_solver(A.tocsr())
|
|
180
|
+
M = ml.aspreconditioner()
|
|
181
|
+
except ImportError:
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
# Select solver
|
|
185
|
+
if solver_type == "gmres":
|
|
186
|
+
solver_func = gmres
|
|
187
|
+
elif solver_type == "bicgstab":
|
|
188
|
+
solver_func = bicgstab
|
|
189
|
+
elif solver_type == "cg":
|
|
190
|
+
solver_func = cg
|
|
191
|
+
else:
|
|
192
|
+
raise ValueError(f"Unknown solver type: {solver_type}")
|
|
193
|
+
|
|
194
|
+
# Solve
|
|
195
|
+
try:
|
|
196
|
+
x_col, info = solver_func(A, rhs_col, M=M, rtol=tol, maxiter=maxiter)
|
|
197
|
+
except TypeError:
|
|
198
|
+
# Older scipy versions use 'tol' instead of 'rtol'
|
|
199
|
+
x_col, info = solver_func(A, rhs_col, M=M, tol=tol, maxiter=maxiter)
|
|
200
|
+
|
|
201
|
+
return col_idx, x_col, info
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _blqmr_parallel(
|
|
205
|
+
Amat: sparse.spmatrix,
|
|
206
|
+
rhs: np.ndarray,
|
|
207
|
+
*,
|
|
208
|
+
tol: float,
|
|
209
|
+
maxiter: int,
|
|
210
|
+
rhsblock: int,
|
|
211
|
+
precond_type: int,
|
|
212
|
+
droptol: float,
|
|
213
|
+
nthread: int,
|
|
214
|
+
verbose: bool,
|
|
215
|
+
) -> Tuple[np.ndarray, int]:
|
|
216
|
+
"""
|
|
217
|
+
Solve BLQMR with multiple RHS in parallel using multiprocessing.
|
|
218
|
+
"""
|
|
219
|
+
n, ncol = rhs.shape
|
|
220
|
+
|
|
221
|
+
# Determine output dtype based on matrix and RHS
|
|
222
|
+
is_complex = np.iscomplexobj(Amat) or np.iscomplexobj(rhs)
|
|
223
|
+
out_dtype = np.complex128 if is_complex else np.float64
|
|
224
|
+
|
|
225
|
+
# Convert matrix to CSC for consistent serialization
|
|
226
|
+
Acsc = Amat.tocsc()
|
|
227
|
+
A_data = Acsc.data
|
|
228
|
+
A_indices = Acsc.indices
|
|
229
|
+
A_indptr = Acsc.indptr
|
|
230
|
+
A_shape = Acsc.shape
|
|
231
|
+
A_dtype = Acsc.dtype
|
|
232
|
+
|
|
233
|
+
# Prepare batches
|
|
234
|
+
batches = []
|
|
235
|
+
for batch_id, start in enumerate(range(0, ncol, rhsblock)):
|
|
236
|
+
end = min(start + rhsblock, ncol)
|
|
237
|
+
rhs_batch = np.ascontiguousarray(rhs[:, start:end])
|
|
238
|
+
batches.append(
|
|
239
|
+
(
|
|
240
|
+
A_data,
|
|
241
|
+
A_indices,
|
|
242
|
+
A_indptr,
|
|
243
|
+
A_shape,
|
|
244
|
+
A_dtype,
|
|
245
|
+
rhs_batch,
|
|
246
|
+
tol,
|
|
247
|
+
maxiter,
|
|
248
|
+
precond_type,
|
|
249
|
+
droptol,
|
|
250
|
+
batch_id,
|
|
251
|
+
start,
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Solve in parallel
|
|
256
|
+
x = np.zeros((n, ncol), dtype=out_dtype)
|
|
257
|
+
max_flag = 0
|
|
258
|
+
|
|
259
|
+
with ProcessPoolExecutor(max_workers=nthread) as executor:
|
|
260
|
+
results = executor.map(_solve_blqmr_batch, batches)
|
|
261
|
+
|
|
262
|
+
for batch_id, start_col, x_batch, batch_flag, niter, relres in results:
|
|
263
|
+
end_col = start_col + x_batch.shape[1]
|
|
264
|
+
if is_complex and not np.iscomplexobj(x_batch):
|
|
265
|
+
x[:, start_col:end_col] = x_batch.astype(out_dtype)
|
|
266
|
+
else:
|
|
267
|
+
x[:, start_col:end_col] = x_batch
|
|
268
|
+
max_flag = max(max_flag, batch_flag)
|
|
269
|
+
|
|
270
|
+
if verbose:
|
|
271
|
+
print(
|
|
272
|
+
f"blqmr [{start_col+1}:{end_col}] (worker {batch_id}): "
|
|
273
|
+
f"iter={niter}, relres={relres:.2e}, flag={batch_flag}"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
return x, max_flag
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _iterative_parallel(
|
|
280
|
+
Amat: sparse.spmatrix,
|
|
281
|
+
rhs: np.ndarray,
|
|
282
|
+
solver_type: str,
|
|
283
|
+
*,
|
|
284
|
+
tol: float,
|
|
285
|
+
maxiter: int,
|
|
286
|
+
nthread: int,
|
|
287
|
+
use_amg: bool = False,
|
|
288
|
+
verbose: bool = False,
|
|
289
|
+
) -> Tuple[np.ndarray, int]:
|
|
290
|
+
"""
|
|
291
|
+
Solve iterative methods (gmres, bicgstab, cg) in parallel.
|
|
292
|
+
|
|
293
|
+
Each RHS column is solved independently in a separate process.
|
|
294
|
+
"""
|
|
295
|
+
n, ncol = rhs.shape
|
|
296
|
+
|
|
297
|
+
is_complex = np.iscomplexobj(Amat) or np.iscomplexobj(rhs)
|
|
298
|
+
out_dtype = np.complex128 if is_complex else np.float64
|
|
299
|
+
|
|
300
|
+
# Convert matrix to CSC for serialization
|
|
301
|
+
Acsc = Amat.tocsc()
|
|
302
|
+
A_data = Acsc.data
|
|
303
|
+
A_indices = Acsc.indices
|
|
304
|
+
A_indptr = Acsc.indptr
|
|
305
|
+
A_shape = Acsc.shape
|
|
306
|
+
A_dtype = Acsc.dtype
|
|
307
|
+
|
|
308
|
+
# Prepare tasks - one per non-zero RHS column
|
|
309
|
+
tasks = []
|
|
310
|
+
for i in range(ncol):
|
|
311
|
+
if np.any(rhs[:, i] != 0):
|
|
312
|
+
rhs_col = np.ascontiguousarray(rhs[:, i])
|
|
313
|
+
tasks.append(
|
|
314
|
+
(
|
|
315
|
+
A_data,
|
|
316
|
+
A_indices,
|
|
317
|
+
A_indptr,
|
|
318
|
+
A_shape,
|
|
319
|
+
A_dtype,
|
|
320
|
+
rhs_col,
|
|
321
|
+
i,
|
|
322
|
+
solver_type,
|
|
323
|
+
tol,
|
|
324
|
+
maxiter,
|
|
325
|
+
use_amg,
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Solve in parallel
|
|
330
|
+
x = np.zeros((n, ncol), dtype=out_dtype)
|
|
331
|
+
max_flag = 0
|
|
332
|
+
|
|
333
|
+
with ProcessPoolExecutor(max_workers=nthread) as executor:
|
|
334
|
+
results = executor.map(_solve_iterative_column, tasks)
|
|
335
|
+
|
|
336
|
+
for col_idx, x_col, info in results:
|
|
337
|
+
x[:, col_idx] = x_col
|
|
338
|
+
max_flag = max(max_flag, info)
|
|
339
|
+
|
|
340
|
+
if verbose:
|
|
341
|
+
status = "converged" if info == 0 else f"flag={info}"
|
|
342
|
+
print(f"{solver_type} [col {col_idx+1}]: {status}")
|
|
343
|
+
|
|
344
|
+
return x, max_flag
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
# =============================================================================
|
|
348
|
+
# Main Solver Interface
|
|
349
|
+
# =============================================================================
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def femsolve(
|
|
353
|
+
Amat: sparse.spmatrix,
|
|
354
|
+
rhs: Union[np.ndarray, sparse.spmatrix],
|
|
355
|
+
method: str = "auto",
|
|
356
|
+
**kwargs,
|
|
357
|
+
) -> Tuple[np.ndarray, int]:
|
|
358
|
+
"""
|
|
359
|
+
Solve FEM linear system A*x = b with automatic solver selection.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
Amat : sparse matrix
|
|
364
|
+
System matrix
|
|
365
|
+
rhs : ndarray or sparse matrix
|
|
366
|
+
Right-hand side (n,) or (n, m) for m simultaneous RHS
|
|
367
|
+
method : str
|
|
368
|
+
'auto': Automatically select best solver
|
|
369
|
+
'pardiso': Intel MKL PARDISO (fastest, requires pypardiso)
|
|
370
|
+
'umfpack': UMFPACK (fast, requires scikit-umfpack)
|
|
371
|
+
'cholmod': CHOLMOD for SPD matrices (requires scikit-sparse)
|
|
372
|
+
'direct': Best available direct solver
|
|
373
|
+
'superlu': SuperLU (always available)
|
|
374
|
+
'blqmr': Block QMR iterative (good for complex symmetric, multiple RHS)
|
|
375
|
+
'cg': Conjugate gradient (SPD only)
|
|
376
|
+
'cg+amg': CG with AMG preconditioner (SPD, requires pyamg)
|
|
377
|
+
'gmres': GMRES
|
|
378
|
+
'bicgstab': BiCGSTAB
|
|
379
|
+
**kwargs : dict
|
|
380
|
+
tol : float - convergence tolerance (default: 1e-10)
|
|
381
|
+
maxiter : int - maximum iterations (default: 1000)
|
|
382
|
+
rhsblock : int - block size for blqmr (default: 8)
|
|
383
|
+
nthread : int - parallel workers for iterative solvers
|
|
384
|
+
(default: min(ncol, cpu_count), set to 1 to disable)
|
|
385
|
+
Supported by: blqmr, gmres, bicgstab, cg, cg+amg
|
|
386
|
+
verbose : bool - print solver progress (default: False)
|
|
387
|
+
spd : bool - True if matrix is symmetric positive definite
|
|
388
|
+
M, M1, M2 : preconditioners (disables parallel for gmres/bicgstab/cg)
|
|
389
|
+
x0 : initial guess
|
|
390
|
+
workspace : BLQMRWorkspace for blqmr
|
|
391
|
+
precond_type : int - use automatic preconditioning for blqmr (default: True)
|
|
392
|
+
droptol : float - drop tolerance for ILU preconditioner (default: 0.001)
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
x : ndarray
|
|
397
|
+
Solution
|
|
398
|
+
flag : int
|
|
399
|
+
0 = success, >0 = solver-specific warning/error code
|
|
400
|
+
"""
|
|
401
|
+
if sparse.issparse(rhs):
|
|
402
|
+
rhs = rhs.toarray()
|
|
403
|
+
|
|
404
|
+
rhs_was_1d = rhs.ndim == 1
|
|
405
|
+
if rhs_was_1d:
|
|
406
|
+
rhs = rhs.reshape(-1, 1)
|
|
407
|
+
|
|
408
|
+
n, ncol = rhs.shape
|
|
409
|
+
is_complex = np.iscomplexobj(Amat) or np.iscomplexobj(rhs)
|
|
410
|
+
dtype = complex if is_complex else float
|
|
411
|
+
is_spd = kwargs.get("spd", False)
|
|
412
|
+
tol = kwargs.get("tol", 1e-10)
|
|
413
|
+
maxiter = kwargs.get("maxiter", 1000)
|
|
414
|
+
verbose = kwargs.get("verbose", False)
|
|
415
|
+
|
|
416
|
+
x = np.zeros((n, ncol), dtype=dtype)
|
|
417
|
+
flag = 0
|
|
418
|
+
|
|
419
|
+
def get_direct_solver_for_matrix():
|
|
420
|
+
"""Get best direct solver for current matrix type."""
|
|
421
|
+
# Pardiso is fastest and now supports complex (via real-valued formulation)
|
|
422
|
+
if _DIRECT_SOLVER == "pardiso":
|
|
423
|
+
return "pardiso"
|
|
424
|
+
# For complex matrices without Pardiso, prefer UMFPACK
|
|
425
|
+
if is_complex:
|
|
426
|
+
return "umfpack" if _HAS_UMFPACK else "superlu"
|
|
427
|
+
# For real matrices, use best available
|
|
428
|
+
return _DIRECT_SOLVER
|
|
429
|
+
|
|
430
|
+
# Auto-select solver
|
|
431
|
+
if method == "auto":
|
|
432
|
+
if n < 10000:
|
|
433
|
+
method = "direct"
|
|
434
|
+
elif is_spd and _HAS_AMG and not is_complex:
|
|
435
|
+
method = "cg+amg"
|
|
436
|
+
else:
|
|
437
|
+
method = "direct"
|
|
438
|
+
|
|
439
|
+
if method == "direct":
|
|
440
|
+
method = get_direct_solver_for_matrix()
|
|
441
|
+
|
|
442
|
+
if verbose:
|
|
443
|
+
print(f"femsolve: method={method}, n={n}, ncol={ncol}, complex={is_complex}")
|
|
444
|
+
|
|
445
|
+
# === DIRECT SOLVERS ===
|
|
446
|
+
|
|
447
|
+
if method == "pardiso":
|
|
448
|
+
if _DIRECT_SOLVER != "pardiso":
|
|
449
|
+
warnings.warn("pypardiso not available, falling back")
|
|
450
|
+
return femsolve(Amat, rhs, method="direct", **kwargs)
|
|
451
|
+
|
|
452
|
+
if is_complex:
|
|
453
|
+
# Convert complex system to real-valued form:
|
|
454
|
+
# [A_r -A_i] [x_r] [b_r]
|
|
455
|
+
# [A_i A_r] [x_i] = [b_i]
|
|
456
|
+
A_r = Amat.real
|
|
457
|
+
A_i = Amat.imag
|
|
458
|
+
|
|
459
|
+
# Build block matrix [A_r, -A_i; A_i, A_r]
|
|
460
|
+
if sparse.issparse(Amat):
|
|
461
|
+
A_real = sparse.bmat([[A_r, -A_i], [A_i, A_r]], format="csr")
|
|
462
|
+
else:
|
|
463
|
+
A_real = np.block([[A_r, -A_i], [A_i, A_r]])
|
|
464
|
+
|
|
465
|
+
# Stack RHS: [b_r; b_i]
|
|
466
|
+
rhs_real = np.vstack([rhs.real, rhs.imag])
|
|
467
|
+
|
|
468
|
+
# Solve real system (batch solve for all RHS at once)
|
|
469
|
+
x_real = _pardiso_solve(A_real, rhs_real)
|
|
470
|
+
|
|
471
|
+
# Reconstruct complex solution: x = x_r + j*x_i
|
|
472
|
+
x = x_real[:n, :] + 1j * x_real[n:, :]
|
|
473
|
+
else:
|
|
474
|
+
# Real matrix - batch solve all RHS at once
|
|
475
|
+
Acsr = Amat.tocsr()
|
|
476
|
+
x = _pardiso_solve(Acsr, rhs)
|
|
477
|
+
|
|
478
|
+
elif method == "umfpack":
|
|
479
|
+
if not _HAS_UMFPACK:
|
|
480
|
+
warnings.warn("scikit-umfpack not available, falling back to superlu")
|
|
481
|
+
return femsolve(Amat, rhs, method="superlu", **kwargs)
|
|
482
|
+
|
|
483
|
+
Acsc = Amat.tocsc()
|
|
484
|
+
# Use UMFPACK via scipy's spsolve (auto-selects UMFPACK when installed)
|
|
485
|
+
# spsolve can handle matrix RHS directly
|
|
486
|
+
if ncol > 1:
|
|
487
|
+
# Check for zero columns and solve non-zero ones
|
|
488
|
+
nonzero_cols = [i for i in range(ncol) if np.any(rhs[:, i] != 0)]
|
|
489
|
+
if len(nonzero_cols) == ncol:
|
|
490
|
+
# All columns non-zero - solve all at once
|
|
491
|
+
x[:] = spsolve(Acsc, rhs)
|
|
492
|
+
else:
|
|
493
|
+
# Some zero columns - solve only non-zero ones
|
|
494
|
+
for i in nonzero_cols:
|
|
495
|
+
x[:, i] = spsolve(Acsc, rhs[:, i])
|
|
496
|
+
else:
|
|
497
|
+
x[:, 0] = spsolve(Acsc, rhs[:, 0])
|
|
498
|
+
|
|
499
|
+
elif method == "cholmod":
|
|
500
|
+
if not _HAS_CHOLMOD:
|
|
501
|
+
warnings.warn("scikit-sparse not available, falling back")
|
|
502
|
+
return femsolve(Amat, rhs, method="direct", **kwargs)
|
|
503
|
+
|
|
504
|
+
if is_complex:
|
|
505
|
+
fallback = get_direct_solver_for_matrix()
|
|
506
|
+
if verbose:
|
|
507
|
+
print(f"cholmod doesn't support complex, using {fallback}")
|
|
508
|
+
return femsolve(Amat, rhs, method=fallback, **kwargs)
|
|
509
|
+
|
|
510
|
+
if not is_spd:
|
|
511
|
+
warnings.warn("cholmod requires SPD matrix, falling back")
|
|
512
|
+
return femsolve(Amat, rhs, method="direct", **kwargs)
|
|
513
|
+
|
|
514
|
+
Acsc = Amat.tocsc()
|
|
515
|
+
factor = _cholmod_cholesky(Acsc)
|
|
516
|
+
for i in range(ncol):
|
|
517
|
+
if np.any(rhs[:, i] != 0):
|
|
518
|
+
x[:, i] = factor(rhs[:, i])
|
|
519
|
+
|
|
520
|
+
elif method == "superlu":
|
|
521
|
+
Acsc = Amat.tocsc()
|
|
522
|
+
# For multiple RHS: factorize once, then solve
|
|
523
|
+
# lu.solve() can handle 2D arrays directly
|
|
524
|
+
if ncol > 1:
|
|
525
|
+
try:
|
|
526
|
+
lu = splu(Acsc)
|
|
527
|
+
# Check for zero columns
|
|
528
|
+
nonzero_cols = [i for i in range(ncol) if np.any(rhs[:, i] != 0)]
|
|
529
|
+
if len(nonzero_cols) == ncol:
|
|
530
|
+
# All columns non-zero - solve all at once
|
|
531
|
+
x[:] = lu.solve(rhs)
|
|
532
|
+
else:
|
|
533
|
+
# Some zero columns - solve only non-zero ones
|
|
534
|
+
if len(nonzero_cols) > 0:
|
|
535
|
+
rhs_nonzero = rhs[:, nonzero_cols]
|
|
536
|
+
x_nonzero = lu.solve(rhs_nonzero)
|
|
537
|
+
for idx, col in enumerate(nonzero_cols):
|
|
538
|
+
x[:, col] = (
|
|
539
|
+
x_nonzero[:, idx] if x_nonzero.ndim > 1 else x_nonzero
|
|
540
|
+
)
|
|
541
|
+
except Exception:
|
|
542
|
+
# Fallback to individual solves if batch fails
|
|
543
|
+
for i in range(ncol):
|
|
544
|
+
if np.any(rhs[:, i] != 0):
|
|
545
|
+
x[:, i] = spsolve(Acsc, rhs[:, i])
|
|
546
|
+
else:
|
|
547
|
+
x[:, 0] = spsolve(Acsc, rhs[:, 0])
|
|
548
|
+
|
|
549
|
+
# === ITERATIVE SOLVERS ===
|
|
550
|
+
|
|
551
|
+
elif method == "blqmr":
|
|
552
|
+
if not _HAS_BLQMR:
|
|
553
|
+
warnings.warn("blocksolver not available, falling back to gmres")
|
|
554
|
+
return femsolve(Amat, rhs, method="gmres", **kwargs)
|
|
555
|
+
|
|
556
|
+
M1 = kwargs.get("M1", None)
|
|
557
|
+
M2 = kwargs.get("M2", None)
|
|
558
|
+
x0 = kwargs.get("x0", None)
|
|
559
|
+
rhsblock = kwargs.get("rhsblock", 8)
|
|
560
|
+
workspace = kwargs.get("workspace", None)
|
|
561
|
+
precond_type = kwargs.get("precond_type", 3)
|
|
562
|
+
droptol = kwargs.get("droptol", 0.001)
|
|
563
|
+
nthread = kwargs.get("nthread", None)
|
|
564
|
+
if nthread is None:
|
|
565
|
+
nthread = min(ncol, multiprocessing.cpu_count())
|
|
566
|
+
|
|
567
|
+
if rhsblock <= 0 or ncol <= rhsblock:
|
|
568
|
+
# Single batch - no parallelization needed
|
|
569
|
+
result = blqmr(
|
|
570
|
+
Amat,
|
|
571
|
+
rhs,
|
|
572
|
+
tol=tol,
|
|
573
|
+
maxiter=maxiter,
|
|
574
|
+
M1=M1,
|
|
575
|
+
M2=M2,
|
|
576
|
+
x0=x0,
|
|
577
|
+
workspace=workspace,
|
|
578
|
+
precond_type=precond_type,
|
|
579
|
+
droptol=droptol,
|
|
580
|
+
)
|
|
581
|
+
x = result.x if result.x.ndim > 1 else result.x.reshape(-1, 1)
|
|
582
|
+
flag = result.flag
|
|
583
|
+
if verbose:
|
|
584
|
+
print(
|
|
585
|
+
f"blqmr: iter={result.iter}, relres={result.relres:.2e}, "
|
|
586
|
+
f"flag={flag}, BLQMR_EXT={BLQMR_EXT}"
|
|
587
|
+
)
|
|
588
|
+
elif nthread > 1:
|
|
589
|
+
# Parallel batch solving using multiprocessing
|
|
590
|
+
x, flag = _blqmr_parallel(
|
|
591
|
+
Amat,
|
|
592
|
+
rhs,
|
|
593
|
+
tol=tol,
|
|
594
|
+
maxiter=maxiter,
|
|
595
|
+
rhsblock=rhsblock,
|
|
596
|
+
precond_type=precond_type,
|
|
597
|
+
droptol=droptol,
|
|
598
|
+
nthread=nthread,
|
|
599
|
+
verbose=verbose,
|
|
600
|
+
)
|
|
601
|
+
else:
|
|
602
|
+
# Sequential batch solving
|
|
603
|
+
max_flag = 0
|
|
604
|
+
for start in range(0, ncol, rhsblock):
|
|
605
|
+
end = min(start + rhsblock, ncol)
|
|
606
|
+
rhs_batch = rhs[:, start:end]
|
|
607
|
+
x0_batch = x0[:, start:end] if x0 is not None else None
|
|
608
|
+
|
|
609
|
+
result = blqmr(
|
|
610
|
+
Amat,
|
|
611
|
+
rhs_batch,
|
|
612
|
+
tol=tol,
|
|
613
|
+
maxiter=maxiter,
|
|
614
|
+
M1=M1,
|
|
615
|
+
M2=M2,
|
|
616
|
+
x0=x0_batch,
|
|
617
|
+
workspace=workspace,
|
|
618
|
+
precond_type=precond_type,
|
|
619
|
+
droptol=droptol,
|
|
620
|
+
)
|
|
621
|
+
x_batch = result.x if result.x.ndim > 1 else result.x.reshape(-1, 1)
|
|
622
|
+
x[:, start:end] = x_batch
|
|
623
|
+
max_flag = max(max_flag, result.flag)
|
|
624
|
+
|
|
625
|
+
if verbose:
|
|
626
|
+
print(
|
|
627
|
+
f"blqmr [{start+1}:{end}]: iter={result.iter}, relres={result.relres:.2e}"
|
|
628
|
+
)
|
|
629
|
+
flag = max_flag
|
|
630
|
+
|
|
631
|
+
elif method == "cg+amg":
|
|
632
|
+
if not _HAS_AMG:
|
|
633
|
+
warnings.warn("pyamg not available, falling back to CG")
|
|
634
|
+
return femsolve(Amat, rhs, method="cg", **kwargs)
|
|
635
|
+
|
|
636
|
+
if is_complex:
|
|
637
|
+
warnings.warn("cg+amg doesn't support complex, falling back to gmres")
|
|
638
|
+
return femsolve(Amat, rhs, method="gmres", **kwargs)
|
|
639
|
+
|
|
640
|
+
nthread = kwargs.get("nthread", None)
|
|
641
|
+
if nthread is None:
|
|
642
|
+
nthread = min(ncol, multiprocessing.cpu_count())
|
|
643
|
+
|
|
644
|
+
if nthread > 1 and ncol > 1:
|
|
645
|
+
# Parallel solving
|
|
646
|
+
x, flag = _iterative_parallel(
|
|
647
|
+
Amat,
|
|
648
|
+
rhs,
|
|
649
|
+
"cg",
|
|
650
|
+
tol=tol,
|
|
651
|
+
maxiter=maxiter,
|
|
652
|
+
nthread=nthread,
|
|
653
|
+
use_amg=True,
|
|
654
|
+
verbose=verbose,
|
|
655
|
+
)
|
|
656
|
+
else:
|
|
657
|
+
# Sequential solving
|
|
658
|
+
ml = _pyamg.smoothed_aggregation_solver(Amat.tocsr())
|
|
659
|
+
M = ml.aspreconditioner()
|
|
660
|
+
|
|
661
|
+
for i in range(ncol):
|
|
662
|
+
if np.any(rhs[:, i] != 0):
|
|
663
|
+
try:
|
|
664
|
+
x[:, i], info = cg(
|
|
665
|
+
Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
|
|
666
|
+
)
|
|
667
|
+
except TypeError:
|
|
668
|
+
x[:, i], info = cg(
|
|
669
|
+
Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
|
|
670
|
+
)
|
|
671
|
+
flag = max(flag, info)
|
|
672
|
+
if verbose:
|
|
673
|
+
status = "converged" if info == 0 else f"flag={info}"
|
|
674
|
+
print(f"cg+amg [col {i+1}]: {status}")
|
|
675
|
+
|
|
676
|
+
elif method == "cg":
|
|
677
|
+
if is_complex:
|
|
678
|
+
warnings.warn("cg requires Hermitian matrix, falling back to gmres")
|
|
679
|
+
return femsolve(Amat, rhs, method="gmres", **kwargs)
|
|
680
|
+
|
|
681
|
+
nthread = kwargs.get("nthread", None)
|
|
682
|
+
if nthread is None:
|
|
683
|
+
nthread = min(ncol, multiprocessing.cpu_count())
|
|
684
|
+
M = kwargs.get("M", None)
|
|
685
|
+
|
|
686
|
+
if nthread > 1 and ncol > 1 and M is None:
|
|
687
|
+
# Parallel solving (without custom preconditioner)
|
|
688
|
+
x, flag = _iterative_parallel(
|
|
689
|
+
Amat,
|
|
690
|
+
rhs,
|
|
691
|
+
"cg",
|
|
692
|
+
tol=tol,
|
|
693
|
+
maxiter=maxiter,
|
|
694
|
+
nthread=nthread,
|
|
695
|
+
use_amg=False,
|
|
696
|
+
verbose=verbose,
|
|
697
|
+
)
|
|
698
|
+
else:
|
|
699
|
+
# Sequential solving
|
|
700
|
+
for i in range(ncol):
|
|
701
|
+
if np.any(rhs[:, i] != 0):
|
|
702
|
+
try:
|
|
703
|
+
x[:, i], info = cg(
|
|
704
|
+
Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
|
|
705
|
+
)
|
|
706
|
+
except TypeError:
|
|
707
|
+
x[:, i], info = cg(
|
|
708
|
+
Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
|
|
709
|
+
)
|
|
710
|
+
flag = max(flag, info)
|
|
711
|
+
if verbose:
|
|
712
|
+
status = "converged" if info == 0 else f"flag={info}"
|
|
713
|
+
print(f"cg [col {i+1}]: {status}")
|
|
714
|
+
|
|
715
|
+
elif method == "gmres":
|
|
716
|
+
nthread = kwargs.get("nthread", None)
|
|
717
|
+
if nthread is None:
|
|
718
|
+
nthread = min(ncol, multiprocessing.cpu_count())
|
|
719
|
+
M = kwargs.get("M", None)
|
|
720
|
+
|
|
721
|
+
if nthread > 1 and ncol > 1 and M is None:
|
|
722
|
+
# Parallel solving (without custom preconditioner)
|
|
723
|
+
x, flag = _iterative_parallel(
|
|
724
|
+
Amat,
|
|
725
|
+
rhs,
|
|
726
|
+
"gmres",
|
|
727
|
+
tol=tol,
|
|
728
|
+
maxiter=maxiter,
|
|
729
|
+
nthread=nthread,
|
|
730
|
+
use_amg=False,
|
|
731
|
+
verbose=verbose,
|
|
732
|
+
)
|
|
733
|
+
else:
|
|
734
|
+
# Sequential solving
|
|
735
|
+
for i in range(ncol):
|
|
736
|
+
if np.any(rhs[:, i] != 0):
|
|
737
|
+
try:
|
|
738
|
+
x[:, i], info = gmres(
|
|
739
|
+
Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
|
|
740
|
+
)
|
|
741
|
+
except TypeError:
|
|
742
|
+
x[:, i], info = gmres(
|
|
743
|
+
Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
|
|
744
|
+
)
|
|
745
|
+
flag = max(flag, info)
|
|
746
|
+
if verbose:
|
|
747
|
+
status = "converged" if info == 0 else f"flag={info}"
|
|
748
|
+
print(f"gmres [col {i+1}]: {status}")
|
|
749
|
+
|
|
750
|
+
elif method == "bicgstab":
|
|
751
|
+
nthread = kwargs.get("nthread", None)
|
|
752
|
+
if nthread is None:
|
|
753
|
+
nthread = min(ncol, multiprocessing.cpu_count())
|
|
754
|
+
M = kwargs.get("M", None)
|
|
755
|
+
|
|
756
|
+
if nthread > 1 and ncol > 1 and M is None:
|
|
757
|
+
# Parallel solving (without custom preconditioner)
|
|
758
|
+
x, flag = _iterative_parallel(
|
|
759
|
+
Amat,
|
|
760
|
+
rhs,
|
|
761
|
+
"bicgstab",
|
|
762
|
+
tol=tol,
|
|
763
|
+
maxiter=maxiter,
|
|
764
|
+
nthread=nthread,
|
|
765
|
+
use_amg=False,
|
|
766
|
+
verbose=verbose,
|
|
767
|
+
)
|
|
768
|
+
else:
|
|
769
|
+
# Sequential solving
|
|
770
|
+
for i in range(ncol):
|
|
771
|
+
if np.any(rhs[:, i] != 0):
|
|
772
|
+
try:
|
|
773
|
+
x[:, i], info = bicgstab(
|
|
774
|
+
Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
|
|
775
|
+
)
|
|
776
|
+
except TypeError:
|
|
777
|
+
x[:, i], info = bicgstab(
|
|
778
|
+
Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
|
|
779
|
+
)
|
|
780
|
+
flag = max(flag, info)
|
|
781
|
+
if verbose:
|
|
782
|
+
status = "converged" if info == 0 else f"flag={info}"
|
|
783
|
+
print(f"bicgstab [col {i+1}]: {status}")
|
|
784
|
+
|
|
785
|
+
else:
|
|
786
|
+
raise ValueError(f"Unknown solver: {method}")
|
|
787
|
+
|
|
788
|
+
# Flatten output if input was 1D
|
|
789
|
+
if rhs_was_1d:
|
|
790
|
+
x = x.ravel()
|
|
791
|
+
|
|
792
|
+
return x, flag
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def get_solver_info() -> dict:
|
|
796
|
+
"""Return information about available solvers."""
|
|
797
|
+
info = {
|
|
798
|
+
"direct_solver": _DIRECT_SOLVER,
|
|
799
|
+
"has_pardiso": _DIRECT_SOLVER == "pardiso",
|
|
800
|
+
"has_umfpack": _HAS_UMFPACK,
|
|
801
|
+
"has_cholmod": _HAS_CHOLMOD,
|
|
802
|
+
"has_amg": _HAS_AMG,
|
|
803
|
+
"has_blqmr": _HAS_BLQMR,
|
|
804
|
+
"complex_direct": "umfpack" if _HAS_UMFPACK else "superlu",
|
|
805
|
+
"complex_iterative": ["gmres", "bicgstab"],
|
|
806
|
+
"cpu_count": multiprocessing.cpu_count(),
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
if _HAS_BLQMR:
|
|
810
|
+
info["blqmr_backend"] = "fortran" if BLQMR_EXT else "native"
|
|
811
|
+
info["blqmr_has_numba"] = HAS_NUMBA
|
|
812
|
+
info["complex_iterative"].insert(0, "blqmr")
|
|
813
|
+
|
|
814
|
+
return info
|