fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/basis.py CHANGED
@@ -4,8 +4,7 @@ from typing import Protocol
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import numpy as np
7
- from .dtypes import DEFAULT_DTYPE as _FDTYPE
8
- # from .dtypes import DEFAULT_DTYPE
7
+ from .dtypes import default_dtype
9
8
 
10
9
 
11
10
  def build_B_matrices(dN_dx: jnp.ndarray) -> jnp.ndarray:
@@ -271,6 +270,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
271
270
 
272
271
  @property
273
272
  def ref_node_coords(self) -> jnp.ndarray:
273
+ dtype = default_dtype()
274
274
  return jnp.array(
275
275
  [
276
276
  [0.0, 0.0, 0.0],
@@ -278,7 +278,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
278
278
  [0.0, 1.0, 0.0],
279
279
  [0.0, 0.0, 1.0],
280
280
  ],
281
- dtype=_FDTYPE,
281
+ dtype=dtype,
282
282
  )
283
283
 
284
284
  def shape_functions(self) -> jnp.ndarray:
@@ -294,6 +294,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
294
294
 
295
295
  def shape_grads_ref(self) -> jnp.ndarray:
296
296
  # constant gradients in reference tetra
297
+ dtype = default_dtype()
297
298
  dN = jnp.array(
298
299
  [
299
300
  [-1.0, -1.0, -1.0],
@@ -301,7 +302,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
301
302
  [0.0, 1.0, 0.0],
302
303
  [0.0, 0.0, 1.0],
303
304
  ],
304
- dtype=_FDTYPE,
305
+ dtype=dtype,
305
306
  )
306
307
  dN = jnp.tile(dN[None, :, :], (self.n_q, 1, 1)) # (n_q,4,3)
307
308
  return dN
@@ -374,7 +375,7 @@ class TetQuadraticBasis10(SmallStrainBMixin, TotalLagrangeBMixin):
374
375
 
375
376
  grads = []
376
377
  for a, dLa in zip([L1, L2, L3, L4], [dL1, dL2, dL3, dL4]):
377
- grads.append((2 * a - 1)[..., None] * dLa[None, :])
378
+ grads.append((4 * a - 1)[..., None] * dLa[None, :])
378
379
 
379
380
  dN1 = grads[0]
380
381
  dN2 = grads[1]
@@ -420,7 +421,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
420
421
 
421
422
  def tree_flatten(self):
422
423
  children = (self.quad_points, self.quad_weights)
423
- aux_data = {}
424
+ aux_data: dict[str, object] = {}
424
425
  return children, aux_data
425
426
 
426
427
  @classmethod
@@ -446,6 +447,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
446
447
  6: ( 1, 1, 1)
447
448
  7: (-1, 1, 1)
448
449
  """
450
+ dtype = default_dtype()
449
451
  return jnp.array(
450
452
  [
451
453
  [-1.0, -1.0, -1.0],
@@ -457,7 +459,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
457
459
  [ 1.0, 1.0, 1.0],
458
460
  [-1.0, 1.0, 1.0],
459
461
  ],
460
- dtype=_FDTYPE,
462
+ dtype=dtype,
461
463
  )
462
464
 
463
465
  # ---------- reference shape functions & gradients ----------
@@ -567,7 +569,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
567
569
 
568
570
  def tree_flatten(self):
569
571
  children = (self.quad_points, self.quad_weights)
570
- aux_data = {}
572
+ aux_data: dict[str, object] = {}
571
573
  return children, aux_data
572
574
 
573
575
  @classmethod
@@ -581,6 +583,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
581
583
 
582
584
  @property
583
585
  def ref_node_coords(self) -> jnp.ndarray:
586
+ dtype = default_dtype()
584
587
  corners = jnp.array(
585
588
  [
586
589
  [-1.0, -1.0, -1.0],
@@ -592,7 +595,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
592
595
  [ 1.0, 1.0, 1.0],
593
596
  [-1.0, 1.0, 1.0],
594
597
  ],
595
- dtype=_FDTYPE,
598
+ dtype=dtype,
596
599
  )
597
600
  edges = jnp.array(
598
601
  [
@@ -609,7 +612,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
609
612
  [ 1.0, 1.0, 0.0], # 2-6
610
613
  [-1.0, 1.0, 0.0], # 3-7
611
614
  ],
612
- dtype=_FDTYPE,
615
+ dtype=dtype,
613
616
  )
614
617
  return jnp.concatenate([corners, edges], axis=0) # (20,3)
615
618
 
@@ -620,6 +623,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
620
623
  zeta = qp[:, 2:3]
621
624
 
622
625
  # corners
626
+ dtype = default_dtype()
623
627
  s = jnp.array(
624
628
  [
625
629
  [-1, -1, -1],
@@ -631,7 +635,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
631
635
  [ 1, 1, 1],
632
636
  [-1, 1, 1],
633
637
  ],
634
- dtype=_FDTYPE,
638
+ dtype=dtype,
635
639
  )
636
640
  sx = s[:, 0]
637
641
  sy = s[:, 1]
@@ -639,22 +643,30 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
639
643
  term = xi * sx + eta * sy + zeta * sz - 2.0 # (n_q,8)
640
644
  N_corner = 0.125 * (1 + sx * xi) * (1 + sy * eta) * (1 + sz * zeta) * term # (n_q,8)
641
645
 
642
- # edges
643
- edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # eta, zeta fixed
644
- edges_y = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # xi fixed
645
- edges_z = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # xi, eta fixed
646
-
647
- N_edges = []
648
- # along xi (1 - xi^2)
649
- for sy, sz in edges_x:
650
- N_edges.append(0.25 * (1 - xi * xi) * (1 + sy * eta) * (1 + sz * zeta))
651
- # along eta
652
- for sx, sz in edges_y:
653
- N_edges.append(0.25 * (1 - eta * eta) * (1 + sx * xi) * (1 + sz * zeta))
654
- # along zeta
655
- for sx, sy in edges_z:
656
- N_edges.append(0.25 * (1 - zeta * zeta) * (1 + sx * xi) * (1 + sy * eta))
657
-
646
+ # edges: order matches hex20 connectivity (e01, e12, e23, e30, e45, e56, e67, e74, e04, e15, e26, e37)
647
+ def edge_x(sy, sz):
648
+ return 0.25 * (1 - xi * xi) * (1 + sy * eta) * (1 + sz * zeta)
649
+
650
+ def edge_y(sx, sz):
651
+ return 0.25 * (1 - eta * eta) * (1 + sx * xi) * (1 + sz * zeta)
652
+
653
+ def edge_z(sx, sy):
654
+ return 0.25 * (1 - zeta * zeta) * (1 + sx * xi) * (1 + sy * eta)
655
+
656
+ N_edges = [
657
+ edge_x(-1, -1),
658
+ edge_y(1, -1),
659
+ edge_x(1, -1),
660
+ edge_y(-1, -1),
661
+ edge_x(-1, 1),
662
+ edge_y(1, 1),
663
+ edge_x(1, 1),
664
+ edge_y(-1, 1),
665
+ edge_z(-1, -1),
666
+ edge_z(1, -1),
667
+ edge_z(1, 1),
668
+ edge_z(-1, 1),
669
+ ]
658
670
  N_edges = jnp.concatenate(N_edges, axis=1) # (n_q, 12)
659
671
  return jnp.concatenate([N_corner, N_edges], axis=1) # (n_q,20)
660
672
 
@@ -664,6 +676,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
664
676
  eta = qp[:, 1:2]
665
677
  zeta = qp[:, 2:3]
666
678
 
679
+ dtype = default_dtype()
667
680
  s = jnp.array(
668
681
  [
669
682
  [-1, -1, -1],
@@ -675,7 +688,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
675
688
  [ 1, 1, 1],
676
689
  [-1, 1, 1],
677
690
  ],
678
- dtype=_FDTYPE,
691
+ dtype=dtype,
679
692
  )
680
693
  sx = s[:, 0]
681
694
  sy = s[:, 1]
@@ -696,34 +709,38 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
696
709
  ) # (n_q,8,3)
697
710
 
698
711
  # edges derivatives
699
- d_list = []
700
- # along xi
701
- edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
702
- for sy_val, sz_val in edges_x:
703
- sy_ = sy_val
704
- sz_ = sz_val
712
+ def d_edge_x(sy_, sz_):
705
713
  dxi = -0.5 * xi * (1 + sy_ * eta) * (1 + sz_ * zeta)
706
714
  deta = 0.25 * (1 - xi * xi) * sy_ * (1 + sz_ * zeta)
707
715
  dzeta = 0.25 * (1 - xi * xi) * (1 + sy_ * eta) * sz_
708
- d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
709
- # along eta
710
- edges_y = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
711
- for sx_val, sz_val in edges_y:
712
- sx_ = sx_val
713
- sz_ = sz_val
716
+ return jnp.stack([dxi, deta, dzeta], axis=2)
717
+
718
+ def d_edge_y(sx_, sz_):
714
719
  dxi = 0.25 * (1 - eta * eta) * sx_ * (1 + sz_ * zeta)
715
720
  deta = -0.5 * eta * (1 + sx_ * xi) * (1 + sz_ * zeta)
716
721
  dzeta = 0.25 * (1 - eta * eta) * (1 + sx_ * xi) * sz_
717
- d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
718
- # along zeta
719
- edges_z = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
720
- for sx_val, sy_val in edges_z:
721
- sx_ = sx_val
722
- sy_ = sy_val
722
+ return jnp.stack([dxi, deta, dzeta], axis=2)
723
+
724
+ def d_edge_z(sx_, sy_):
723
725
  dxi = 0.25 * (1 - zeta * zeta) * sx_ * (1 + sy_ * eta)
724
726
  deta = 0.25 * (1 - zeta * zeta) * (1 + sx_ * xi) * sy_
725
727
  dzeta = -0.5 * zeta * (1 + sx_ * xi) * (1 + sy_ * eta)
726
- d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
728
+ return jnp.stack([dxi, deta, dzeta], axis=2)
729
+
730
+ d_list = [
731
+ d_edge_x(-1, -1),
732
+ d_edge_y(1, -1),
733
+ d_edge_x(1, -1),
734
+ d_edge_y(-1, -1),
735
+ d_edge_x(-1, 1),
736
+ d_edge_y(1, 1),
737
+ d_edge_x(1, 1),
738
+ d_edge_y(-1, 1),
739
+ d_edge_z(-1, -1),
740
+ d_edge_z(1, -1),
741
+ d_edge_z(1, 1),
742
+ d_edge_z(-1, 1),
743
+ ]
727
744
 
728
745
  d_edges = jnp.concatenate(d_list, axis=1) # (n_q,12,3)
729
746
  return jnp.concatenate([d_corner, d_edges], axis=1) # (n_q,20,3)
@@ -846,7 +863,8 @@ def _gauss_legendre_1d(order: int) -> tuple[jnp.ndarray, jnp.ndarray]:
846
863
  if order <= 0:
847
864
  raise ValueError("quadrature order must be positive")
848
865
  pts, wts = np.polynomial.legendre.leggauss(order)
849
- return jnp.array(pts, dtype=_FDTYPE), jnp.array(wts, dtype=_FDTYPE)
866
+ dtype = default_dtype()
867
+ return jnp.array(pts, dtype=dtype), jnp.array(wts, dtype=dtype)
850
868
 
851
869
 
852
870
  def _gl_points_for_degree(degree: int) -> int:
@@ -865,10 +883,12 @@ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
865
883
  degree<=1: 1-point; degree<=2: 4-point; degree>=3: 5-point (Stroud T3-5).
866
884
  """
867
885
  if degree <= 1:
868
- qp = jnp.array([[0.25, 0.25, 0.25]], dtype=_FDTYPE)
869
- qw = jnp.array([1.0 / 6.0], dtype=_FDTYPE)
886
+ dtype = default_dtype()
887
+ qp = jnp.array([[0.25, 0.25, 0.25]], dtype=dtype)
888
+ qw = jnp.array([1.0 / 6.0], dtype=dtype)
870
889
  return qp, qw
871
890
  if degree <= 2:
891
+ dtype = default_dtype()
872
892
  qp = jnp.array(
873
893
  [
874
894
  [0.58541020, 0.13819660, 0.13819660],
@@ -876,11 +896,12 @@ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
876
896
  [0.13819660, 0.13819660, 0.58541020],
877
897
  [0.13819660, 0.13819660, 0.13819660],
878
898
  ],
879
- dtype=_FDTYPE,
899
+ dtype=dtype,
880
900
  )
881
- qw = jnp.full((4,), (1.0 / 24.0), dtype=_FDTYPE)
901
+ qw = jnp.full((4,), (1.0 / 24.0), dtype=dtype)
882
902
  return qp, qw
883
903
  # degree 3 rule: centroid + 4 symmetric points
904
+ dtype = default_dtype()
884
905
  qp = jnp.array(
885
906
  [
886
907
  [0.25, 0.25, 0.25],
@@ -889,11 +910,11 @@ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
889
910
  [1.0 / 6.0, 1.0 / 6.0, 0.50],
890
911
  [1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0],
891
912
  ],
892
- dtype=_FDTYPE,
913
+ dtype=dtype,
893
914
  )
894
915
  qw = jnp.array(
895
916
  [-2.0 / 15.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0],
896
- dtype=_FDTYPE,
917
+ dtype=dtype,
897
918
  )
898
919
  return qp, qw
899
920
 
@@ -1,36 +1,60 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Protocol, TypeAlias, runtime_checkable
3
+ from typing import Any, Mapping, Protocol, TYPE_CHECKING, TypeAlias, runtime_checkable
4
+
5
+ import numpy as np
6
+
7
+ if TYPE_CHECKING:
8
+ from jax import Array as JaxArray
9
+
10
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
11
+ else:
12
+ ArrayLike: TypeAlias = np.ndarray
4
13
 
5
14
 
6
15
  @runtime_checkable
7
16
  class VolumeContext(Protocol):
8
17
  """Minimum interface for volume weak-form evaluation."""
9
18
 
10
- test: Any
11
- trial: Any
12
- w: Any
19
+ test: "FormFieldLike"
20
+ trial: "FormFieldLike"
21
+ w: ArrayLike
13
22
 
14
23
 
15
24
  @runtime_checkable
16
25
  class SurfaceContext(Protocol):
17
26
  """Minimum interface for surface weak-form evaluation."""
18
27
 
19
- v: Any
20
- w: Any
21
- detJ: Any
22
- normal: Any
28
+ v: "FormFieldLike"
29
+ w: ArrayLike
30
+ detJ: ArrayLike
31
+ normal: ArrayLike
23
32
 
24
33
 
25
34
  @runtime_checkable
26
35
  class FormFieldLike(Protocol):
27
36
  """Minimum interface for form fields used in weak-form evaluation."""
28
37
 
29
- N: Any
30
- gradN: Any
31
- detJ: Any
38
+ N: ArrayLike
39
+ gradN: ArrayLike
40
+ detJ: ArrayLike
32
41
  value_dim: int
42
+ basis: Any
43
+
44
+
45
+ @runtime_checkable
46
+ class WeakFormContext(Protocol):
47
+ """Context interface used when resolving field references."""
48
+
49
+ test: FormFieldLike
50
+ trial: FormFieldLike
51
+ v: FormFieldLike
52
+ unknown: FormFieldLike | None
53
+ fields: Mapping[str, Any] | None
54
+ test_fields: Mapping[str, FormFieldLike] | None
55
+ trial_fields: Mapping[str, FormFieldLike] | None
56
+ unknown_fields: Mapping[str, FormFieldLike] | None
33
57
 
34
58
 
35
- UElement: TypeAlias = Any
59
+ UElement: TypeAlias = ArrayLike | Mapping[str, ArrayLike]
36
60
  ParamsLike: TypeAlias = Any
fluxfem/core/dtypes.py CHANGED
@@ -1,4 +1,12 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
+ import numpy as np
3
4
 
4
- DEFAULT_DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
5
+
6
+ def default_dtype() -> jnp.dtype:
7
+ return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
8
+
9
+
10
+ DEFAULT_DTYPE = default_dtype()
11
+ INDEX_DTYPE = jnp.int64
12
+ NP_INDEX_DTYPE = np.int64
fluxfem/core/forms.py CHANGED
@@ -189,6 +189,7 @@ class FormContext:
189
189
  )
190
190
 
191
191
 
192
+ @jax.tree_util.register_pytree_node_class
192
193
  @dataclass(eq=False)
193
194
  class FieldPair:
194
195
  """Named test/trial/unknown grouping for mixed formulations."""
@@ -196,6 +197,15 @@ class FieldPair:
196
197
  trial: FormFieldLike
197
198
  unknown: FormFieldLike | None = None
198
199
 
200
+ def tree_flatten(self):
201
+ children = (self.test, self.trial, self.unknown)
202
+ return children, {}
203
+
204
+ @classmethod
205
+ def tree_unflatten(cls, aux_data, children):
206
+ test, trial, unknown = children
207
+ return cls(test=test, trial=trial, unknown=unknown)
208
+
199
209
 
200
210
  @jax.tree_util.register_pytree_node_class
201
211
  @dataclass(eq=False)
@@ -0,0 +1,263 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Mapping
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+
8
+ from .dtypes import INDEX_DTYPE
9
+ from .assembly import element_residual, make_sparsity_pattern, chunk_pad_stats, _maybe_trace_pad
10
+
11
+
12
+ def _coerce_mixed_u(space, u):
13
+ if isinstance(u, Mapping):
14
+ return space.pack_fields(u)
15
+ return jnp.asarray(u)
16
+
17
+
18
+ def _split_elem_vec(field_names, elem_slices, u_elem_vec):
19
+ return {name: u_elem_vec[elem_slices[name]] for name in field_names}
20
+
21
+
22
+ def _concat_residuals(field_names, res_dict):
23
+ return jnp.concatenate([res_dict[name] for name in field_names], axis=0)
24
+
25
+
26
+ def make_element_mixed_residual_kernel(res_form, params, field_names, elem_slices):
27
+ """Jitted element residual kernel for mixed systems."""
28
+
29
+ def per_element(ctx, u_elem_vec):
30
+ u_elem = _split_elem_vec(field_names, elem_slices, u_elem_vec)
31
+ res_dict = element_residual(res_form, ctx, u_elem, params)
32
+ return _concat_residuals(field_names, res_dict)
33
+
34
+ return jax.jit(per_element)
35
+
36
+
37
+ def make_element_mixed_jacobian_kernel(res_form, params, field_names, elem_slices):
38
+ """Jitted element Jacobian kernel for mixed systems."""
39
+ res_kernel = make_element_mixed_residual_kernel(res_form, params, field_names, elem_slices)
40
+
41
+ def fe_fun(u_elem_vec, ctx):
42
+ return res_kernel(ctx, u_elem_vec)
43
+
44
+ return jax.jit(jax.jacrev(fe_fun, argnums=0))
45
+
46
+
47
+ def assemble_mixed_residual_scatter(
48
+ space,
49
+ res_form,
50
+ u,
51
+ params,
52
+ *,
53
+ sparse: bool = False,
54
+ kernel=None,
55
+ n_chunks: int | None = None,
56
+ pad_trace: bool = False,
57
+ ):
58
+ """Assemble mixed residual using jitted element kernels + scatter_add."""
59
+ u_vec = _coerce_mixed_u(space, u)
60
+ ctxs = space.build_form_contexts()
61
+ ker = kernel if kernel is not None else make_element_mixed_residual_kernel(
62
+ res_form, params, space.field_names, space.elem_slices
63
+ )
64
+
65
+ u_elems = u_vec[space.elem_dofs]
66
+ if n_chunks is None:
67
+ elem_res = jax.vmap(ker)(ctxs, u_elems)
68
+ else:
69
+ n_elems = int(u_elems.shape[0])
70
+ if n_chunks <= 0:
71
+ raise ValueError("n_chunks must be a positive integer.")
72
+ n_chunks = min(int(n_chunks), int(n_elems))
73
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
74
+ stats = chunk_pad_stats(n_elems, n_chunks)
75
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
76
+ pad = (-n_elems) % chunk_size
77
+ if pad:
78
+ ctxs_pad = jax.tree_util.tree_map(
79
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
80
+ ctxs,
81
+ )
82
+ u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
83
+ else:
84
+ ctxs_pad = ctxs
85
+ u_elems_pad = u_elems
86
+
87
+ n_pad = n_elems + pad
88
+ n_chunks = n_pad // chunk_size
89
+
90
+ def _slice_first_dim(x, start, size):
91
+ start_idx = (start,) + (0,) * (x.ndim - 1)
92
+ slice_sizes = (size,) + x.shape[1:]
93
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
94
+
95
+ def chunk_fn(i):
96
+ start = i * chunk_size
97
+ ctx_chunk = jax.tree_util.tree_map(
98
+ lambda x: _slice_first_dim(x, start, chunk_size),
99
+ ctxs_pad,
100
+ )
101
+ u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
102
+ res_chunk = jax.vmap(ker)(ctx_chunk, u_chunk)
103
+ return res_chunk.reshape(-1)
104
+
105
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
106
+ elem_res = data_chunks.reshape(-1)[: n_elems * int(space.n_ldofs)].reshape(n_elems, -1)
107
+ rows = space.elem_dofs.reshape(-1)
108
+ data = elem_res.reshape(-1)
109
+
110
+ if sparse:
111
+ return rows, data, space.n_dofs
112
+
113
+ sdn = jax.lax.ScatterDimensionNumbers(
114
+ update_window_dims=(),
115
+ inserted_window_dims=(0,),
116
+ scatter_dims_to_operand_dims=(0,),
117
+ )
118
+ F = jnp.zeros((space.n_dofs,), dtype=data.dtype)
119
+ F = jax.lax.scatter_add(F, rows[:, None], data, sdn)
120
+ return F
121
+
122
+
123
+ def assemble_mixed_jacobian_values(
124
+ space, res_form, u, params, *, kernel=None, n_chunks: int | None = None, pad_trace: bool = False
125
+ ):
126
+ """Assemble numeric values for mixed Jacobian (pattern-free)."""
127
+ u_vec = _coerce_mixed_u(space, u)
128
+ ctxs = space.build_form_contexts()
129
+ ker = kernel if kernel is not None else make_element_mixed_jacobian_kernel(
130
+ res_form, params, space.field_names, space.elem_slices
131
+ )
132
+
133
+ u_elems = u_vec[space.elem_dofs]
134
+ if n_chunks is None:
135
+ J_e_all = jax.vmap(ker)(u_elems, ctxs)
136
+ return J_e_all.reshape(-1)
137
+
138
+ n_elems = int(u_elems.shape[0])
139
+ if n_chunks <= 0:
140
+ raise ValueError("n_chunks must be a positive integer.")
141
+ n_chunks = min(int(n_chunks), int(n_elems))
142
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
143
+ stats = chunk_pad_stats(n_elems, n_chunks)
144
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
145
+ pad = (-n_elems) % chunk_size
146
+ if pad:
147
+ ctxs_pad = jax.tree_util.tree_map(
148
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
149
+ ctxs,
150
+ )
151
+ u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
152
+ else:
153
+ ctxs_pad = ctxs
154
+ u_elems_pad = u_elems
155
+
156
+ n_pad = n_elems + pad
157
+ n_chunks = n_pad // chunk_size
158
+ m = int(space.n_ldofs)
159
+
160
+ def _slice_first_dim(x, start, size):
161
+ start_idx = (start,) + (0,) * (x.ndim - 1)
162
+ slice_sizes = (size,) + x.shape[1:]
163
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
164
+
165
+ def chunk_fn(i):
166
+ start = i * chunk_size
167
+ ctx_chunk = jax.tree_util.tree_map(
168
+ lambda x: _slice_first_dim(x, start, chunk_size),
169
+ ctxs_pad,
170
+ )
171
+ u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
172
+ J_e = jax.vmap(ker)(u_chunk, ctx_chunk)
173
+ return J_e.reshape(-1)
174
+
175
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
176
+ return data_chunks.reshape(-1)[: n_elems * m * m]
177
+
178
+
179
+ def assemble_mixed_jacobian_scatter(
180
+ space,
181
+ res_form,
182
+ u,
183
+ params,
184
+ *,
185
+ kernel=None,
186
+ sparse: bool = True,
187
+ return_flux_matrix: bool = False,
188
+ pattern=None,
189
+ n_chunks: int | None = None,
190
+ pad_trace: bool = False,
191
+ ):
192
+ """Assemble mixed Jacobian using jitted element kernels + scatter_add."""
193
+ from ..solver import FluxSparseMatrix # local import to avoid circular
194
+
195
+ pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
196
+ data = assemble_mixed_jacobian_values(
197
+ space, res_form, u, params, kernel=kernel, n_chunks=n_chunks, pad_trace=pad_trace
198
+ )
199
+
200
+ if sparse:
201
+ if return_flux_matrix:
202
+ return FluxSparseMatrix(pat, data)
203
+ return pat.rows, pat.cols, data, pat.n_dofs
204
+
205
+ idx = pat.idx
206
+ if idx is None:
207
+ idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(INDEX_DTYPE)
208
+
209
+ n_entries = pat.n_dofs * pat.n_dofs
210
+ sdn = jax.lax.ScatterDimensionNumbers(
211
+ update_window_dims=(),
212
+ inserted_window_dims=(0,),
213
+ scatter_dims_to_operand_dims=(0,),
214
+ )
215
+ K_flat = jnp.zeros(n_entries, dtype=data.dtype)
216
+ K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
217
+ return K_flat.reshape(pat.n_dofs, pat.n_dofs)
218
+
219
+
220
+ def assemble_mixed_residual(
221
+ space, res_form, u, params, *, sparse: bool = False, n_chunks: int | None = None, pad_trace: bool = False
222
+ ):
223
+ """Assemble the global mixed residual vector."""
224
+ return assemble_mixed_residual_scatter(
225
+ space, res_form, u, params, sparse=sparse, n_chunks=n_chunks, pad_trace=pad_trace
226
+ )
227
+
228
+
229
+ def assemble_mixed_jacobian(
230
+ space,
231
+ res_form,
232
+ u,
233
+ params,
234
+ *,
235
+ sparse: bool = True,
236
+ return_flux_matrix: bool = False,
237
+ pattern=None,
238
+ n_chunks: int | None = None,
239
+ pad_trace: bool = False,
240
+ ):
241
+ """Assemble the global mixed Jacobian."""
242
+ return assemble_mixed_jacobian_scatter(
243
+ space,
244
+ res_form,
245
+ u,
246
+ params,
247
+ sparse=sparse,
248
+ return_flux_matrix=return_flux_matrix,
249
+ pattern=pattern,
250
+ n_chunks=n_chunks,
251
+ pad_trace=pad_trace,
252
+ )
253
+
254
+
255
+ __all__ = [
256
+ "make_element_mixed_residual_kernel",
257
+ "make_element_mixed_jacobian_kernel",
258
+ "assemble_mixed_residual",
259
+ "assemble_mixed_jacobian",
260
+ "assemble_mixed_residual_scatter",
261
+ "assemble_mixed_jacobian_scatter",
262
+ "assemble_mixed_jacobian_values",
263
+ ]