fluxfem 0.1.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fluxfem might be problematic. Click here for more details.

Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +318 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +828 -0
  12. fluxfem/helpers_ts.py +11 -0
  13. fluxfem/helpers_wf.py +44 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.3a0.dist-info/LICENSE +201 -0
  45. fluxfem-0.1.3a0.dist-info/METADATA +125 -0
  46. fluxfem-0.1.3a0.dist-info/RECORD +47 -0
  47. fluxfem-0.1.3a0.dist-info/WHEEL +4 -0
@@ -0,0 +1,534 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Callable, Iterable, List, Sequence
6
+
7
+ import numpy as np
8
+ import jax.numpy as jnp
9
+
10
+ from ..core.assembly import assemble_bilinear_form
11
+ from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
12
+ from .cg import cg_solve, cg_solve_jax
13
+ from .sparse import FluxSparseMatrix
14
+ from .dirichlet import expand_dirichlet_solution
15
+ from .newton import newton_solve
16
+ from .result import SolverResult
17
+ from .history import NewtonIterRecord, LoadStepResult
18
+ from ..tools.timer import SectionTimer, NullTimer
19
+
20
+
21
+ @dataclass
22
+ class NonlinearAnalysis:
23
+ """
24
+ Bundle problem data needed for a Newton solve with load scaling.
25
+
26
+ Attributes
27
+ ----------
28
+ space : Any
29
+ FE space containing topology/dofs.
30
+ residual_form : Any
31
+ Internal residual form (e.g., neo_hookean_residual_form).
32
+ params : Any
33
+ Parameters forwarded to the residual form.
34
+ base_external_vector : Any | None
35
+ Unscaled external load vector (scaled by load factor in `external_for_load`).
36
+ dirichlet : tuple | None
37
+ (dofs, values) for Dirichlet boundary conditions.
38
+ jacobian_pattern : Any | None
39
+ Optional sparsity pattern to reuse between load steps.
40
+ dtype : Any
41
+ dtype for the solution vector (defaults to float64).
42
+ """
43
+
44
+ space: Any
45
+ residual_form: Any
46
+ params: Any
47
+ base_external_vector: Any | None = None
48
+ dirichlet: tuple | None = None
49
+ jacobian_pattern: Any | None = None
50
+ dtype: Any = jnp.float64
51
+
52
+ def external_for_load(self, load_factor: float):
53
+ if self.base_external_vector is None:
54
+ return None
55
+ return jnp.asarray(load_factor * self.base_external_vector, dtype=self.dtype)
56
+
57
+
58
+ @dataclass
59
+ class NewtonLoopConfig:
60
+ """
61
+ Control parameters for the Newton loop with load stepping.
62
+ """
63
+
64
+ tol: float = 1e-8
65
+ atol: float = 0.0
66
+ maxiter: int = 20
67
+ line_search: bool = False
68
+ max_ls: int = 10
69
+ ls_c: float = 1e-4
70
+ linear_solver: str = "spsolve"
71
+ linear_maxiter: int | None = None
72
+ linear_tol: float | None = None
73
+ linear_preconditioner: Any | None = None
74
+ load_sequence: Sequence[float] | None = None
75
+ n_steps: int = 1
76
+
77
+ def schedule(self) -> List[float]:
78
+ """
79
+ Return a monotonically increasing list of load factors (0->1].
80
+ """
81
+ if self.load_sequence is not None:
82
+ return list(self.load_sequence)
83
+ n = max(1, self.n_steps)
84
+ return list(np.linspace(0.0, 1.0, n + 1, endpoint=True)[1:])
85
+
86
+
87
+ class NewtonSolveRunner:
88
+ """
89
+ Run one or more Newton solves across load factors.
90
+
91
+ This orchestrates load stepping, assembles external load per step,
92
+ and returns the full (Dirichlet-expanded) solution and per-step history.
93
+ """
94
+
95
+ def __init__(self, analysis: NonlinearAnalysis, config: NewtonLoopConfig):
96
+ self.analysis = analysis
97
+ self.config = config
98
+
99
+ def run(
100
+ self,
101
+ u0=None,
102
+ *,
103
+ load_sequence: Sequence[float] | None = None,
104
+ newton_callback: Callable | None = None,
105
+ step_callback: Callable[[LoadStepResult], None] | None = None,
106
+ timer: "SectionTimer | None" = None,
107
+ report_timing: bool = True
108
+ ):
109
+ """
110
+ Execute Newton solves over the configured load schedule.
111
+
112
+ Parameters
113
+ ----------
114
+ u0 : array-like | None
115
+ Initial guess (defaults to zeros).
116
+ load_sequence : sequence | None
117
+ Optional per-call load schedule; if omitted, uses config.schedule().
118
+ newton_callback : callable | None
119
+ Per-iteration callback passed to `newton_solve`.
120
+ step_callback : callable | None
121
+ Optional hook called after each load step with LoadStepResult.
122
+ Returns:
123
+ u: full solution (Dirichlet expanded), dtype per analysis.dtype
124
+ history: list of LoadStepResult per load factor
125
+ """
126
+ # timer = timer or NullTimer()
127
+ timer = timer or SectionTimer(hierarchical=True)
128
+ with timer.section("run_total"):
129
+ dtype = self.analysis.dtype
130
+
131
+ with timer.section("preprocess"):
132
+
133
+ if u0 is None:
134
+ u = jnp.zeros(self.analysis.space.n_dofs, dtype=dtype)
135
+ else:
136
+ u = jnp.asarray(u0, dtype=dtype)
137
+
138
+ schedule_raw = list(load_sequence) if load_sequence is not None else self.config.schedule()
139
+ # enforce monotone increasing 0->1 schedule, warn if dropped
140
+ schedule = []
141
+ dropped = []
142
+ prev = 0.0
143
+ for lf in schedule_raw:
144
+ lf_clamped = float(lf)
145
+ if not np.isfinite(lf_clamped):
146
+ dropped.append(("nonfinite", lf_clamped))
147
+ continue
148
+ if lf_clamped < 0.0 or lf_clamped > 1.0:
149
+ dropped.append(("out_of_range", lf_clamped))
150
+ continue
151
+ if lf_clamped < prev:
152
+ dropped.append(("nonmonotone", lf_clamped))
153
+ continue
154
+ schedule.append(lf_clamped)
155
+ prev = lf_clamped
156
+ history: List[LoadStepResult] = []
157
+ for step_i, load_factor in enumerate(schedule, start=1):
158
+ with timer.section("step"):
159
+ external = self.analysis.external_for_load(load_factor)
160
+ iter_log: List[NewtonIterRecord] = []
161
+ ext_nnz = int(jnp.count_nonzero(external))
162
+ ext_inf = int(jnp.count_nonzero(external))
163
+ # assumes u is full dofs (xyz)
164
+ u_nodes0 = np.asarray(u).reshape(-1, 3)
165
+ max_u0 = float(np.linalg.norm(u_nodes0, axis=1).max()) if u_nodes0.size else 0.0
166
+ print(
167
+ f"[load factor step {step_i}/{len(schedule)}] lf={load_factor:.3f} "
168
+ f"||F||_inf={ext_inf:.3e} nnz={ext_nnz} max|u|_start={max_u0:.3e}"
169
+ )
170
+
171
+ def cb(d):
172
+ # d: {"iter": k, "residual_norm": ..., "rel_residual": ..., "alpha": ..., "step_norm": ...}
173
+ lin_iters = d.get("linear_iters")
174
+ lin_res = d.get("linear_residual")
175
+ lin_conv = d.get("linear_converged")
176
+ iter_log.append(
177
+ NewtonIterRecord(
178
+ iter=int(d.get("iter", -1)),
179
+ res_inf=float(d.get("res_inf", float("nan"))),
180
+ res_two=float(d.get("res_two", float("nan"))),
181
+ rel_res_inf=float(d.get("rel_residual", float("nan"))),
182
+ alpha=float(d.get("alpha", 1.0)),
183
+ step_norm=float(d.get("step_norm", float("nan"))),
184
+ lin_iters=int(lin_iters) if lin_iters is not None else None,
185
+ lin_converged=bool(lin_conv) if lin_conv is not None else None,
186
+ lin_residual=float(lin_res) if lin_res is not None else None,
187
+ nan_detected=bool(d.get("nan_detected", False)),
188
+ )
189
+ )
190
+ if newton_callback is not None:
191
+ newton_callback(d)
192
+
193
+ try:
194
+ u, info = newton_solve(
195
+ self.analysis.space,
196
+ self.analysis.residual_form,
197
+ u,
198
+ self.analysis.params,
199
+ tol=self.config.tol,
200
+ atol=self.config.atol,
201
+ maxiter=self.config.maxiter,
202
+ linear_solver=self.config.linear_solver,
203
+ linear_maxiter=self.config.linear_maxiter,
204
+ linear_tol=self.config.linear_tol,
205
+ linear_preconditioner=self.config.linear_preconditioner,
206
+ dirichlet=self.analysis.dirichlet,
207
+ line_search=self.config.line_search,
208
+ max_ls=self.config.max_ls,
209
+ ls_c=self.config.ls_c,
210
+ external_vector=external,
211
+ callback=cb,
212
+ jacobian_pattern=self.analysis.jacobian_pattern,
213
+ )
214
+ exception = None
215
+ except Exception as e: # pragma: no cover - defensive
216
+ info = SolverResult(converged=False, iters=0, stop_reason="exception", nan_detected=False)
217
+ exception = repr(e)
218
+
219
+ # ===== [B] OUTER LOOP PRINT (STEP END) =====
220
+ u_nodes1 = np.asarray(u).reshape(-1, 3)
221
+ max_u1 = float(np.linalg.norm(u_nodes1, axis=1).max()) if u_nodes1.size else 0.0
222
+ step_solve_time = timer._records.get("step>newton_solve", [0.0])[-1]
223
+ print(
224
+ f" -> converged={getattr(info,'converged',None)} iters={getattr(info,'iters',None)} "
225
+ f"time={step_solve_time:.3f}s max|u|_end={max_u1:.3e}"
226
+ + (f" EXC={exception}" if exception else "")
227
+ )
228
+
229
+ meta = {
230
+ "load_factor": load_factor,
231
+ "linear_solver": self.config.linear_solver,
232
+ "line_search": self.config.line_search,
233
+ "maxiter": self.config.maxiter,
234
+ "n_dofs": self.analysis.space.n_dofs,
235
+ "dtype": str(self.analysis.dtype),
236
+ "u_layout": "full",
237
+ "schedule": schedule,
238
+ "schedule_dropped": dropped,
239
+ }
240
+
241
+ result = LoadStepResult(
242
+ load_factor=load_factor,
243
+ info=info,
244
+ solve_time=step_solve_time,
245
+ u=u,
246
+ iter_history=iter_log,
247
+ exception=exception,
248
+ meta=meta,
249
+ )
250
+ history.append(result)
251
+ if step_callback is not None:
252
+ step_callback(result)
253
+
254
+ if report_timing:
255
+ timer.report(sort_by="total")
256
+ return u, history
257
+
258
+
259
+ def _condense_flux_dirichlet(K: FluxSparseMatrix, F, dirichlet):
260
+ dir_dofs, dir_vals = dirichlet
261
+ dir_arr = np.asarray(dir_dofs, dtype=int)
262
+ dir_vals_arr = np.asarray(dir_vals, dtype=float)
263
+ K_csr = K.to_csr()
264
+ mask = np.ones(K_csr.shape[0], dtype=bool)
265
+ mask[dir_arr] = False
266
+ free = np.nonzero(mask)[0]
267
+ F_full = np.asarray(F, dtype=float)
268
+ K_fd = K_csr[free][:, dir_arr] if dir_arr.size > 0 else None
269
+ F_free_base = F_full[free]
270
+ offset = K_fd @ dir_vals_arr if K_fd is not None and dir_arr.size > 0 else None
271
+ K_ff = K_csr[free][:, free]
272
+ return K_ff, F_free_base, offset, free, dir_arr, dir_vals_arr
273
+
274
+
275
+ def solve_nonlinear(
276
+ space,
277
+ residual_form,
278
+ params,
279
+ *,
280
+ dirichlet: tuple | None = None,
281
+ base_external_vector=None,
282
+ dtype=jnp.float64,
283
+ maxiter: int = 20,
284
+ tol: float = 1e-8,
285
+ atol: float = 1e-10,
286
+ linear_solver: str = "spsolve",
287
+ linear_maxiter: int | None = None,
288
+ linear_tol: float | None = None,
289
+ linear_preconditioner=None,
290
+ line_search: bool = False,
291
+ max_ls: int = 10,
292
+ ls_c: float = 1e-4,
293
+ n_steps: int = 1,
294
+ jacobian_pattern=None,
295
+ u0=None,
296
+ ):
297
+ """
298
+ Convenience wrapper: build NonlinearAnalysis and run NewtonSolveRunner.
299
+ """
300
+ analysis = NonlinearAnalysis(
301
+ space=space,
302
+ residual_form=residual_form,
303
+ params=params,
304
+ base_external_vector=base_external_vector,
305
+ dirichlet=dirichlet,
306
+ dtype=dtype,
307
+ jacobian_pattern=jacobian_pattern,
308
+ )
309
+ cfg = NewtonLoopConfig(
310
+ maxiter=maxiter,
311
+ tol=tol,
312
+ atol=atol,
313
+ linear_solver=linear_solver,
314
+ linear_maxiter=linear_maxiter,
315
+ linear_tol=linear_tol,
316
+ linear_preconditioner=linear_preconditioner,
317
+ line_search=line_search,
318
+ max_ls=max_ls,
319
+ ls_c=ls_c,
320
+ n_steps=n_steps,
321
+ )
322
+ runner = NewtonSolveRunner(analysis, cfg)
323
+ u0_use = jnp.zeros(space.n_dofs, dtype=dtype) if u0 is None else u0
324
+ u, history = runner.run(u0=u0_use)
325
+ return u, history
326
+
327
+
328
+ @dataclass
329
+ class LinearAnalysis:
330
+ """
331
+ Bundle linear problem data for a single solve or a load-scaled sequence.
332
+
333
+ The matrix is assembled once from ``bilinear_form``; the RHS is scaled
334
+ by the load factor.
335
+ """
336
+
337
+ space: Any
338
+ bilinear_form: Any
339
+ params: Any
340
+ base_rhs_vector: Any
341
+ dirichlet: tuple | None = None
342
+ pattern: Any | None = None
343
+ dtype: Any = jnp.float64
344
+
345
+ def assemble_matrix(self):
346
+ return self.space.assemble_bilinear_form(
347
+ self.bilinear_form,
348
+ params=self.params,
349
+ pattern=self.pattern,
350
+ )
351
+
352
+ def rhs_for_load(self, load_factor: float):
353
+ return jnp.asarray(load_factor * self.base_rhs_vector, dtype=self.dtype)
354
+
355
+
356
+ @dataclass
357
+ class LinearSolveConfig:
358
+ """
359
+ Control parameters for the linear solve with optional load scaling.
360
+ """
361
+
362
+ method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom"
363
+ tol: float = 1e-8
364
+ maxiter: int | None = None
365
+ preconditioner: Any | None = None
366
+
367
+
368
+ @dataclass
369
+ class LinearStepResult:
370
+ """
371
+ Result record for one linear solve step.
372
+
373
+ Attributes
374
+ ----------
375
+ info : SolverResult
376
+ Solver status and iteration metadata.
377
+ solve_time : float
378
+ Wall time for the solve section.
379
+ u : Any
380
+ Full solution vector (Dirichlet-expanded).
381
+ """
382
+ info: SolverResult
383
+ solve_time: float
384
+ u: Any
385
+
386
+
387
+ class LinearSolveRunner:
388
+ """
389
+ Solve linear systems for one or more load factors using a unified interface.
390
+ """
391
+
392
+ def __init__(self, analysis: LinearAnalysis, config: LinearSolveConfig):
393
+ self.analysis = analysis
394
+ self.config = config
395
+
396
+ def run(
397
+ self,
398
+ *,
399
+ step_callback: Callable[[LinearStepResult], None] | None = None,
400
+ timer: "SectionTimer | None" = None,
401
+ report_timing: bool = True
402
+ ) -> tuple[np.ndarray | None, list[LinearStepResult]]:
403
+
404
+ timer = timer or SectionTimer(hierarchical=True)
405
+ # timer = timer or NullTimer()
406
+ with timer.section("linear_run_total"):
407
+ with timer.section("assemble_matrix"):
408
+ K = self.analysis.assemble_matrix()
409
+
410
+ with timer.section("build_rhs"):
411
+ base_rhs = jnp.asarray(
412
+ self.analysis.base_rhs_vector, dtype=self.analysis.dtype
413
+ )
414
+ if self.analysis.dirichlet is not None:
415
+ K_ff, F_free_base, offset, free, dir_arr, dir_vals_arr = _condense_flux_dirichlet(
416
+ K, base_rhs, self.analysis.dirichlet
417
+ )
418
+ n_total = K.shape[0] if hasattr(K, "shape") else self.analysis.space.n_dofs
419
+ else:
420
+ K_ff = K.to_csr()
421
+ F_free_base = np.asarray(base_rhs, dtype=float)
422
+ offset = None
423
+ free = None
424
+ dir_arr = dir_vals_arr = None
425
+ n_total = K_ff.shape[0]
426
+
427
+ F_free = np.asarray(F_free_base, dtype=float)
428
+ if offset is not None:
429
+ F_free = F_free - offset
430
+
431
+ with timer.section(f"solve>{self.config.method}"):
432
+ try:
433
+ if self.config.method == "spsolve":
434
+ u_free = spdirect_solve_cpu(K_ff, F_free)
435
+ lin_iters = 1
436
+ lin_conv = True
437
+ lin_res = None
438
+ info = SolverResult(
439
+ converged=True,
440
+ iters=lin_iters,
441
+ linear_iters=lin_iters,
442
+ linear_converged=lin_conv,
443
+ linear_residual=lin_res,
444
+ tol=self.config.tol,
445
+ stop_reason="converged",
446
+ )
447
+ elif self.config.method == "spdirect_solve_gpu":
448
+ u_free = spdirect_solve_gpu(K_ff, F_free)
449
+ lin_iters = 1
450
+ lin_conv = True
451
+ lin_res = None
452
+ info = SolverResult(
453
+ converged=True,
454
+ iters=lin_iters,
455
+ linear_iters=lin_iters,
456
+ linear_converged=lin_conv,
457
+ linear_residual=lin_res,
458
+ tol=self.config.tol,
459
+ stop_reason="converged",
460
+ )
461
+ elif self.config.method in ("cg", "cg_custom"):
462
+ coo = K_ff.tocoo()
463
+ A_cg = FluxSparseMatrix.from_bilinear(
464
+ (
465
+ jnp.asarray(coo.row, dtype=jnp.int32),
466
+ jnp.asarray(coo.col, dtype=jnp.int32),
467
+ jnp.asarray(coo.data),
468
+ K_ff.shape[0],
469
+ )
470
+ )
471
+ cg_solver = cg_solve_jax if self.config.method == "cg" else cg_solve
472
+ u_free, cg_info = cg_solver(
473
+ A_cg, jnp.asarray(F_free),
474
+ tol=self.config.tol,
475
+ maxiter=self.config.maxiter,
476
+ preconditioner=self.config.preconditioner
477
+ )
478
+ lin_iters = cg_info.get("iters")
479
+ lin_conv = bool(cg_info.get("converged", True))
480
+ lin_res = cg_info.get("residual_norm", cg_info.get("residual"))
481
+ info = SolverResult(
482
+ converged=lin_conv,
483
+ iters=int(lin_iters) if lin_iters is not None else 0,
484
+ linear_iters=int(lin_iters) if lin_iters is not None else None,
485
+ linear_converged=lin_conv,
486
+ linear_residual=float(lin_res) if lin_res is not None else None,
487
+ tol=self.config.tol,
488
+ stop_reason=("converged" if lin_conv else "linfail"),
489
+ nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
490
+ )
491
+ else:
492
+ raise ValueError(f"Unknown linear solve method: {self.config.method}")
493
+
494
+ except Exception as e: # pragma: no cover - defensive
495
+ exception = repr(e)
496
+ info = SolverResult(
497
+ converged=False,
498
+ iters=0,
499
+ linear_iters=None,
500
+ linear_converged=False,
501
+ linear_residual=None,
502
+ tol=self.config.tol,
503
+ stop_reason=f"exception: {exception}",
504
+ nan_detected=False,
505
+ )
506
+ # keep u as None to signal failure; caller should treat as Optional
507
+ result = LinearStepResult(info=info, solve_time=0.0, u=None)
508
+ if step_callback is not None:
509
+ step_callback(result)
510
+ # Report outside.
511
+ if report_timing:
512
+ timer.report(sort_by="total")
513
+ return None, [result], timer
514
+
515
+ with timer.section("expand_dirichlet"):
516
+ if free is None:
517
+ u_full = np.asarray(u_free, dtype=float)
518
+ else:
519
+ u_full = expand_dirichlet_solution(
520
+ u_free, free, dir_arr, dir_vals_arr, n_total
521
+ )
522
+
523
+ solve_key = f"solve>{self.config.method}"
524
+ solve_time = timer._records.get(solve_key, [0.0])[-1]
525
+ result = LinearStepResult(
526
+ info=info, solve_time=solve_time, u=u_full
527
+ )
528
+ if step_callback is not None:
529
+ step_callback(result)
530
+
531
+ if report_timing:
532
+ timer.report(sort_by="total")
533
+
534
+ return u_full, [result]
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+
6
+ from .cg import cg_solve, cg_solve_jax
7
+ from .newton import newton_solve
8
+ from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
9
+ from .dirichlet import (
10
+ condense_dirichlet_dense,
11
+ condense_dirichlet_fluxsparse,
12
+ expand_dirichlet_solution,
13
+ enforce_dirichlet_dense,
14
+ enforce_dirichlet_sparse,
15
+ )
16
+ from .sparse import FluxSparseMatrix
17
+ from ..core.space import FESpace
18
+
19
+
20
+ class LinearSolver:
21
+ """
22
+ Lightweight wrapper for solving linear systems with optional Dirichlet BCs.
23
+
24
+ Supports dense arrays or FluxSparseMatrix and can either condense or enforce
25
+ Dirichlet conditions before solving with the chosen backend.
26
+ """
27
+
28
+ def __init__(self, method: str = "spsolve", tol: float = 1e-8, maxiter: int = 200):
29
+ self.method = method
30
+ self.tol = tol
31
+ self.maxiter = maxiter
32
+
33
+ def _solve_free(self, A, b):
34
+ if self.method == "cg":
35
+ x, info = cg_solve_jax(A, b, tol=self.tol, maxiter=self.maxiter)
36
+ return np.asarray(x), {"iters": info.get("iters"), "converged": info.get("converged", True)}
37
+ elif self.method == "cg_custom":
38
+ x, info = cg_solve(A, b, tol=self.tol, maxiter=self.maxiter)
39
+ return np.asarray(x), {"iters": info.get("iters"), "converged": info.get("converged", True)}
40
+ elif self.method == "spsolve":
41
+ x = spdirect_solve_cpu(A, b)
42
+ return np.asarray(x), {"iters": 1, "converged": True}
43
+ elif self.method == "spsolve_jax":
44
+ x = spdirect_solve_cpu(A, b, use_jax=True)
45
+ return np.asarray(x), {"iters": 1, "converged": True}
46
+ elif self.method == "spdirect_solve_gpu":
47
+ x = spdirect_solve_gpu(A, b)
48
+ return np.asarray(x), {"iters": 1, "converged": True}
49
+ else:
50
+ raise ValueError(f"Unknown linear method: {self.method}")
51
+
52
+ def solve(
53
+ self,
54
+ A,
55
+ b,
56
+ *,
57
+ dirichlet=None,
58
+ dirichlet_mode: str = "condense",
59
+ n_total: int | None = None,
60
+ ):
61
+ if dirichlet is None:
62
+ return self._solve_free(A, b)
63
+
64
+ if dirichlet_mode not in ("condense", "enforce"):
65
+ raise ValueError("dirichlet_mode must be 'condense' or 'enforce'.")
66
+
67
+ dir_dofs, dir_vals = dirichlet
68
+ if dirichlet_mode == "enforce":
69
+ if isinstance(A, FluxSparseMatrix):
70
+ A_bc, b_bc = enforce_dirichlet_sparse(A, b, dir_dofs, dir_vals)
71
+ else:
72
+ A_bc, b_bc = enforce_dirichlet_dense(A, b, dir_dofs, dir_vals)
73
+ return self._solve_free(A_bc, b_bc)
74
+
75
+ if isinstance(A, FluxSparseMatrix):
76
+ K_ff, F_free, free, dir_arr, dir_vals_arr = condense_dirichlet_fluxsparse(A, b, dir_dofs, dir_vals)
77
+ else:
78
+ K_ff, F_free, free, dir_arr, dir_vals_arr = condense_dirichlet_dense(A, b, dir_dofs, dir_vals)
79
+ u_free, info = self._solve_free(K_ff, F_free)
80
+ if n_total is not None:
81
+ n_total_use = int(n_total)
82
+ elif isinstance(A, FluxSparseMatrix):
83
+ n_total_use = int(A.n_dofs)
84
+ else:
85
+ n_total_use = int(getattr(A, "shape", [0])[0])
86
+ u_full = expand_dirichlet_solution(u_free, free, dir_arr, dir_vals_arr, n_total=n_total_use)
87
+ return u_full, info
88
+
89
+
90
+ class NonlinearSolver:
91
+ """
92
+ Backward-compatible Newton-based nonlinear solver.
93
+
94
+ This is a thin wrapper around ``newton_solve`` kept for legacy code paths.
95
+ Prefer ``NonlinearAnalysis`` + ``NewtonSolveRunner`` for new workflows.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ space: FESpace,
101
+ res_form,
102
+ params,
103
+ *,
104
+ tol: float = 1e-8,
105
+ maxiter: int = 20,
106
+ linear_method: str = "spsolve",
107
+ line_search: bool = False,
108
+ max_ls: int = 10,
109
+ ls_c: float = 1e-4,
110
+ linear_tol: float | None = None,
111
+ dirichlet=None,
112
+ external_vector=None,
113
+ linear_maxiter: int | None = None,
114
+ jacobian_pattern=None,
115
+ ):
116
+ self.space = space
117
+ self.res_form = res_form
118
+ self.params = params
119
+ self.tol = tol
120
+ self.maxiter = maxiter
121
+ self.linear_method = linear_method
122
+ self.line_search = line_search
123
+ self.max_ls = max_ls
124
+ self.ls_c = ls_c
125
+ self.linear_tol = linear_tol
126
+ self.dirichlet = dirichlet
127
+ self.external_vector = external_vector
128
+ self.linear_maxiter = linear_maxiter
129
+ self.jacobian_pattern = jacobian_pattern
130
+
131
+ def solve(self, u0):
132
+ return newton_solve(
133
+ self.space,
134
+ self.res_form,
135
+ jnp.asarray(u0),
136
+ self.params,
137
+ tol=self.tol,
138
+ maxiter=self.maxiter,
139
+ linear_solver=self.linear_method,
140
+ line_search=self.line_search,
141
+ max_ls=self.max_ls,
142
+ ls_c=self.ls_c,
143
+ linear_tol=self.linear_tol,
144
+ dirichlet=self.dirichlet,
145
+ external_vector=self.external_vector,
146
+ linear_maxiter=self.linear_maxiter,
147
+ jacobian_pattern=self.jacobian_pattern,
148
+ )