fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass, field
5
+ import warnings
5
6
  from typing import Any, Callable, Iterable, List, Sequence
6
7
 
7
8
  import numpy as np
@@ -10,6 +11,7 @@ import jax.numpy as jnp
10
11
  from ..core.assembly import assemble_bilinear_form
11
12
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
12
13
  from .cg import cg_solve, cg_solve_jax
14
+ from .petsc import petsc_shell_solve
13
15
  from .sparse import FluxSparseMatrix
14
16
  from .dirichlet import expand_dirichlet_solution
15
17
  from .newton import newton_solve
@@ -35,6 +37,8 @@ class NonlinearAnalysis:
35
37
  Unscaled external load vector (scaled by load factor in `external_for_load`).
36
38
  dirichlet : tuple | None
37
39
  (dofs, values) for Dirichlet boundary conditions.
40
+ extra_terms : list[callable] | None
41
+ Optional extra term assemblers returning (K, f[, metrics]).
38
42
  jacobian_pattern : Any | None
39
43
  Optional sparsity pattern to reuse between load steps.
40
44
  dtype : Any
@@ -46,6 +50,7 @@ class NonlinearAnalysis:
46
50
  params: Any
47
51
  base_external_vector: Any | None = None
48
52
  dirichlet: tuple | None = None
53
+ extra_terms: list[Callable] | None = None
49
54
  jacobian_pattern: Any | None = None
50
55
  dtype: Any = jnp.float64
51
56
 
@@ -71,6 +76,7 @@ class NewtonLoopConfig:
71
76
  linear_maxiter: int | None = None
72
77
  linear_tol: float | None = None
73
78
  linear_preconditioner: Any | None = None
79
+ matfree_mode: str = "linearize"
74
80
  load_sequence: Sequence[float] | None = None
75
81
  n_steps: int = 1
76
82
 
@@ -95,6 +101,7 @@ class NewtonSolveRunner:
95
101
  def __init__(self, analysis: NonlinearAnalysis, config: NewtonLoopConfig):
96
102
  self.analysis = analysis
97
103
  self.config = config
104
+ self._matfree_cache: dict[str, Any] = {}
98
105
 
99
106
  def run(
100
107
  self,
@@ -154,6 +161,16 @@ class NewtonSolveRunner:
154
161
  schedule.append(lf_clamped)
155
162
  prev = lf_clamped
156
163
  history: List[LoadStepResult] = []
164
+ matfree_cache = None
165
+ if self.config.linear_preconditioner == "diag0":
166
+ n_free = self.analysis.space.n_dofs
167
+ if self.analysis.dirichlet is not None:
168
+ n_free -= len(self.analysis.dirichlet[0])
169
+ cached_free = self._matfree_cache.get("n_free_dofs")
170
+ if cached_free is not None and cached_free != n_free:
171
+ self._matfree_cache.clear()
172
+ self._matfree_cache["n_free_dofs"] = n_free
173
+ matfree_cache = self._matfree_cache
157
174
  for step_i, load_factor in enumerate(schedule, start=1):
158
175
  with timer.section("step"):
159
176
  external = self.analysis.external_for_load(load_factor)
@@ -203,6 +220,8 @@ class NewtonSolveRunner:
203
220
  linear_maxiter=self.config.linear_maxiter,
204
221
  linear_tol=self.config.linear_tol,
205
222
  linear_preconditioner=self.config.linear_preconditioner,
223
+ matfree_cache=matfree_cache,
224
+ matfree_mode=self.config.matfree_mode,
206
225
  dirichlet=self.analysis.dirichlet,
207
226
  line_search=self.config.line_search,
208
227
  max_ls=self.config.max_ls,
@@ -210,6 +229,7 @@ class NewtonSolveRunner:
210
229
  external_vector=external,
211
230
  callback=cb,
212
231
  jacobian_pattern=self.analysis.jacobian_pattern,
232
+ extra_terms=self.analysis.extra_terms,
213
233
  )
214
234
  exception = None
215
235
  except Exception as e: # pragma: no cover - defensive
@@ -279,6 +299,7 @@ def solve_nonlinear(
279
299
  *,
280
300
  dirichlet: tuple | None = None,
281
301
  base_external_vector=None,
302
+ extra_terms=None,
282
303
  dtype=jnp.float64,
283
304
  maxiter: int = 20,
284
305
  tol: float = 1e-8,
@@ -287,6 +308,7 @@ def solve_nonlinear(
287
308
  linear_maxiter: int | None = None,
288
309
  linear_tol: float | None = None,
289
310
  linear_preconditioner=None,
311
+ matfree_mode: str = "linearize",
290
312
  line_search: bool = False,
291
313
  max_ls: int = 10,
292
314
  ls_c: float = 1e-4,
@@ -303,6 +325,7 @@ def solve_nonlinear(
303
325
  params=params,
304
326
  base_external_vector=base_external_vector,
305
327
  dirichlet=dirichlet,
328
+ extra_terms=extra_terms,
306
329
  dtype=dtype,
307
330
  jacobian_pattern=jacobian_pattern,
308
331
  )
@@ -314,6 +337,7 @@ def solve_nonlinear(
314
337
  linear_maxiter=linear_maxiter,
315
338
  linear_tol=linear_tol,
316
339
  linear_preconditioner=linear_preconditioner,
340
+ matfree_mode=matfree_mode,
317
341
  line_search=line_search,
318
342
  max_ls=max_ls,
319
343
  ls_c=ls_c,
@@ -359,10 +383,42 @@ class LinearSolveConfig:
359
383
  Control parameters for the linear solve with optional load scaling.
360
384
  """
361
385
 
362
- method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom"
386
+ method: str = "spsolve" # "spsolve" | "spdirect_solve_gpu" | "cg" | "cg_custom" | "petsc_shell"
363
387
  tol: float = 1e-8
364
388
  maxiter: int | None = None
365
389
  preconditioner: Any | None = None
390
+ ksp_type: str | None = None
391
+ pc_type: str | None = None
392
+ ksp_rtol: float | None = None
393
+ ksp_atol: float | None = None
394
+ ksp_max_it: int | None = None
395
+ petsc_ksp_norm_type: str | None = None
396
+ petsc_ksp_monitor_true_residual: bool = False
397
+ petsc_ksp_converged_reason: bool = False
398
+ petsc_ksp_monitor_short: bool = False
399
+ petsc_shell_pmat: bool = False
400
+ petsc_shell_pmat_mode: str = "full"
401
+ petsc_shell_pmat_rebuild_iters: int | None = None
402
+ petsc_shell_fallback: bool = False
403
+ petsc_shell_fallback_ksp_types: tuple[str, ...] = ("bcgs", "gmres")
404
+ petsc_shell_fallback_rebuild_pmat: bool = True
405
+
406
+ @classmethod
407
+ def from_preset(cls, name: str) -> "LinearSolveConfig":
408
+ preset = name.lower()
409
+ if preset == "contact":
410
+ return cls(
411
+ method="petsc_shell",
412
+ ksp_type="bcgs",
413
+ pc_type="ilu",
414
+ petsc_shell_pmat=True,
415
+ petsc_shell_pmat_mode="full",
416
+ petsc_ksp_norm_type="unpreconditioned",
417
+ petsc_ksp_monitor_true_residual=True,
418
+ petsc_ksp_converged_reason=True,
419
+ petsc_shell_fallback=True,
420
+ )
421
+ raise ValueError(f"Unknown LinearSolveConfig preset: {name}")
366
422
 
367
423
 
368
424
  @dataclass
@@ -392,6 +448,9 @@ class LinearSolveRunner:
392
448
  def __init__(self, analysis: LinearAnalysis, config: LinearSolveConfig):
393
449
  self.analysis = analysis
394
450
  self.config = config
451
+ self._petsc_shell_pmat = None
452
+ self._petsc_shell_last_iters = None
453
+ self._petsc_shell_pmat_rebuilds = 0
395
454
 
396
455
  def run(
397
456
  self,
@@ -488,6 +547,114 @@ class LinearSolveRunner:
488
547
  stop_reason=("converged" if lin_conv else "linfail"),
489
548
  nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
490
549
  )
550
+ elif self.config.method == "petsc_shell":
551
+ base_ksp_type = self.config.ksp_type or "gmres"
552
+ pc_type = self.config.pc_type if self.config.pc_type is not None else "none"
553
+ ksp_rtol = self.config.ksp_rtol if self.config.ksp_rtol is not None else self.config.tol
554
+ ksp_atol = self.config.ksp_atol
555
+ ksp_max_it = self.config.ksp_max_it if self.config.ksp_max_it is not None else self.config.maxiter
556
+ petsc_options = {}
557
+ if self.config.petsc_ksp_norm_type:
558
+ petsc_options["fluxfem_ksp_norm_type"] = self.config.petsc_ksp_norm_type
559
+ if self.config.petsc_ksp_monitor_true_residual:
560
+ petsc_options["fluxfem_ksp_monitor_true_residual"] = ""
561
+ if self.config.petsc_ksp_converged_reason:
562
+ petsc_options["fluxfem_ksp_converged_reason"] = ""
563
+ if self.config.petsc_ksp_monitor_short:
564
+ petsc_options["fluxfem_ksp_monitor_short"] = ""
565
+ if not petsc_options:
566
+ petsc_options = None
567
+ use_pmat = bool(self.config.petsc_shell_pmat)
568
+ rebuild_thresh = self.config.petsc_shell_pmat_rebuild_iters
569
+ if use_pmat:
570
+ pmat_mode = (self.config.petsc_shell_pmat_mode or "full").lower()
571
+ if pmat_mode == "none":
572
+ use_pmat = False
573
+ pmat = None
574
+ elif pmat_mode == "full":
575
+ pmat = K_ff
576
+ else:
577
+ warnings.warn(
578
+ f"petsc_shell_pmat_mode='{pmat_mode}' is not supported in runner; "
579
+ "falling back to 'full'.",
580
+ RuntimeWarning,
581
+ )
582
+ pmat = K_ff
583
+ if use_pmat:
584
+ if self._petsc_shell_pmat is None:
585
+ self._petsc_shell_pmat = pmat
586
+ self._petsc_shell_pmat_rebuilds += 1
587
+ elif rebuild_thresh is not None and self._petsc_shell_last_iters is not None:
588
+ if self._petsc_shell_last_iters > rebuild_thresh:
589
+ self._petsc_shell_pmat = pmat
590
+ self._petsc_shell_pmat_rebuilds += 1
591
+ pmat = self._petsc_shell_pmat
592
+ if not use_pmat:
593
+ pmat = None
594
+
595
+ def _attempt_solve(ksp_type: str):
596
+ return petsc_shell_solve(
597
+ K_ff,
598
+ F_free,
599
+ preconditioner=self.config.preconditioner,
600
+ ksp_type=ksp_type,
601
+ pc_type=pc_type,
602
+ rtol=ksp_rtol,
603
+ atol=ksp_atol,
604
+ max_it=ksp_max_it,
605
+ pmat=pmat,
606
+ options=petsc_options,
607
+ return_info=True,
608
+ )
609
+
610
+ fallback_ksp = [base_ksp_type]
611
+ if self.config.petsc_shell_fallback:
612
+ for ksp in self.config.petsc_shell_fallback_ksp_types:
613
+ if ksp not in fallback_ksp:
614
+ fallback_ksp.append(ksp)
615
+ fallback_attempts = []
616
+ petsc_info = None
617
+ u_free = None
618
+ for ksp in fallback_ksp:
619
+ fallback_attempts.append(ksp)
620
+ u_free, petsc_info = _attempt_solve(ksp)
621
+ lin_conv = petsc_info.get("converged")
622
+ reason = petsc_info.get("reason")
623
+ if lin_conv is None and reason is not None:
624
+ lin_conv = reason > 0
625
+ if lin_conv:
626
+ break
627
+ if self.config.petsc_shell_fallback and use_pmat and self.config.petsc_shell_fallback_rebuild_pmat:
628
+ self._petsc_shell_pmat = pmat
629
+ self._petsc_shell_pmat_rebuilds += 1
630
+ lin_iters = petsc_info.get("iters")
631
+ lin_res = petsc_info.get("residual_norm")
632
+ lin_solve_dt = petsc_info.get("solve_time")
633
+ pc_setup_dt = petsc_info.get("pc_setup_time")
634
+ pmat_dt = petsc_info.get("pmat_build_time")
635
+ lin_conv = petsc_info.get("converged")
636
+ if lin_conv is None and petsc_info.get("reason") is not None:
637
+ lin_conv = petsc_info.get("reason") > 0
638
+ if lin_conv is None:
639
+ lin_conv = True
640
+ self._petsc_shell_last_iters = lin_iters
641
+ info = SolverResult(
642
+ converged=bool(lin_conv),
643
+ iters=int(lin_iters) if lin_iters is not None else 0,
644
+ linear_iters=int(lin_iters) if lin_iters is not None else None,
645
+ linear_converged=bool(lin_conv),
646
+ linear_residual=float(lin_res) if lin_res is not None else None,
647
+ linear_solve_time=float(lin_solve_dt) if lin_solve_dt is not None else None,
648
+ pc_setup_time=float(pc_setup_dt) if pc_setup_dt is not None else None,
649
+ pmat_build_time=float(pmat_dt) if pmat_dt is not None else None,
650
+ pmat_rebuilds=self._petsc_shell_pmat_rebuilds if use_pmat else None,
651
+ pmat_mode=self.config.petsc_shell_pmat_mode if use_pmat else None,
652
+ tol=self.config.tol,
653
+ stop_reason=("converged" if lin_conv else "linfail"),
654
+ nan_detected=bool(np.isnan(lin_res)) if lin_res is not None else False,
655
+ )
656
+ if len(fallback_attempts) > 1:
657
+ info.linear_fallbacks = fallback_attempts
491
658
  else:
492
659
  raise ValueError(f"Unknown linear solve method: {self.config.method}")
493
660
 
fluxfem/solver/solver.py CHANGED
@@ -5,8 +5,10 @@ import jax.numpy as jnp
5
5
 
6
6
  from .cg import cg_solve, cg_solve_jax
7
7
  from .newton import newton_solve
8
+ from .petsc import petsc_solve, petsc_shell_solve
8
9
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
9
10
  from .dirichlet import (
11
+ DirichletBC,
10
12
  condense_dirichlet_dense,
11
13
  condense_dirichlet_fluxsparse,
12
14
  expand_dirichlet_solution,
@@ -46,6 +48,12 @@ class LinearSolver:
46
48
  elif self.method == "spdirect_solve_gpu":
47
49
  x = spdirect_solve_gpu(A, b)
48
50
  return np.asarray(x), {"iters": 1, "converged": True}
51
+ elif self.method == "petsc":
52
+ x = petsc_solve(A, b)
53
+ return np.asarray(x), {"iters": None, "converged": True}
54
+ elif self.method == "petsc_shell":
55
+ x = petsc_shell_solve(A, b)
56
+ return np.asarray(x), {"iters": None, "converged": True}
49
57
  else:
50
58
  raise ValueError(f"Unknown linear method: {self.method}")
51
59
 
@@ -64,7 +72,10 @@ class LinearSolver:
64
72
  if dirichlet_mode not in ("condense", "enforce"):
65
73
  raise ValueError("dirichlet_mode must be 'condense' or 'enforce'.")
66
74
 
67
- dir_dofs, dir_vals = dirichlet
75
+ if isinstance(dirichlet, DirichletBC):
76
+ dir_dofs, dir_vals = dirichlet.as_tuple()
77
+ else:
78
+ dir_dofs, dir_vals = dirichlet
68
79
  if dirichlet_mode == "enforce":
69
80
  if isinstance(A, FluxSparseMatrix):
70
81
  A_bc, b_bc = enforce_dirichlet_sparse(A, b, dir_dofs, dir_vals)
fluxfem/solver/sparse.py CHANGED
@@ -12,6 +12,79 @@ except Exception: # pragma: no cover
12
12
  sp = None
13
13
 
14
14
 
15
+ def coalesce_coo(rows, cols, data):
16
+ """
17
+ Sum duplicate COO entries by sorting (CPU-friendly).
18
+ Returns (rows_u, cols_u, data_u) as NumPy arrays.
19
+ """
20
+ r = np.asarray(rows, dtype=np.int64)
21
+ c = np.asarray(cols, dtype=np.int64)
22
+ d = np.asarray(data)
23
+ if r.size == 0:
24
+ return r, c, d
25
+ order = np.lexsort((c, r))
26
+ r_s = r[order]
27
+ c_s = c[order]
28
+ d_s = d[order]
29
+ new_group = np.ones(r_s.size, dtype=bool)
30
+ new_group[1:] = (r_s[1:] != r_s[:-1]) | (c_s[1:] != c_s[:-1])
31
+ starts = np.nonzero(new_group)[0]
32
+ r_u = r_s[starts]
33
+ c_u = c_s[starts]
34
+ d_u = np.add.reduceat(d_s, starts)
35
+ return r_u, c_u, d_u
36
+
37
+
38
+ def _normalize_flux_mats(mats):
39
+ if len(mats) == 1 and isinstance(mats[0], (list, tuple)):
40
+ mats = tuple(mats[0])
41
+ if not mats:
42
+ raise ValueError("At least one FluxSparseMatrix is required.")
43
+ return mats
44
+
45
+
46
+ def concat_flux(*mats, n_dofs: int | None = None):
47
+ """
48
+ Concatenate COO entries from multiple FluxSparseMatrix objects.
49
+ All matrices must share the same n_dofs unless n_dofs is provided.
50
+ """
51
+ mats = _normalize_flux_mats(mats)
52
+ if n_dofs is None:
53
+ n_dofs = int(mats[0].n_dofs)
54
+ for mat in mats[1:]:
55
+ if int(mat.n_dofs) != n_dofs:
56
+ raise ValueError("All matrices must share n_dofs for concat_flux.")
57
+ rows_list = [np.asarray(mat.pattern.rows, dtype=np.int32) for mat in mats]
58
+ cols_list = [np.asarray(mat.pattern.cols, dtype=np.int32) for mat in mats]
59
+ data_list = [np.asarray(mat.data) for mat in mats]
60
+ rows = np.concatenate(rows_list) if rows_list else np.asarray([], dtype=np.int32)
61
+ cols = np.concatenate(cols_list) if cols_list else np.asarray([], dtype=np.int32)
62
+ data = np.concatenate(data_list) if data_list else np.asarray([], dtype=float)
63
+ return FluxSparseMatrix(rows, cols, data, int(n_dofs))
64
+
65
+
66
+ def block_diag_flux(*mats):
67
+ """Block-diagonal concatenation for FluxSparseMatrix objects."""
68
+ mats = _normalize_flux_mats(mats)
69
+ rows_out = []
70
+ cols_out = []
71
+ data_out = []
72
+ offset = 0
73
+ for mat in mats:
74
+ rows = np.asarray(mat.pattern.rows, dtype=np.int32)
75
+ cols = np.asarray(mat.pattern.cols, dtype=np.int32)
76
+ data = np.asarray(mat.data)
77
+ if rows.size:
78
+ rows_out.append(rows + offset)
79
+ cols_out.append(cols + offset)
80
+ data_out.append(data)
81
+ offset += int(mat.n_dofs)
82
+ rows = np.concatenate(rows_out) if rows_out else np.asarray([], dtype=np.int32)
83
+ cols = np.concatenate(cols_out) if cols_out else np.asarray([], dtype=np.int32)
84
+ data = np.concatenate(data_out) if data_out else np.asarray([], dtype=float)
85
+ return FluxSparseMatrix(rows, cols, data, int(offset))
86
+
87
+
15
88
  @jax.tree_util.register_pytree_node_class
16
89
  @dataclass(frozen=True)
17
90
  class SparsityPattern:
@@ -80,7 +153,7 @@ class FluxSparseMatrix:
80
153
  - data stores the numeric values for the current nonlinear iterate
81
154
  """
82
155
 
83
- def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None):
156
+ def __init__(self, rows_or_pattern, cols=None, data=None, n_dofs: int | None = None, meta: dict | None = None):
84
157
  # New signature: FluxSparseMatrix(pattern, data)
85
158
  if isinstance(rows_or_pattern, SparsityPattern):
86
159
  pattern = rows_or_pattern
@@ -88,15 +161,22 @@ class FluxSparseMatrix:
88
161
  values = jnp.asarray(values)
89
162
  else:
90
163
  # Legacy signature: FluxSparseMatrix(rows, cols, data, n_dofs)
91
- r_np = np.asarray(rows_or_pattern, dtype=np.int32)
92
- c_np = np.asarray(cols, dtype=np.int32)
93
- diag_idx_np = np.nonzero(r_np == c_np)[0].astype(np.int32)
164
+ r_j = jnp.asarray(rows_or_pattern, dtype=jnp.int32)
165
+ c_j = jnp.asarray(cols, dtype=jnp.int32)
166
+ is_tracer = isinstance(rows_or_pattern, jax.core.Tracer) or isinstance(cols, jax.core.Tracer)
167
+ diag_idx_j = None
168
+ if not is_tracer:
169
+ diag_idx_j = jnp.nonzero(r_j == c_j)[0].astype(jnp.int32)
170
+ if n_dofs is None:
171
+ if is_tracer:
172
+ raise ValueError("n_dofs must be provided when constructing FluxSparseMatrix under JIT.")
173
+ n_dofs = int(np.asarray(cols).max()) + 1
94
174
  pattern = SparsityPattern(
95
- rows=jnp.asarray(r_np),
96
- cols=jnp.asarray(c_np),
97
- n_dofs=int(n_dofs) if n_dofs is not None else int(c_np.max()) + 1,
175
+ rows=r_j,
176
+ cols=c_j,
177
+ n_dofs=int(n_dofs) if n_dofs is not None else int(np.asarray(cols).max()) + 1,
98
178
  idx=None,
99
- diag_idx=jnp.asarray(diag_idx_np),
179
+ diag_idx=diag_idx_j,
100
180
  )
101
181
  values = jnp.asarray(data)
102
182
 
@@ -105,6 +185,7 @@ class FluxSparseMatrix:
105
185
  self.cols = pattern.cols
106
186
  self.n_dofs = int(pattern.n_dofs)
107
187
  self.data = values
188
+ self.meta = dict(meta) if meta is not None else None
108
189
 
109
190
  @classmethod
110
191
  def from_bilinear(cls, coo_tuple):
@@ -121,11 +202,25 @@ class FluxSparseMatrix:
121
202
 
122
203
  def with_data(self, data):
123
204
  """Return a new FluxSparseMatrix sharing the same pattern with updated data."""
124
- return FluxSparseMatrix(self.pattern, data)
205
+ return FluxSparseMatrix(self.pattern, data, meta=self.meta)
206
+
207
+ def add_dense(self, dense):
208
+ """Return a new FluxSparseMatrix with dense entries added on the pattern."""
209
+ dense_vals = jnp.asarray(dense)[self.pattern.rows, self.pattern.cols]
210
+ return FluxSparseMatrix(self.pattern, self.data + dense_vals)
125
211
 
126
212
  def to_coo(self):
127
213
  return self.pattern.rows, self.pattern.cols, self.data, self.pattern.n_dofs
128
214
 
215
+ @property
216
+ def nnz(self) -> int:
217
+ return int(self.data.shape[0])
218
+
219
+ def coalesce(self):
220
+ """Return a new FluxSparseMatrix with duplicate entries summed."""
221
+ rows_u, cols_u, data_u = coalesce_coo(self.pattern.rows, self.pattern.cols, self.data)
222
+ return FluxSparseMatrix(rows_u, cols_u, data_u, self.pattern.n_dofs)
223
+
129
224
  def to_csr(self):
130
225
  if sp is None:
131
226
  raise ImportError("scipy is required for to_csr()")
@@ -167,6 +262,26 @@ class FluxSparseMatrix:
167
262
  out = jnp.zeros(self.pattern.n_dofs, dtype=contrib.dtype)
168
263
  return out.at[self.pattern.rows].add(contrib)
169
264
 
265
+ def as_cg_operator(
266
+ self,
267
+ *,
268
+ matvec: str = "flux",
269
+ preconditioner=None,
270
+ solver: str = "cg",
271
+ dof_per_node: int | None = None,
272
+ block_sizes=None,
273
+ ):
274
+ from .cg import build_cg_operator
275
+
276
+ return build_cg_operator(
277
+ self,
278
+ matvec=matvec,
279
+ preconditioner=preconditioner,
280
+ solver=solver,
281
+ dof_per_node=dof_per_node,
282
+ block_sizes=block_sizes,
283
+ )
284
+
170
285
  def diag(self):
171
286
  """Diagonal entries aggregated for Jacobi preconditioning."""
172
287
  if self.pattern.diag_idx is not None: