fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
@@ -2,21 +2,35 @@ from __future__ import annotations
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass, field
5
- from typing import Any, Callable, Iterable, List, Sequence
5
+ import warnings
6
+ from typing import Any, Callable, Iterable, List, Sequence, TYPE_CHECKING, TypeAlias
6
7
 
7
8
  import numpy as np
8
9
  import jax.numpy as jnp
9
10
 
10
- from ..core.assembly import assemble_bilinear_form
11
+ from ..core.assembly import FormKernel, ResidualForm, assemble_bilinear_form
11
12
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
12
13
  from .cg import cg_solve, cg_solve_jax
14
+ from .petsc import petsc_shell_solve
13
15
  from .sparse import FluxSparseMatrix
14
- from .dirichlet import expand_dirichlet_solution
16
+ from .dirichlet import DirichletBC, expand_dirichlet_solution
15
17
  from .newton import newton_solve
16
18
  from .result import SolverResult
17
19
  from .history import NewtonIterRecord, LoadStepResult
18
20
  from ..tools.timer import SectionTimer, NullTimer
19
21
 
22
+ if TYPE_CHECKING:
23
+ from jax import Array as JaxArray
24
+
25
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
26
+ else:
27
+ ArrayLike: TypeAlias = np.ndarray
28
+ DirichletLike: TypeAlias = tuple[np.ndarray, np.ndarray]
29
+ ExtraTerm: TypeAlias = Callable[
30
+ [np.ndarray],
31
+ tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, dict[str, Any]] | None,
32
+ ]
33
+
20
34
 
21
35
  @dataclass
22
36
  class NonlinearAnalysis:
@@ -35,6 +49,8 @@ class NonlinearAnalysis:
35
49
  Unscaled external load vector (scaled by load factor in `external_for_load`).
36
50
  dirichlet : tuple | None
37
51
  (dofs, values) for Dirichlet boundary conditions.
52
+ extra_terms : list[callable] | None
53
+ Optional extra term assemblers returning (K, f[, metrics]).
38
54
  jacobian_pattern : Any | None
39
55
  Optional sparsity pattern to reuse between load steps.
40
56
  dtype : Any
@@ -42,14 +58,19 @@ class NonlinearAnalysis:
42
58
  """
43
59
 
44
60
  space: Any
45
- residual_form: Any
61
+ residual_form: ResidualForm[Any]
46
62
  params: Any
47
- base_external_vector: Any | None = None
48
- dirichlet: tuple | None = None
63
+ base_external_vector: ArrayLike | None = None
64
+ dirichlet: DirichletLike | None = None
65
+ extra_terms: list[ExtraTerm] | None = None
49
66
  jacobian_pattern: Any | None = None
50
67
  dtype: Any = jnp.float64
51
68
 
52
- def external_for_load(self, load_factor: float):
69
+ def __post_init__(self) -> None:
70
+ if isinstance(self.dirichlet, DirichletBC):
71
+ self.dirichlet = self.dirichlet.as_tuple()
72
+
73
+ def external_for_load(self, load_factor: float) -> ArrayLike | None:
53
74
  if self.base_external_vector is None:
54
75
  return None
55
76
  return jnp.asarray(load_factor * self.base_external_vector, dtype=self.dtype)
@@ -71,6 +92,7 @@ class NewtonLoopConfig:
71
92
  linear_maxiter: int | None = None
72
93
  linear_tol: float | None = None
73
94
  linear_preconditioner: Any | None = None
95
+ matfree_mode: str = "linearize"
74
96
  load_sequence: Sequence[float] | None = None
75
97
  n_steps: int = 1
76
98
 
@@ -95,17 +117,18 @@ class NewtonSolveRunner:
95
117
  def __init__(self, analysis: NonlinearAnalysis, config: NewtonLoopConfig):
96
118
  self.analysis = analysis
97
119
  self.config = config
120
+ self._matfree_cache: dict[str, Any] = {}
98
121
 
99
122
  def run(
100
123
  self,
101
- u0=None,
124
+ u0: ArrayLike | None = None,
102
125
  *,
103
126
  load_sequence: Sequence[float] | None = None,
104
- newton_callback: Callable | None = None,
127
+ newton_callback: Callable[[dict[str, Any]], None] | None = None,
105
128
  step_callback: Callable[[LoadStepResult], None] | None = None,
106
129
  timer: "SectionTimer | None" = None,
107
130
  report_timing: bool = True
108
- ):
131
+ ) -> tuple[np.ndarray, list[LoadStepResult]]:
109
132
  """
110
133
  Execute Newton solves over the configured load schedule.
111
134
 
@@ -154,6 +177,16 @@ class NewtonSolveRunner:
154
177
  schedule.append(lf_clamped)
155
178
  prev = lf_clamped
156
179
  history: List[LoadStepResult] = []
180
+ matfree_cache = None
181
+ if self.config.linear_preconditioner == "diag0":
182
+ n_free = self.analysis.space.n_dofs
183
+ if self.analysis.dirichlet is not None:
184
+ n_free -= len(self.analysis.dirichlet[0])
185
+ cached_free = self._matfree_cache.get("n_free_dofs")
186
+ if cached_free is not None and cached_free != n_free:
187
+ self._matfree_cache.clear()
188
+ self._matfree_cache["n_free_dofs"] = n_free
189
+ matfree_cache = self._matfree_cache
157
190
  for step_i, load_factor in enumerate(schedule, start=1):
158
191
  with timer.section("step"):
159
192
  external = self.analysis.external_for_load(load_factor)
@@ -203,6 +236,8 @@ class NewtonSolveRunner:
203
236
  linear_maxiter=self.config.linear_maxiter,
204
237
  linear_tol=self.config.linear_tol,
205
238
  linear_preconditioner=self.config.linear_preconditioner,
239
+ matfree_cache=matfree_cache,
240
+ matfree_mode=self.config.matfree_mode,
206
241
  dirichlet=self.analysis.dirichlet,
207
242
  line_search=self.config.line_search,
208
243
  max_ls=self.config.max_ls,
@@ -210,6 +245,7 @@ class NewtonSolveRunner:
210
245
  external_vector=external,
211
246
  callback=cb,
212
247
  jacobian_pattern=self.analysis.jacobian_pattern,
248
+ extra_terms=self.analysis.extra_terms,
213
249
  )
214
250
  exception = None
215
251
  except Exception as e: # pragma: no cover - defensive
@@ -256,7 +292,9 @@ class NewtonSolveRunner:
256
292
  return u, history
257
293
 
258
294
 
259
- def _condense_flux_dirichlet(K: FluxSparseMatrix, F, dirichlet):
295
+ def _condense_flux_dirichlet(
296
+ K: FluxSparseMatrix, F: ArrayLike, dirichlet: DirichletLike
297
+ ) -> tuple[Any, np.ndarray, np.ndarray | None, np.ndarray, np.ndarray, np.ndarray]:
260
298
  dir_dofs, dir_vals = dirichlet
261
299
  dir_arr = np.asarray(dir_dofs, dtype=int)
262
300
  dir_vals_arr = np.asarray(dir_vals, dtype=float)
@@ -274,11 +312,12 @@ def _condense_flux_dirichlet(K: FluxSparseMatrix, F, dirichlet):
274
312
 
275
313
  def solve_nonlinear(
276
314
  space,
277
- residual_form,
278
- params,
315
+ residual_form: ResidualForm[Any],
316
+ params: Any,
279
317
  *,
280
- dirichlet: tuple | None = None,
281
- base_external_vector=None,
318
+ dirichlet: DirichletLike | None = None,
319
+ base_external_vector: ArrayLike | None = None,
320
+ extra_terms: list[ExtraTerm] | None = None,
282
321
  dtype=jnp.float64,
283
322
  maxiter: int = 20,
284
323
  tol: float = 1e-8,
@@ -287,13 +326,14 @@ def solve_nonlinear(
287
326
  linear_maxiter: int | None = None,
288
327
  linear_tol: float | None = None,
289
328
  linear_preconditioner=None,
329
+ matfree_mode: str = "linearize",
290
330
  line_search: bool = False,
291
331
  max_ls: int = 10,
292
332
  ls_c: float = 1e-4,
293
333
  n_steps: int = 1,
294
334
  jacobian_pattern=None,
295
- u0=None,
296
- ):
335
+ u0: ArrayLike | None = None,
336
+ ) -> tuple[np.ndarray, list[LoadStepResult]]:
297
337
  """
298
338
  Convenience wrapper: build NonlinearAnalysis and run NewtonSolveRunner.
299
339
  """
@@ -303,6 +343,7 @@ def solve_nonlinear(
303
343
  params=params,
304
344
  base_external_vector=base_external_vector,
305
345
  dirichlet=dirichlet,
346
+ extra_terms=extra_terms,
306
347
  dtype=dtype,
307
348
  jacobian_pattern=jacobian_pattern,
308
349
  )
@@ -314,6 +355,7 @@ def solve_nonlinear(
314
355
  linear_maxiter=linear_maxiter,
315
356
  linear_tol=linear_tol,
316
357
  linear_preconditioner=linear_preconditioner,
358
+ matfree_mode=matfree_mode,
317
359
  line_search=line_search,
318
360
  max_ls=max_ls,
319
361
  ls_c=ls_c,
@@ -335,10 +377,10 @@ class LinearAnalysis:
335
377
  """
336
378
 
337
379
  space: Any
338
- bilinear_form: Any
380
+ bilinear_form: FormKernel[Any]
339
381
  params: Any
340
- base_rhs_vector: Any
341
- dirichlet: tuple | None = None
382
+ base_rhs_vector: ArrayLike
383
+ dirichlet: DirichletLike | None = None
342
384
  pattern: Any | None = None
343
385
  dtype: Any = jnp.float64
344
386
 
@@ -349,7 +391,7 @@ class LinearAnalysis:
349
391
  pattern=self.pattern,
350
392
  )
351
393
 
352
- def rhs_for_load(self, load_factor: float):
394
+ def rhs_for_load(self, load_factor: float) -> ArrayLike:
353
395
  return jnp.asarray(load_factor * self.base_rhs_vector, dtype=self.dtype)
354
396
 
355
397
 
@@ -359,10 +401,42 @@ class LinearSolveConfig:
359
401
  Control parameters for the linear solve with optional load scaling.
360
402
  """
361
403
 
362
- method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom"
404
+ method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom" | "petsc_shell"
363
405
  tol: float = 1e-8
364
406
  maxiter: int | None = None
365
407
  preconditioner: Any | None = None
408
+ ksp_type: str | None = None
409
+ pc_type: str | None = None
410
+ ksp_rtol: float | None = None
411
+ ksp_atol: float | None = None
412
+ ksp_max_it: int | None = None
413
+ petsc_ksp_norm_type: str | None = None
414
+ petsc_ksp_monitor_true_residual: bool = False
415
+ petsc_ksp_converged_reason: bool = False
416
+ petsc_ksp_monitor_short: bool = False
417
+ petsc_shell_pmat: bool = False
418
+ petsc_shell_pmat_mode: str = "full"
419
+ petsc_shell_pmat_rebuild_iters: int | None = None
420
+ petsc_shell_fallback: bool = False
421
+ petsc_shell_fallback_ksp_types: tuple[str, ...] = ("bcgs", "gmres")
422
+ petsc_shell_fallback_rebuild_pmat: bool = True
423
+
424
+ @classmethod
425
+ def from_preset(cls, name: str) -> "LinearSolveConfig":
426
+ preset = name.lower()
427
+ if preset == "contact":
428
+ return cls(
429
+ method="petsc_shell",
430
+ ksp_type="bcgs",
431
+ pc_type="ilu",
432
+ petsc_shell_pmat=True,
433
+ petsc_shell_pmat_mode="full",
434
+ petsc_ksp_norm_type="unpreconditioned",
435
+ petsc_ksp_monitor_true_residual=True,
436
+ petsc_ksp_converged_reason=True,
437
+ petsc_shell_fallback=True,
438
+ )
439
+ raise ValueError(f"Unknown LinearSolveConfig preset: {name}")
366
440
 
367
441
 
368
442
  @dataclass
@@ -381,7 +455,7 @@ class LinearStepResult:
381
455
  """
382
456
  info: SolverResult
383
457
  solve_time: float
384
- u: Any
458
+ u: ArrayLike
385
459
 
386
460
 
387
461
  class LinearSolveRunner:
@@ -392,6 +466,9 @@ class LinearSolveRunner:
392
466
  def __init__(self, analysis: LinearAnalysis, config: LinearSolveConfig):
393
467
  self.analysis = analysis
394
468
  self.config = config
469
+ self._petsc_shell_pmat = None
470
+ self._petsc_shell_last_iters = None
471
+ self._petsc_shell_pmat_rebuilds = 0
395
472
 
396
473
  def run(
397
474
  self,
@@ -488,6 +565,114 @@ class LinearSolveRunner:
488
565
  stop_reason=("converged" if lin_conv else "linfail"),
489
566
  nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
490
567
  )
568
+ elif self.config.method == "petsc_shell":
569
+ base_ksp_type = self.config.ksp_type or "gmres"
570
+ pc_type = self.config.pc_type if self.config.pc_type is not None else "none"
571
+ ksp_rtol = self.config.ksp_rtol if self.config.ksp_rtol is not None else self.config.tol
572
+ ksp_atol = self.config.ksp_atol
573
+ ksp_max_it = self.config.ksp_max_it if self.config.ksp_max_it is not None else self.config.maxiter
574
+ petsc_options = {}
575
+ if self.config.petsc_ksp_norm_type:
576
+ petsc_options["fluxfem_ksp_norm_type"] = self.config.petsc_ksp_norm_type
577
+ if self.config.petsc_ksp_monitor_true_residual:
578
+ petsc_options["fluxfem_ksp_monitor_true_residual"] = ""
579
+ if self.config.petsc_ksp_converged_reason:
580
+ petsc_options["fluxfem_ksp_converged_reason"] = ""
581
+ if self.config.petsc_ksp_monitor_short:
582
+ petsc_options["fluxfem_ksp_monitor_short"] = ""
583
+ if not petsc_options:
584
+ petsc_options = None
585
+ use_pmat = bool(self.config.petsc_shell_pmat)
586
+ rebuild_thresh = self.config.petsc_shell_pmat_rebuild_iters
587
+ if use_pmat:
588
+ pmat_mode = (self.config.petsc_shell_pmat_mode or "full").lower()
589
+ if pmat_mode == "none":
590
+ use_pmat = False
591
+ pmat = None
592
+ elif pmat_mode == "full":
593
+ pmat = K_ff
594
+ else:
595
+ warnings.warn(
596
+ f"petsc_shell_pmat_mode='{pmat_mode}' is not supported in runner; "
597
+ "falling back to 'full'.",
598
+ RuntimeWarning,
599
+ )
600
+ pmat = K_ff
601
+ if use_pmat:
602
+ if self._petsc_shell_pmat is None:
603
+ self._petsc_shell_pmat = pmat
604
+ self._petsc_shell_pmat_rebuilds += 1
605
+ elif rebuild_thresh is not None and self._petsc_shell_last_iters is not None:
606
+ if self._petsc_shell_last_iters > rebuild_thresh:
607
+ self._petsc_shell_pmat = pmat
608
+ self._petsc_shell_pmat_rebuilds += 1
609
+ pmat = self._petsc_shell_pmat
610
+ if not use_pmat:
611
+ pmat = None
612
+
613
+ def _attempt_solve(ksp_type: str):
614
+ return petsc_shell_solve(
615
+ K_ff,
616
+ F_free,
617
+ preconditioner=self.config.preconditioner,
618
+ ksp_type=ksp_type,
619
+ pc_type=pc_type,
620
+ rtol=ksp_rtol,
621
+ atol=ksp_atol,
622
+ max_it=ksp_max_it,
623
+ pmat=pmat,
624
+ options=petsc_options,
625
+ return_info=True,
626
+ )
627
+
628
+ fallback_ksp = [base_ksp_type]
629
+ if self.config.petsc_shell_fallback:
630
+ for ksp in self.config.petsc_shell_fallback_ksp_types:
631
+ if ksp not in fallback_ksp:
632
+ fallback_ksp.append(ksp)
633
+ fallback_attempts = []
634
+ petsc_info = None
635
+ u_free = None
636
+ for ksp in fallback_ksp:
637
+ fallback_attempts.append(ksp)
638
+ u_free, petsc_info = _attempt_solve(ksp)
639
+ lin_conv = petsc_info.get("converged")
640
+ reason = petsc_info.get("reason")
641
+ if lin_conv is None and reason is not None:
642
+ lin_conv = reason > 0
643
+ if lin_conv:
644
+ break
645
+ if self.config.petsc_shell_fallback and use_pmat and self.config.petsc_shell_fallback_rebuild_pmat:
646
+ self._petsc_shell_pmat = pmat
647
+ self._petsc_shell_pmat_rebuilds += 1
648
+ lin_iters = petsc_info.get("iters")
649
+ lin_res = petsc_info.get("residual_norm")
650
+ lin_solve_dt = petsc_info.get("solve_time")
651
+ pc_setup_dt = petsc_info.get("pc_setup_time")
652
+ pmat_dt = petsc_info.get("pmat_build_time")
653
+ lin_conv = petsc_info.get("converged")
654
+ if lin_conv is None and petsc_info.get("reason") is not None:
655
+ lin_conv = petsc_info.get("reason") > 0
656
+ if lin_conv is None:
657
+ lin_conv = True
658
+ self._petsc_shell_last_iters = lin_iters
659
+ info = SolverResult(
660
+ converged=bool(lin_conv),
661
+ iters=int(lin_iters) if lin_iters is not None else 0,
662
+ linear_iters=int(lin_iters) if lin_iters is not None else None,
663
+ linear_converged=bool(lin_conv),
664
+ linear_residual=float(lin_res) if lin_res is not None else None,
665
+ linear_solve_time=float(lin_solve_dt) if lin_solve_dt is not None else None,
666
+ pc_setup_time=float(pc_setup_dt) if pc_setup_dt is not None else None,
667
+ pmat_build_time=float(pmat_dt) if pmat_dt is not None else None,
668
+ pmat_rebuilds=self._petsc_shell_pmat_rebuilds if use_pmat else None,
669
+ pmat_mode=self.config.petsc_shell_pmat_mode if use_pmat else None,
670
+ tol=self.config.tol,
671
+ stop_reason=("converged" if lin_conv else "linfail"),
672
+ nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
673
+ )
674
+ if len(fallback_attempts) > 1:
675
+ info.linear_fallbacks = fallback_attempts
491
676
  else:
492
677
  raise ValueError(f"Unknown linear solve method: {self.config.method}")
493
678
 
fluxfem/solver/solver.py CHANGED
@@ -1,12 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Any, TYPE_CHECKING, TypeAlias
4
+
3
5
  import numpy as np
4
6
  import jax.numpy as jnp
5
7
 
6
8
  from .cg import cg_solve, cg_solve_jax
7
9
  from .newton import newton_solve
10
+ from .petsc import petsc_solve, petsc_shell_solve
8
11
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
9
12
  from .dirichlet import (
13
+ DirichletBC,
10
14
  condense_dirichlet_dense,
11
15
  condense_dirichlet_fluxsparse,
12
16
  expand_dirichlet_solution,
@@ -16,6 +20,16 @@ from .dirichlet import (
16
20
  from .sparse import FluxSparseMatrix
17
21
  from ..core.space import FESpace
18
22
 
23
+ if TYPE_CHECKING:
24
+ from jax import Array as JaxArray
25
+
26
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
27
+ else:
28
+ ArrayLike: TypeAlias = np.ndarray
29
+ DirichletLike: TypeAlias = DirichletBC | tuple[np.ndarray, np.ndarray]
30
+ SolveInfo: TypeAlias = dict[str, Any]
31
+ SolveReturn: TypeAlias = tuple[np.ndarray, SolveInfo]
32
+
19
33
 
20
34
  class LinearSolver:
21
35
  """
@@ -30,7 +44,7 @@ class LinearSolver:
30
44
  self.tol = tol
31
45
  self.maxiter = maxiter
32
46
 
33
- def _solve_free(self, A, b):
47
+ def _solve_free(self, A: Any, b: Any) -> SolveReturn:
34
48
  if self.method == "cg":
35
49
  x, info = cg_solve_jax(A, b, tol=self.tol, maxiter=self.maxiter)
36
50
  return np.asarray(x), {"iters": info.get("iters"), "converged": info.get("converged", True)}
@@ -46,25 +60,34 @@ class LinearSolver:
46
60
  elif self.method == "spdirect_solve_gpu":
47
61
  x = spdirect_solve_gpu(A, b)
48
62
  return np.asarray(x), {"iters": 1, "converged": True}
63
+ elif self.method == "petsc":
64
+ x = petsc_solve(A, b)
65
+ return np.asarray(x), {"iters": None, "converged": True}
66
+ elif self.method == "petsc_shell":
67
+ x = petsc_shell_solve(A, b)
68
+ return np.asarray(x), {"iters": None, "converged": True}
49
69
  else:
50
70
  raise ValueError(f"Unknown linear method: {self.method}")
51
71
 
52
72
  def solve(
53
73
  self,
54
- A,
55
- b,
74
+ A: Any,
75
+ b: Any,
56
76
  *,
57
- dirichlet=None,
77
+ dirichlet: DirichletLike | None = None,
58
78
  dirichlet_mode: str = "condense",
59
79
  n_total: int | None = None,
60
- ):
80
+ ) -> SolveReturn:
61
81
  if dirichlet is None:
62
82
  return self._solve_free(A, b)
63
83
 
64
84
  if dirichlet_mode not in ("condense", "enforce"):
65
85
  raise ValueError("dirichlet_mode must be 'condense' or 'enforce'.")
66
86
 
67
- dir_dofs, dir_vals = dirichlet
87
+ if isinstance(dirichlet, DirichletBC):
88
+ dir_dofs, dir_vals = dirichlet.as_tuple()
89
+ else:
90
+ dir_dofs, dir_vals = dirichlet
68
91
  if dirichlet_mode == "enforce":
69
92
  if isinstance(A, FluxSparseMatrix):
70
93
  A_bc, b_bc = enforce_dirichlet_sparse(A, b, dir_dofs, dir_vals)
@@ -98,8 +121,8 @@ class NonlinearSolver:
98
121
  def __init__(
99
122
  self,
100
123
  space: FESpace,
101
- res_form,
102
- params,
124
+ res_form: Any,
125
+ params: Any,
103
126
  *,
104
127
  tol: float = 1e-8,
105
128
  maxiter: int = 20,
@@ -108,10 +131,10 @@ class NonlinearSolver:
108
131
  max_ls: int = 10,
109
132
  ls_c: float = 1e-4,
110
133
  linear_tol: float | None = None,
111
- dirichlet=None,
112
- external_vector=None,
134
+ dirichlet: DirichletLike | None = None,
135
+ external_vector: ArrayLike | None = None,
113
136
  linear_maxiter: int | None = None,
114
- jacobian_pattern=None,
137
+ jacobian_pattern: Any | None = None,
115
138
  ):
116
139
  self.space = space
117
140
  self.res_form = res_form
@@ -128,7 +151,7 @@ class NonlinearSolver:
128
151
  self.linear_maxiter = linear_maxiter
129
152
  self.jacobian_pattern = jacobian_pattern
130
153
 
131
- def solve(self, u0):
154
+ def solve(self, u0: ArrayLike):
132
155
  return newton_solve(
133
156
  self.space,
134
157
  self.res_form,