redbirdpy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redbirdpy/solver.py ADDED
@@ -0,0 +1,814 @@
1
+ """
2
+ Redbird Solver Module - Linear system solvers for FEM.
3
+
4
+ Provides:
5
+ femsolve: Main solver interface with automatic method selection
6
+ get_solver_info: Query available solver backends
7
+
8
+ Supported solvers:
9
+ Direct: pardiso, umfpack, cholmod, superlu
10
+ Iterative: blqmr, cg, cg+amg, gmres, bicgstab
11
+
12
+ Dependencies:
13
+ - blocksolver: For BLQMR iterative solver (complex symmetric systems)
14
+ """
15
+
16
+ __all__ = [
17
+ "femsolve",
18
+ "get_solver_info",
19
+ ]
20
+
21
+ import numpy as np
22
+ from scipy import sparse
23
+ from scipy.sparse.linalg import spsolve, cg, gmres, bicgstab, splu
24
+ from typing import Dict, Tuple, Optional, Union, List, Any
25
+ import warnings
26
+ from concurrent.futures import ProcessPoolExecutor
27
+ import multiprocessing
28
+
29
+ # =============================================================================
30
+ # Solver Backend Detection
31
+ # =============================================================================
32
+
33
+ _DIRECT_SOLVER = "superlu" # fallback
34
+ _HAS_UMFPACK = False
35
+ _HAS_CHOLMOD = False
36
+ _HAS_AMG = False
37
+ _HAS_BLQMR = False
38
+
39
+ # Pardiso references
40
+ _pardiso_solve = None
41
+ _pardiso_factorized = None
42
+
43
+ try:
44
+ from pypardiso import spsolve as pardiso_solve, factorized as pardiso_factorized
45
+
46
+ _pardiso_solve = pardiso_solve
47
+ _pardiso_factorized = pardiso_factorized
48
+ _DIRECT_SOLVER = "pardiso"
49
+ except ImportError:
50
+ pass
51
+
52
+ try:
53
+ from scikits.umfpack import spsolve as umfpack_spsolve, splu as umfpack_splu
54
+
55
+ _HAS_UMFPACK = True
56
+ if _DIRECT_SOLVER == "superlu":
57
+ _DIRECT_SOLVER = "umfpack"
58
+ except ImportError:
59
+ umfpack_spsolve = None
60
+ umfpack_splu = None
61
+
62
+ try:
63
+ from sksparse.cholmod import cholesky as _cholmod_cholesky
64
+
65
+ _HAS_CHOLMOD = True
66
+ except ImportError:
67
+ _cholmod_cholesky = None
68
+
69
+ try:
70
+ import pyamg as _pyamg
71
+
72
+ _HAS_AMG = True
73
+ except ImportError:
74
+ _pyamg = None
75
+
76
+ # Import blocksolver for BLQMR
77
+ try:
78
+ from blocksolver import (
79
+ blqmr,
80
+ BLQMRResult,
81
+ BLQMR_EXT,
82
+ HAS_NUMBA,
83
+ make_preconditioner,
84
+ BLQMRWorkspace,
85
+ )
86
+
87
+ _HAS_BLQMR = True
88
+ except ImportError:
89
+ blqmr = None
90
+ BLQMRResult = None
91
+ BLQMR_EXT = False
92
+ HAS_NUMBA = False
93
+ make_preconditioner = None
94
+ BLQMRWorkspace = None
95
+
96
+
97
+ # =============================================================================
98
+ # Parallel Solver Helper Functions
99
+ # =============================================================================
100
+
101
+
102
+ def _solve_blqmr_batch(args):
103
+ """
104
+ Worker function for parallel BLQMR solving.
105
+
106
+ This function runs in a separate process and reconstructs the sparse
107
+ matrix from its components before solving.
108
+ """
109
+ (
110
+ A_data,
111
+ A_indices,
112
+ A_indptr,
113
+ A_shape,
114
+ A_dtype,
115
+ rhs_batch,
116
+ tol,
117
+ maxiter,
118
+ precond_type,
119
+ droptol,
120
+ batch_id,
121
+ start_col,
122
+ ) = args
123
+
124
+ # Import blqmr in worker process
125
+ from blocksolver import blqmr as worker_blqmr
126
+
127
+ # Reconstruct sparse matrix in worker process
128
+ A = sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape, dtype=A_dtype)
129
+
130
+ result = worker_blqmr(
131
+ A,
132
+ rhs_batch,
133
+ tol=tol,
134
+ maxiter=maxiter,
135
+ M1=None,
136
+ M2=None,
137
+ x0=None,
138
+ workspace=None,
139
+ precond_type=precond_type,
140
+ droptol=droptol,
141
+ )
142
+
143
+ x_batch = result.x if result.x.ndim > 1 else result.x.reshape(-1, 1)
144
+ return batch_id, start_col, x_batch, result.flag, result.iter, result.relres
145
+
146
+
147
+ def _solve_iterative_column(args):
148
+ """
149
+ Worker function for parallel iterative solving (gmres, bicgstab, cg).
150
+
151
+ Solves a single RHS column using the specified iterative method.
152
+ """
153
+ (
154
+ A_data,
155
+ A_indices,
156
+ A_indptr,
157
+ A_shape,
158
+ A_dtype,
159
+ rhs_col,
160
+ col_idx,
161
+ solver_type,
162
+ tol,
163
+ maxiter,
164
+ use_amg,
165
+ ) = args
166
+
167
+ from scipy.sparse.linalg import gmres, bicgstab, cg
168
+ from scipy import sparse
169
+
170
+ # Reconstruct sparse matrix in worker process
171
+ A = sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape, dtype=A_dtype)
172
+
173
+ # Setup preconditioner if AMG requested
174
+ M = None
175
+ if use_amg and solver_type == "cg":
176
+ try:
177
+ import pyamg
178
+
179
+ ml = pyamg.smoothed_aggregation_solver(A.tocsr())
180
+ M = ml.aspreconditioner()
181
+ except ImportError:
182
+ pass
183
+
184
+ # Select solver
185
+ if solver_type == "gmres":
186
+ solver_func = gmres
187
+ elif solver_type == "bicgstab":
188
+ solver_func = bicgstab
189
+ elif solver_type == "cg":
190
+ solver_func = cg
191
+ else:
192
+ raise ValueError(f"Unknown solver type: {solver_type}")
193
+
194
+ # Solve
195
+ try:
196
+ x_col, info = solver_func(A, rhs_col, M=M, rtol=tol, maxiter=maxiter)
197
+ except TypeError:
198
+ # Older scipy versions use 'tol' instead of 'rtol'
199
+ x_col, info = solver_func(A, rhs_col, M=M, tol=tol, maxiter=maxiter)
200
+
201
+ return col_idx, x_col, info
202
+
203
+
204
+ def _blqmr_parallel(
205
+ Amat: sparse.spmatrix,
206
+ rhs: np.ndarray,
207
+ *,
208
+ tol: float,
209
+ maxiter: int,
210
+ rhsblock: int,
211
+ precond_type: int,
212
+ droptol: float,
213
+ nthread: int,
214
+ verbose: bool,
215
+ ) -> Tuple[np.ndarray, int]:
216
+ """
217
+ Solve BLQMR with multiple RHS in parallel using multiprocessing.
218
+ """
219
+ n, ncol = rhs.shape
220
+
221
+ # Determine output dtype based on matrix and RHS
222
+ is_complex = np.iscomplexobj(Amat) or np.iscomplexobj(rhs)
223
+ out_dtype = np.complex128 if is_complex else np.float64
224
+
225
+ # Convert matrix to CSC for consistent serialization
226
+ Acsc = Amat.tocsc()
227
+ A_data = Acsc.data
228
+ A_indices = Acsc.indices
229
+ A_indptr = Acsc.indptr
230
+ A_shape = Acsc.shape
231
+ A_dtype = Acsc.dtype
232
+
233
+ # Prepare batches
234
+ batches = []
235
+ for batch_id, start in enumerate(range(0, ncol, rhsblock)):
236
+ end = min(start + rhsblock, ncol)
237
+ rhs_batch = np.ascontiguousarray(rhs[:, start:end])
238
+ batches.append(
239
+ (
240
+ A_data,
241
+ A_indices,
242
+ A_indptr,
243
+ A_shape,
244
+ A_dtype,
245
+ rhs_batch,
246
+ tol,
247
+ maxiter,
248
+ precond_type,
249
+ droptol,
250
+ batch_id,
251
+ start,
252
+ )
253
+ )
254
+
255
+ # Solve in parallel
256
+ x = np.zeros((n, ncol), dtype=out_dtype)
257
+ max_flag = 0
258
+
259
+ with ProcessPoolExecutor(max_workers=nthread) as executor:
260
+ results = executor.map(_solve_blqmr_batch, batches)
261
+
262
+ for batch_id, start_col, x_batch, batch_flag, niter, relres in results:
263
+ end_col = start_col + x_batch.shape[1]
264
+ if is_complex and not np.iscomplexobj(x_batch):
265
+ x[:, start_col:end_col] = x_batch.astype(out_dtype)
266
+ else:
267
+ x[:, start_col:end_col] = x_batch
268
+ max_flag = max(max_flag, batch_flag)
269
+
270
+ if verbose:
271
+ print(
272
+ f"blqmr [{start_col+1}:{end_col}] (worker {batch_id}): "
273
+ f"iter={niter}, relres={relres:.2e}, flag={batch_flag}"
274
+ )
275
+
276
+ return x, max_flag
277
+
278
+
279
+ def _iterative_parallel(
280
+ Amat: sparse.spmatrix,
281
+ rhs: np.ndarray,
282
+ solver_type: str,
283
+ *,
284
+ tol: float,
285
+ maxiter: int,
286
+ nthread: int,
287
+ use_amg: bool = False,
288
+ verbose: bool = False,
289
+ ) -> Tuple[np.ndarray, int]:
290
+ """
291
+ Solve iterative methods (gmres, bicgstab, cg) in parallel.
292
+
293
+ Each RHS column is solved independently in a separate process.
294
+ """
295
+ n, ncol = rhs.shape
296
+
297
+ is_complex = np.iscomplexobj(Amat) or np.iscomplexobj(rhs)
298
+ out_dtype = np.complex128 if is_complex else np.float64
299
+
300
+ # Convert matrix to CSC for serialization
301
+ Acsc = Amat.tocsc()
302
+ A_data = Acsc.data
303
+ A_indices = Acsc.indices
304
+ A_indptr = Acsc.indptr
305
+ A_shape = Acsc.shape
306
+ A_dtype = Acsc.dtype
307
+
308
+ # Prepare tasks - one per non-zero RHS column
309
+ tasks = []
310
+ for i in range(ncol):
311
+ if np.any(rhs[:, i] != 0):
312
+ rhs_col = np.ascontiguousarray(rhs[:, i])
313
+ tasks.append(
314
+ (
315
+ A_data,
316
+ A_indices,
317
+ A_indptr,
318
+ A_shape,
319
+ A_dtype,
320
+ rhs_col,
321
+ i,
322
+ solver_type,
323
+ tol,
324
+ maxiter,
325
+ use_amg,
326
+ )
327
+ )
328
+
329
+ # Solve in parallel
330
+ x = np.zeros((n, ncol), dtype=out_dtype)
331
+ max_flag = 0
332
+
333
+ with ProcessPoolExecutor(max_workers=nthread) as executor:
334
+ results = executor.map(_solve_iterative_column, tasks)
335
+
336
+ for col_idx, x_col, info in results:
337
+ x[:, col_idx] = x_col
338
+ max_flag = max(max_flag, info)
339
+
340
+ if verbose:
341
+ status = "converged" if info == 0 else f"flag={info}"
342
+ print(f"{solver_type} [col {col_idx+1}]: {status}")
343
+
344
+ return x, max_flag
345
+
346
+
347
+ # =============================================================================
348
+ # Main Solver Interface
349
+ # =============================================================================
350
+
351
+
352
+ def femsolve(
353
+ Amat: sparse.spmatrix,
354
+ rhs: Union[np.ndarray, sparse.spmatrix],
355
+ method: str = "auto",
356
+ **kwargs,
357
+ ) -> Tuple[np.ndarray, int]:
358
+ """
359
+ Solve FEM linear system A*x = b with automatic solver selection.
360
+
361
+ Parameters
362
+ ----------
363
+ Amat : sparse matrix
364
+ System matrix
365
+ rhs : ndarray or sparse matrix
366
+ Right-hand side (n,) or (n, m) for m simultaneous RHS
367
+ method : str
368
+ 'auto': Automatically select best solver
369
+ 'pardiso': Intel MKL PARDISO (fastest, requires pypardiso)
370
+ 'umfpack': UMFPACK (fast, requires scikit-umfpack)
371
+ 'cholmod': CHOLMOD for SPD matrices (requires scikit-sparse)
372
+ 'direct': Best available direct solver
373
+ 'superlu': SuperLU (always available)
374
+ 'blqmr': Block QMR iterative (good for complex symmetric, multiple RHS)
375
+ 'cg': Conjugate gradient (SPD only)
376
+ 'cg+amg': CG with AMG preconditioner (SPD, requires pyamg)
377
+ 'gmres': GMRES
378
+ 'bicgstab': BiCGSTAB
379
+ **kwargs : dict
380
+ tol : float - convergence tolerance (default: 1e-10)
381
+ maxiter : int - maximum iterations (default: 1000)
382
+ rhsblock : int - block size for blqmr (default: 8)
383
+ nthread : int - parallel workers for iterative solvers
384
+ (default: min(ncol, cpu_count), set to 1 to disable)
385
+ Supported by: blqmr, gmres, bicgstab, cg, cg+amg
386
+ verbose : bool - print solver progress (default: False)
387
+ spd : bool - True if matrix is symmetric positive definite
388
+ M, M1, M2 : preconditioners (disables parallel for gmres/bicgstab/cg)
389
+ x0 : initial guess
390
+ workspace : BLQMRWorkspace for blqmr
391
+ precond_type : int - use automatic preconditioning for blqmr (default: True)
392
+ droptol : float - drop tolerance for ILU preconditioner (default: 0.001)
393
+
394
+ Returns
395
+ -------
396
+ x : ndarray
397
+ Solution
398
+ flag : int
399
+ 0 = success, >0 = solver-specific warning/error code
400
+ """
401
+ if sparse.issparse(rhs):
402
+ rhs = rhs.toarray()
403
+
404
+ rhs_was_1d = rhs.ndim == 1
405
+ if rhs_was_1d:
406
+ rhs = rhs.reshape(-1, 1)
407
+
408
+ n, ncol = rhs.shape
409
+ is_complex = np.iscomplexobj(Amat) or np.iscomplexobj(rhs)
410
+ dtype = complex if is_complex else float
411
+ is_spd = kwargs.get("spd", False)
412
+ tol = kwargs.get("tol", 1e-10)
413
+ maxiter = kwargs.get("maxiter", 1000)
414
+ verbose = kwargs.get("verbose", False)
415
+
416
+ x = np.zeros((n, ncol), dtype=dtype)
417
+ flag = 0
418
+
419
+ def get_direct_solver_for_matrix():
420
+ """Get best direct solver for current matrix type."""
421
+ # Pardiso is fastest and now supports complex (via real-valued formulation)
422
+ if _DIRECT_SOLVER == "pardiso":
423
+ return "pardiso"
424
+ # For complex matrices without Pardiso, prefer UMFPACK
425
+ if is_complex:
426
+ return "umfpack" if _HAS_UMFPACK else "superlu"
427
+ # For real matrices, use best available
428
+ return _DIRECT_SOLVER
429
+
430
+ # Auto-select solver
431
+ if method == "auto":
432
+ if n < 10000:
433
+ method = "direct"
434
+ elif is_spd and _HAS_AMG and not is_complex:
435
+ method = "cg+amg"
436
+ else:
437
+ method = "direct"
438
+
439
+ if method == "direct":
440
+ method = get_direct_solver_for_matrix()
441
+
442
+ if verbose:
443
+ print(f"femsolve: method={method}, n={n}, ncol={ncol}, complex={is_complex}")
444
+
445
+ # === DIRECT SOLVERS ===
446
+
447
+ if method == "pardiso":
448
+ if _DIRECT_SOLVER != "pardiso":
449
+ warnings.warn("pypardiso not available, falling back")
450
+ return femsolve(Amat, rhs, method="direct", **kwargs)
451
+
452
+ if is_complex:
453
+ # Convert complex system to real-valued form:
454
+ # [A_r -A_i] [x_r] [b_r]
455
+ # [A_i A_r] [x_i] = [b_i]
456
+ A_r = Amat.real
457
+ A_i = Amat.imag
458
+
459
+ # Build block matrix [A_r, -A_i; A_i, A_r]
460
+ if sparse.issparse(Amat):
461
+ A_real = sparse.bmat([[A_r, -A_i], [A_i, A_r]], format="csr")
462
+ else:
463
+ A_real = np.block([[A_r, -A_i], [A_i, A_r]])
464
+
465
+ # Stack RHS: [b_r; b_i]
466
+ rhs_real = np.vstack([rhs.real, rhs.imag])
467
+
468
+ # Solve real system (batch solve for all RHS at once)
469
+ x_real = _pardiso_solve(A_real, rhs_real)
470
+
471
+ # Reconstruct complex solution: x = x_r + j*x_i
472
+ x = x_real[:n, :] + 1j * x_real[n:, :]
473
+ else:
474
+ # Real matrix - batch solve all RHS at once
475
+ Acsr = Amat.tocsr()
476
+ x = _pardiso_solve(Acsr, rhs)
477
+
478
+ elif method == "umfpack":
479
+ if not _HAS_UMFPACK:
480
+ warnings.warn("scikit-umfpack not available, falling back to superlu")
481
+ return femsolve(Amat, rhs, method="superlu", **kwargs)
482
+
483
+ Acsc = Amat.tocsc()
484
+ # Use UMFPACK via scipy's spsolve (auto-selects UMFPACK when installed)
485
+ # spsolve can handle matrix RHS directly
486
+ if ncol > 1:
487
+ # Check for zero columns and solve non-zero ones
488
+ nonzero_cols = [i for i in range(ncol) if np.any(rhs[:, i] != 0)]
489
+ if len(nonzero_cols) == ncol:
490
+ # All columns non-zero - solve all at once
491
+ x[:] = spsolve(Acsc, rhs)
492
+ else:
493
+ # Some zero columns - solve only non-zero ones
494
+ for i in nonzero_cols:
495
+ x[:, i] = spsolve(Acsc, rhs[:, i])
496
+ else:
497
+ x[:, 0] = spsolve(Acsc, rhs[:, 0])
498
+
499
+ elif method == "cholmod":
500
+ if not _HAS_CHOLMOD:
501
+ warnings.warn("scikit-sparse not available, falling back")
502
+ return femsolve(Amat, rhs, method="direct", **kwargs)
503
+
504
+ if is_complex:
505
+ fallback = get_direct_solver_for_matrix()
506
+ if verbose:
507
+ print(f"cholmod doesn't support complex, using {fallback}")
508
+ return femsolve(Amat, rhs, method=fallback, **kwargs)
509
+
510
+ if not is_spd:
511
+ warnings.warn("cholmod requires SPD matrix, falling back")
512
+ return femsolve(Amat, rhs, method="direct", **kwargs)
513
+
514
+ Acsc = Amat.tocsc()
515
+ factor = _cholmod_cholesky(Acsc)
516
+ for i in range(ncol):
517
+ if np.any(rhs[:, i] != 0):
518
+ x[:, i] = factor(rhs[:, i])
519
+
520
+ elif method == "superlu":
521
+ Acsc = Amat.tocsc()
522
+ # For multiple RHS: factorize once, then solve
523
+ # lu.solve() can handle 2D arrays directly
524
+ if ncol > 1:
525
+ try:
526
+ lu = splu(Acsc)
527
+ # Check for zero columns
528
+ nonzero_cols = [i for i in range(ncol) if np.any(rhs[:, i] != 0)]
529
+ if len(nonzero_cols) == ncol:
530
+ # All columns non-zero - solve all at once
531
+ x[:] = lu.solve(rhs)
532
+ else:
533
+ # Some zero columns - solve only non-zero ones
534
+ if len(nonzero_cols) > 0:
535
+ rhs_nonzero = rhs[:, nonzero_cols]
536
+ x_nonzero = lu.solve(rhs_nonzero)
537
+ for idx, col in enumerate(nonzero_cols):
538
+ x[:, col] = (
539
+ x_nonzero[:, idx] if x_nonzero.ndim > 1 else x_nonzero
540
+ )
541
+ except Exception:
542
+ # Fallback to individual solves if batch fails
543
+ for i in range(ncol):
544
+ if np.any(rhs[:, i] != 0):
545
+ x[:, i] = spsolve(Acsc, rhs[:, i])
546
+ else:
547
+ x[:, 0] = spsolve(Acsc, rhs[:, 0])
548
+
549
+ # === ITERATIVE SOLVERS ===
550
+
551
+ elif method == "blqmr":
552
+ if not _HAS_BLQMR:
553
+ warnings.warn("blocksolver not available, falling back to gmres")
554
+ return femsolve(Amat, rhs, method="gmres", **kwargs)
555
+
556
+ M1 = kwargs.get("M1", None)
557
+ M2 = kwargs.get("M2", None)
558
+ x0 = kwargs.get("x0", None)
559
+ rhsblock = kwargs.get("rhsblock", 8)
560
+ workspace = kwargs.get("workspace", None)
561
+ precond_type = kwargs.get("precond_type", 3)
562
+ droptol = kwargs.get("droptol", 0.001)
563
+ nthread = kwargs.get("nthread", None)
564
+ if nthread is None:
565
+ nthread = min(ncol, multiprocessing.cpu_count())
566
+
567
+ if rhsblock <= 0 or ncol <= rhsblock:
568
+ # Single batch - no parallelization needed
569
+ result = blqmr(
570
+ Amat,
571
+ rhs,
572
+ tol=tol,
573
+ maxiter=maxiter,
574
+ M1=M1,
575
+ M2=M2,
576
+ x0=x0,
577
+ workspace=workspace,
578
+ precond_type=precond_type,
579
+ droptol=droptol,
580
+ )
581
+ x = result.x if result.x.ndim > 1 else result.x.reshape(-1, 1)
582
+ flag = result.flag
583
+ if verbose:
584
+ print(
585
+ f"blqmr: iter={result.iter}, relres={result.relres:.2e}, "
586
+ f"flag={flag}, BLQMR_EXT={BLQMR_EXT}"
587
+ )
588
+ elif nthread > 1:
589
+ # Parallel batch solving using multiprocessing
590
+ x, flag = _blqmr_parallel(
591
+ Amat,
592
+ rhs,
593
+ tol=tol,
594
+ maxiter=maxiter,
595
+ rhsblock=rhsblock,
596
+ precond_type=precond_type,
597
+ droptol=droptol,
598
+ nthread=nthread,
599
+ verbose=verbose,
600
+ )
601
+ else:
602
+ # Sequential batch solving
603
+ max_flag = 0
604
+ for start in range(0, ncol, rhsblock):
605
+ end = min(start + rhsblock, ncol)
606
+ rhs_batch = rhs[:, start:end]
607
+ x0_batch = x0[:, start:end] if x0 is not None else None
608
+
609
+ result = blqmr(
610
+ Amat,
611
+ rhs_batch,
612
+ tol=tol,
613
+ maxiter=maxiter,
614
+ M1=M1,
615
+ M2=M2,
616
+ x0=x0_batch,
617
+ workspace=workspace,
618
+ precond_type=precond_type,
619
+ droptol=droptol,
620
+ )
621
+ x_batch = result.x if result.x.ndim > 1 else result.x.reshape(-1, 1)
622
+ x[:, start:end] = x_batch
623
+ max_flag = max(max_flag, result.flag)
624
+
625
+ if verbose:
626
+ print(
627
+ f"blqmr [{start+1}:{end}]: iter={result.iter}, relres={result.relres:.2e}"
628
+ )
629
+ flag = max_flag
630
+
631
+ elif method == "cg+amg":
632
+ if not _HAS_AMG:
633
+ warnings.warn("pyamg not available, falling back to CG")
634
+ return femsolve(Amat, rhs, method="cg", **kwargs)
635
+
636
+ if is_complex:
637
+ warnings.warn("cg+amg doesn't support complex, falling back to gmres")
638
+ return femsolve(Amat, rhs, method="gmres", **kwargs)
639
+
640
+ nthread = kwargs.get("nthread", None)
641
+ if nthread is None:
642
+ nthread = min(ncol, multiprocessing.cpu_count())
643
+
644
+ if nthread > 1 and ncol > 1:
645
+ # Parallel solving
646
+ x, flag = _iterative_parallel(
647
+ Amat,
648
+ rhs,
649
+ "cg",
650
+ tol=tol,
651
+ maxiter=maxiter,
652
+ nthread=nthread,
653
+ use_amg=True,
654
+ verbose=verbose,
655
+ )
656
+ else:
657
+ # Sequential solving
658
+ ml = _pyamg.smoothed_aggregation_solver(Amat.tocsr())
659
+ M = ml.aspreconditioner()
660
+
661
+ for i in range(ncol):
662
+ if np.any(rhs[:, i] != 0):
663
+ try:
664
+ x[:, i], info = cg(
665
+ Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
666
+ )
667
+ except TypeError:
668
+ x[:, i], info = cg(
669
+ Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
670
+ )
671
+ flag = max(flag, info)
672
+ if verbose:
673
+ status = "converged" if info == 0 else f"flag={info}"
674
+ print(f"cg+amg [col {i+1}]: {status}")
675
+
676
+ elif method == "cg":
677
+ if is_complex:
678
+ warnings.warn("cg requires Hermitian matrix, falling back to gmres")
679
+ return femsolve(Amat, rhs, method="gmres", **kwargs)
680
+
681
+ nthread = kwargs.get("nthread", None)
682
+ if nthread is None:
683
+ nthread = min(ncol, multiprocessing.cpu_count())
684
+ M = kwargs.get("M", None)
685
+
686
+ if nthread > 1 and ncol > 1 and M is None:
687
+ # Parallel solving (without custom preconditioner)
688
+ x, flag = _iterative_parallel(
689
+ Amat,
690
+ rhs,
691
+ "cg",
692
+ tol=tol,
693
+ maxiter=maxiter,
694
+ nthread=nthread,
695
+ use_amg=False,
696
+ verbose=verbose,
697
+ )
698
+ else:
699
+ # Sequential solving
700
+ for i in range(ncol):
701
+ if np.any(rhs[:, i] != 0):
702
+ try:
703
+ x[:, i], info = cg(
704
+ Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
705
+ )
706
+ except TypeError:
707
+ x[:, i], info = cg(
708
+ Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
709
+ )
710
+ flag = max(flag, info)
711
+ if verbose:
712
+ status = "converged" if info == 0 else f"flag={info}"
713
+ print(f"cg [col {i+1}]: {status}")
714
+
715
+ elif method == "gmres":
716
+ nthread = kwargs.get("nthread", None)
717
+ if nthread is None:
718
+ nthread = min(ncol, multiprocessing.cpu_count())
719
+ M = kwargs.get("M", None)
720
+
721
+ if nthread > 1 and ncol > 1 and M is None:
722
+ # Parallel solving (without custom preconditioner)
723
+ x, flag = _iterative_parallel(
724
+ Amat,
725
+ rhs,
726
+ "gmres",
727
+ tol=tol,
728
+ maxiter=maxiter,
729
+ nthread=nthread,
730
+ use_amg=False,
731
+ verbose=verbose,
732
+ )
733
+ else:
734
+ # Sequential solving
735
+ for i in range(ncol):
736
+ if np.any(rhs[:, i] != 0):
737
+ try:
738
+ x[:, i], info = gmres(
739
+ Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
740
+ )
741
+ except TypeError:
742
+ x[:, i], info = gmres(
743
+ Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
744
+ )
745
+ flag = max(flag, info)
746
+ if verbose:
747
+ status = "converged" if info == 0 else f"flag={info}"
748
+ print(f"gmres [col {i+1}]: {status}")
749
+
750
+ elif method == "bicgstab":
751
+ nthread = kwargs.get("nthread", None)
752
+ if nthread is None:
753
+ nthread = min(ncol, multiprocessing.cpu_count())
754
+ M = kwargs.get("M", None)
755
+
756
+ if nthread > 1 and ncol > 1 and M is None:
757
+ # Parallel solving (without custom preconditioner)
758
+ x, flag = _iterative_parallel(
759
+ Amat,
760
+ rhs,
761
+ "bicgstab",
762
+ tol=tol,
763
+ maxiter=maxiter,
764
+ nthread=nthread,
765
+ use_amg=False,
766
+ verbose=verbose,
767
+ )
768
+ else:
769
+ # Sequential solving
770
+ for i in range(ncol):
771
+ if np.any(rhs[:, i] != 0):
772
+ try:
773
+ x[:, i], info = bicgstab(
774
+ Amat, rhs[:, i], M=M, rtol=tol, maxiter=maxiter
775
+ )
776
+ except TypeError:
777
+ x[:, i], info = bicgstab(
778
+ Amat, rhs[:, i], M=M, tol=tol, maxiter=maxiter
779
+ )
780
+ flag = max(flag, info)
781
+ if verbose:
782
+ status = "converged" if info == 0 else f"flag={info}"
783
+ print(f"bicgstab [col {i+1}]: {status}")
784
+
785
+ else:
786
+ raise ValueError(f"Unknown solver: {method}")
787
+
788
+ # Flatten output if input was 1D
789
+ if rhs_was_1d:
790
+ x = x.ravel()
791
+
792
+ return x, flag
793
+
794
+
795
+ def get_solver_info() -> dict:
796
+ """Return information about available solvers."""
797
+ info = {
798
+ "direct_solver": _DIRECT_SOLVER,
799
+ "has_pardiso": _DIRECT_SOLVER == "pardiso",
800
+ "has_umfpack": _HAS_UMFPACK,
801
+ "has_cholmod": _HAS_CHOLMOD,
802
+ "has_amg": _HAS_AMG,
803
+ "has_blqmr": _HAS_BLQMR,
804
+ "complex_direct": "umfpack" if _HAS_UMFPACK else "superlu",
805
+ "complex_iterative": ["gmres", "bicgstab"],
806
+ "cpu_count": multiprocessing.cpu_count(),
807
+ }
808
+
809
+ if _HAS_BLQMR:
810
+ info["blqmr_backend"] = "fortran" if BLQMR_EXT else "native"
811
+ info["blqmr_has_numba"] = HAS_NUMBA
812
+ info["complex_iterative"].insert(0, "blqmr")
813
+
814
+ return info