fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +136 -161
- fluxfem/core/__init__.py +172 -41
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/context_types.py +36 -0
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +15 -1
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +1503 -312
- fluxfem/helpers_wf.py +53 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +322 -8
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +18 -16
- fluxfem/mesh/io.py +8 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.3.dist-info/METADATA +0 -125
- fluxfem-0.1.3.dist-info/RECORD +0 -47
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/basis.py
CHANGED
|
@@ -4,8 +4,7 @@ from typing import Protocol
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import numpy as np
|
|
7
|
-
from .dtypes import
|
|
8
|
-
# from .dtypes import DEFAULT_DTYPE
|
|
7
|
+
from .dtypes import default_dtype
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
def build_B_matrices(dN_dx: jnp.ndarray) -> jnp.ndarray:
|
|
@@ -271,6 +270,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
271
270
|
|
|
272
271
|
@property
|
|
273
272
|
def ref_node_coords(self) -> jnp.ndarray:
|
|
273
|
+
dtype = default_dtype()
|
|
274
274
|
return jnp.array(
|
|
275
275
|
[
|
|
276
276
|
[0.0, 0.0, 0.0],
|
|
@@ -278,7 +278,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
278
278
|
[0.0, 1.0, 0.0],
|
|
279
279
|
[0.0, 0.0, 1.0],
|
|
280
280
|
],
|
|
281
|
-
dtype=
|
|
281
|
+
dtype=dtype,
|
|
282
282
|
)
|
|
283
283
|
|
|
284
284
|
def shape_functions(self) -> jnp.ndarray:
|
|
@@ -294,6 +294,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
294
294
|
|
|
295
295
|
def shape_grads_ref(self) -> jnp.ndarray:
|
|
296
296
|
# constant gradients in reference tetra
|
|
297
|
+
dtype = default_dtype()
|
|
297
298
|
dN = jnp.array(
|
|
298
299
|
[
|
|
299
300
|
[-1.0, -1.0, -1.0],
|
|
@@ -301,7 +302,7 @@ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
301
302
|
[0.0, 1.0, 0.0],
|
|
302
303
|
[0.0, 0.0, 1.0],
|
|
303
304
|
],
|
|
304
|
-
dtype=
|
|
305
|
+
dtype=dtype,
|
|
305
306
|
)
|
|
306
307
|
dN = jnp.tile(dN[None, :, :], (self.n_q, 1, 1)) # (n_q,4,3)
|
|
307
308
|
return dN
|
|
@@ -374,7 +375,7 @@ class TetQuadraticBasis10(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
374
375
|
|
|
375
376
|
grads = []
|
|
376
377
|
for a, dLa in zip([L1, L2, L3, L4], [dL1, dL2, dL3, dL4]):
|
|
377
|
-
grads.append((
|
|
378
|
+
grads.append((4 * a - 1)[..., None] * dLa[None, :])
|
|
378
379
|
|
|
379
380
|
dN1 = grads[0]
|
|
380
381
|
dN2 = grads[1]
|
|
@@ -446,6 +447,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
446
447
|
6: ( 1, 1, 1)
|
|
447
448
|
7: (-1, 1, 1)
|
|
448
449
|
"""
|
|
450
|
+
dtype = default_dtype()
|
|
449
451
|
return jnp.array(
|
|
450
452
|
[
|
|
451
453
|
[-1.0, -1.0, -1.0],
|
|
@@ -457,7 +459,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
457
459
|
[ 1.0, 1.0, 1.0],
|
|
458
460
|
[-1.0, 1.0, 1.0],
|
|
459
461
|
],
|
|
460
|
-
dtype=
|
|
462
|
+
dtype=dtype,
|
|
461
463
|
)
|
|
462
464
|
|
|
463
465
|
# ---------- reference shape functions & gradients ----------
|
|
@@ -581,6 +583,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
581
583
|
|
|
582
584
|
@property
|
|
583
585
|
def ref_node_coords(self) -> jnp.ndarray:
|
|
586
|
+
dtype = default_dtype()
|
|
584
587
|
corners = jnp.array(
|
|
585
588
|
[
|
|
586
589
|
[-1.0, -1.0, -1.0],
|
|
@@ -592,7 +595,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
592
595
|
[ 1.0, 1.0, 1.0],
|
|
593
596
|
[-1.0, 1.0, 1.0],
|
|
594
597
|
],
|
|
595
|
-
dtype=
|
|
598
|
+
dtype=dtype,
|
|
596
599
|
)
|
|
597
600
|
edges = jnp.array(
|
|
598
601
|
[
|
|
@@ -609,7 +612,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
609
612
|
[ 1.0, 1.0, 0.0], # 2-6
|
|
610
613
|
[-1.0, 1.0, 0.0], # 3-7
|
|
611
614
|
],
|
|
612
|
-
dtype=
|
|
615
|
+
dtype=dtype,
|
|
613
616
|
)
|
|
614
617
|
return jnp.concatenate([corners, edges], axis=0) # (20,3)
|
|
615
618
|
|
|
@@ -620,6 +623,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
620
623
|
zeta = qp[:, 2:3]
|
|
621
624
|
|
|
622
625
|
# corners
|
|
626
|
+
dtype = default_dtype()
|
|
623
627
|
s = jnp.array(
|
|
624
628
|
[
|
|
625
629
|
[-1, -1, -1],
|
|
@@ -631,7 +635,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
631
635
|
[ 1, 1, 1],
|
|
632
636
|
[-1, 1, 1],
|
|
633
637
|
],
|
|
634
|
-
dtype=
|
|
638
|
+
dtype=dtype,
|
|
635
639
|
)
|
|
636
640
|
sx = s[:, 0]
|
|
637
641
|
sy = s[:, 1]
|
|
@@ -639,22 +643,30 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
639
643
|
term = xi * sx + eta * sy + zeta * sz - 2.0 # (n_q,8)
|
|
640
644
|
N_corner = 0.125 * (1 + sx * xi) * (1 + sy * eta) * (1 + sz * zeta) * term # (n_q,8)
|
|
641
645
|
|
|
642
|
-
# edges
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
646
|
+
# edges: order matches hex20 connectivity (e01, e12, e23, e30, e45, e56, e67, e74, e04, e15, e26, e37)
|
|
647
|
+
def edge_x(sy, sz):
|
|
648
|
+
return 0.25 * (1 - xi * xi) * (1 + sy * eta) * (1 + sz * zeta)
|
|
649
|
+
|
|
650
|
+
def edge_y(sx, sz):
|
|
651
|
+
return 0.25 * (1 - eta * eta) * (1 + sx * xi) * (1 + sz * zeta)
|
|
652
|
+
|
|
653
|
+
def edge_z(sx, sy):
|
|
654
|
+
return 0.25 * (1 - zeta * zeta) * (1 + sx * xi) * (1 + sy * eta)
|
|
655
|
+
|
|
656
|
+
N_edges = [
|
|
657
|
+
edge_x(-1, -1),
|
|
658
|
+
edge_y(1, -1),
|
|
659
|
+
edge_x(1, -1),
|
|
660
|
+
edge_y(-1, -1),
|
|
661
|
+
edge_x(-1, 1),
|
|
662
|
+
edge_y(1, 1),
|
|
663
|
+
edge_x(1, 1),
|
|
664
|
+
edge_y(-1, 1),
|
|
665
|
+
edge_z(-1, -1),
|
|
666
|
+
edge_z(1, -1),
|
|
667
|
+
edge_z(1, 1),
|
|
668
|
+
edge_z(-1, 1),
|
|
669
|
+
]
|
|
658
670
|
N_edges = jnp.concatenate(N_edges, axis=1) # (n_q, 12)
|
|
659
671
|
return jnp.concatenate([N_corner, N_edges], axis=1) # (n_q,20)
|
|
660
672
|
|
|
@@ -664,6 +676,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
664
676
|
eta = qp[:, 1:2]
|
|
665
677
|
zeta = qp[:, 2:3]
|
|
666
678
|
|
|
679
|
+
dtype = default_dtype()
|
|
667
680
|
s = jnp.array(
|
|
668
681
|
[
|
|
669
682
|
[-1, -1, -1],
|
|
@@ -675,7 +688,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
675
688
|
[ 1, 1, 1],
|
|
676
689
|
[-1, 1, 1],
|
|
677
690
|
],
|
|
678
|
-
dtype=
|
|
691
|
+
dtype=dtype,
|
|
679
692
|
)
|
|
680
693
|
sx = s[:, 0]
|
|
681
694
|
sy = s[:, 1]
|
|
@@ -696,34 +709,38 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
696
709
|
) # (n_q,8,3)
|
|
697
710
|
|
|
698
711
|
# edges derivatives
|
|
699
|
-
|
|
700
|
-
# along xi
|
|
701
|
-
edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
|
|
702
|
-
for sy_val, sz_val in edges_x:
|
|
703
|
-
sy_ = sy_val
|
|
704
|
-
sz_ = sz_val
|
|
712
|
+
def d_edge_x(sy_, sz_):
|
|
705
713
|
dxi = -0.5 * xi * (1 + sy_ * eta) * (1 + sz_ * zeta)
|
|
706
714
|
deta = 0.25 * (1 - xi * xi) * sy_ * (1 + sz_ * zeta)
|
|
707
715
|
dzeta = 0.25 * (1 - xi * xi) * (1 + sy_ * eta) * sz_
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
for sx_val, sz_val in edges_y:
|
|
712
|
-
sx_ = sx_val
|
|
713
|
-
sz_ = sz_val
|
|
716
|
+
return jnp.stack([dxi, deta, dzeta], axis=2)
|
|
717
|
+
|
|
718
|
+
def d_edge_y(sx_, sz_):
|
|
714
719
|
dxi = 0.25 * (1 - eta * eta) * sx_ * (1 + sz_ * zeta)
|
|
715
720
|
deta = -0.5 * eta * (1 + sx_ * xi) * (1 + sz_ * zeta)
|
|
716
721
|
dzeta = 0.25 * (1 - eta * eta) * (1 + sx_ * xi) * sz_
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
for sx_val, sy_val in edges_z:
|
|
721
|
-
sx_ = sx_val
|
|
722
|
-
sy_ = sy_val
|
|
722
|
+
return jnp.stack([dxi, deta, dzeta], axis=2)
|
|
723
|
+
|
|
724
|
+
def d_edge_z(sx_, sy_):
|
|
723
725
|
dxi = 0.25 * (1 - zeta * zeta) * sx_ * (1 + sy_ * eta)
|
|
724
726
|
deta = 0.25 * (1 - zeta * zeta) * (1 + sx_ * xi) * sy_
|
|
725
727
|
dzeta = -0.5 * zeta * (1 + sx_ * xi) * (1 + sy_ * eta)
|
|
726
|
-
|
|
728
|
+
return jnp.stack([dxi, deta, dzeta], axis=2)
|
|
729
|
+
|
|
730
|
+
d_list = [
|
|
731
|
+
d_edge_x(-1, -1),
|
|
732
|
+
d_edge_y(1, -1),
|
|
733
|
+
d_edge_x(1, -1),
|
|
734
|
+
d_edge_y(-1, -1),
|
|
735
|
+
d_edge_x(-1, 1),
|
|
736
|
+
d_edge_y(1, 1),
|
|
737
|
+
d_edge_x(1, 1),
|
|
738
|
+
d_edge_y(-1, 1),
|
|
739
|
+
d_edge_z(-1, -1),
|
|
740
|
+
d_edge_z(1, -1),
|
|
741
|
+
d_edge_z(1, 1),
|
|
742
|
+
d_edge_z(-1, 1),
|
|
743
|
+
]
|
|
727
744
|
|
|
728
745
|
d_edges = jnp.concatenate(d_list, axis=1) # (n_q,12,3)
|
|
729
746
|
return jnp.concatenate([d_corner, d_edges], axis=1) # (n_q,20,3)
|
|
@@ -846,7 +863,8 @@ def _gauss_legendre_1d(order: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
|
846
863
|
if order <= 0:
|
|
847
864
|
raise ValueError("quadrature order must be positive")
|
|
848
865
|
pts, wts = np.polynomial.legendre.leggauss(order)
|
|
849
|
-
|
|
866
|
+
dtype = default_dtype()
|
|
867
|
+
return jnp.array(pts, dtype=dtype), jnp.array(wts, dtype=dtype)
|
|
850
868
|
|
|
851
869
|
|
|
852
870
|
def _gl_points_for_degree(degree: int) -> int:
|
|
@@ -865,10 +883,12 @@ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
|
865
883
|
degree<=1: 1-point; degree<=2: 4-point; degree>=3: 5-point (Stroud T3-5).
|
|
866
884
|
"""
|
|
867
885
|
if degree <= 1:
|
|
868
|
-
|
|
869
|
-
|
|
886
|
+
dtype = default_dtype()
|
|
887
|
+
qp = jnp.array([[0.25, 0.25, 0.25]], dtype=dtype)
|
|
888
|
+
qw = jnp.array([1.0 / 6.0], dtype=dtype)
|
|
870
889
|
return qp, qw
|
|
871
890
|
if degree <= 2:
|
|
891
|
+
dtype = default_dtype()
|
|
872
892
|
qp = jnp.array(
|
|
873
893
|
[
|
|
874
894
|
[0.58541020, 0.13819660, 0.13819660],
|
|
@@ -876,11 +896,12 @@ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
|
876
896
|
[0.13819660, 0.13819660, 0.58541020],
|
|
877
897
|
[0.13819660, 0.13819660, 0.13819660],
|
|
878
898
|
],
|
|
879
|
-
dtype=
|
|
899
|
+
dtype=dtype,
|
|
880
900
|
)
|
|
881
|
-
qw = jnp.full((4,), (1.0 / 24.0), dtype=
|
|
901
|
+
qw = jnp.full((4,), (1.0 / 24.0), dtype=dtype)
|
|
882
902
|
return qp, qw
|
|
883
903
|
# degree 3 rule: centroid + 4 symmetric points
|
|
904
|
+
dtype = default_dtype()
|
|
884
905
|
qp = jnp.array(
|
|
885
906
|
[
|
|
886
907
|
[0.25, 0.25, 0.25],
|
|
@@ -889,11 +910,11 @@ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
|
889
910
|
[1.0 / 6.0, 1.0 / 6.0, 0.50],
|
|
890
911
|
[1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0],
|
|
891
912
|
],
|
|
892
|
-
dtype=
|
|
913
|
+
dtype=dtype,
|
|
893
914
|
)
|
|
894
915
|
qw = jnp.array(
|
|
895
916
|
[-2.0 / 15.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0],
|
|
896
|
-
dtype=
|
|
917
|
+
dtype=dtype,
|
|
897
918
|
)
|
|
898
919
|
return qp, qw
|
|
899
920
|
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Protocol, TypeAlias, runtime_checkable
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@runtime_checkable
|
|
7
|
+
class VolumeContext(Protocol):
|
|
8
|
+
"""Minimum interface for volume weak-form evaluation."""
|
|
9
|
+
|
|
10
|
+
test: Any
|
|
11
|
+
trial: Any
|
|
12
|
+
w: Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class SurfaceContext(Protocol):
|
|
17
|
+
"""Minimum interface for surface weak-form evaluation."""
|
|
18
|
+
|
|
19
|
+
v: Any
|
|
20
|
+
w: Any
|
|
21
|
+
detJ: Any
|
|
22
|
+
normal: Any
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@runtime_checkable
|
|
26
|
+
class FormFieldLike(Protocol):
|
|
27
|
+
"""Minimum interface for form fields used in weak-form evaluation."""
|
|
28
|
+
|
|
29
|
+
N: Any
|
|
30
|
+
gradN: Any
|
|
31
|
+
detJ: Any
|
|
32
|
+
value_dim: int
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
UElement: TypeAlias = Any
|
|
36
|
+
ParamsLike: TypeAlias = Any
|
fluxfem/core/dtypes.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
3
4
|
|
|
4
|
-
|
|
5
|
+
|
|
6
|
+
def default_dtype() -> jnp.dtype:
|
|
7
|
+
return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_DTYPE = default_dtype()
|
|
11
|
+
INDEX_DTYPE = jnp.int64
|
|
12
|
+
NP_INDEX_DTYPE = np.int64
|
fluxfem/core/forms.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, TypeAlias
|
|
6
7
|
|
|
7
8
|
from .basis import Basis3D
|
|
8
9
|
|
|
@@ -119,7 +120,10 @@ class VectorFormField:
|
|
|
119
120
|
return cls(N, elem_coords, aux["basis"], aux["value_dim"], gradN, detJ)
|
|
120
121
|
|
|
121
122
|
|
|
122
|
-
|
|
123
|
+
if TYPE_CHECKING:
|
|
124
|
+
FormFieldLike: TypeAlias = ScalarFormField | VectorFormField
|
|
125
|
+
else:
|
|
126
|
+
FormFieldLike = object
|
|
123
127
|
|
|
124
128
|
|
|
125
129
|
def vector_load_form(field: FormFieldLike, load_vec: jnp.ndarray) -> jnp.ndarray:
|
|
@@ -185,6 +189,7 @@ class FormContext:
|
|
|
185
189
|
)
|
|
186
190
|
|
|
187
191
|
|
|
192
|
+
@jax.tree_util.register_pytree_node_class
|
|
188
193
|
@dataclass(eq=False)
|
|
189
194
|
class FieldPair:
|
|
190
195
|
"""Named test/trial/unknown grouping for mixed formulations."""
|
|
@@ -192,6 +197,15 @@ class FieldPair:
|
|
|
192
197
|
trial: FormFieldLike
|
|
193
198
|
unknown: FormFieldLike | None = None
|
|
194
199
|
|
|
200
|
+
def tree_flatten(self):
|
|
201
|
+
children = (self.test, self.trial, self.unknown)
|
|
202
|
+
return children, {}
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def tree_unflatten(cls, aux_data, children):
|
|
206
|
+
test, trial, unknown = children
|
|
207
|
+
return cls(test=test, trial=trial, unknown=unknown)
|
|
208
|
+
|
|
195
209
|
|
|
196
210
|
@jax.tree_util.register_pytree_node_class
|
|
197
211
|
@dataclass(eq=False)
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Mapping
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
|
|
8
|
+
from .dtypes import INDEX_DTYPE
|
|
9
|
+
from .assembly import element_residual, make_sparsity_pattern, chunk_pad_stats, _maybe_trace_pad
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _coerce_mixed_u(space, u):
|
|
13
|
+
if isinstance(u, Mapping):
|
|
14
|
+
return space.pack_fields(u)
|
|
15
|
+
return jnp.asarray(u)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _split_elem_vec(field_names, elem_slices, u_elem_vec):
|
|
19
|
+
return {name: u_elem_vec[elem_slices[name]] for name in field_names}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _concat_residuals(field_names, res_dict):
|
|
23
|
+
return jnp.concatenate([res_dict[name] for name in field_names], axis=0)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def make_element_mixed_residual_kernel(res_form, params, field_names, elem_slices):
|
|
27
|
+
"""Jitted element residual kernel for mixed systems."""
|
|
28
|
+
|
|
29
|
+
def per_element(ctx, u_elem_vec):
|
|
30
|
+
u_elem = _split_elem_vec(field_names, elem_slices, u_elem_vec)
|
|
31
|
+
res_dict = element_residual(res_form, ctx, u_elem, params)
|
|
32
|
+
return _concat_residuals(field_names, res_dict)
|
|
33
|
+
|
|
34
|
+
return jax.jit(per_element)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def make_element_mixed_jacobian_kernel(res_form, params, field_names, elem_slices):
|
|
38
|
+
"""Jitted element Jacobian kernel for mixed systems."""
|
|
39
|
+
res_kernel = make_element_mixed_residual_kernel(res_form, params, field_names, elem_slices)
|
|
40
|
+
|
|
41
|
+
def fe_fun(u_elem_vec, ctx):
|
|
42
|
+
return res_kernel(ctx, u_elem_vec)
|
|
43
|
+
|
|
44
|
+
return jax.jit(jax.jacrev(fe_fun, argnums=0))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def assemble_mixed_residual_scatter(
|
|
48
|
+
space,
|
|
49
|
+
res_form,
|
|
50
|
+
u,
|
|
51
|
+
params,
|
|
52
|
+
*,
|
|
53
|
+
sparse: bool = False,
|
|
54
|
+
kernel=None,
|
|
55
|
+
n_chunks: int | None = None,
|
|
56
|
+
pad_trace: bool = False,
|
|
57
|
+
):
|
|
58
|
+
"""Assemble mixed residual using jitted element kernels + scatter_add."""
|
|
59
|
+
u_vec = _coerce_mixed_u(space, u)
|
|
60
|
+
ctxs = space.build_form_contexts()
|
|
61
|
+
ker = kernel if kernel is not None else make_element_mixed_residual_kernel(
|
|
62
|
+
res_form, params, space.field_names, space.elem_slices
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
u_elems = u_vec[space.elem_dofs]
|
|
66
|
+
if n_chunks is None:
|
|
67
|
+
elem_res = jax.vmap(ker)(ctxs, u_elems)
|
|
68
|
+
else:
|
|
69
|
+
n_elems = int(u_elems.shape[0])
|
|
70
|
+
if n_chunks <= 0:
|
|
71
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
72
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
73
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
74
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
75
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
76
|
+
pad = (-n_elems) % chunk_size
|
|
77
|
+
if pad:
|
|
78
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
79
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
80
|
+
ctxs,
|
|
81
|
+
)
|
|
82
|
+
u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
|
|
83
|
+
else:
|
|
84
|
+
ctxs_pad = ctxs
|
|
85
|
+
u_elems_pad = u_elems
|
|
86
|
+
|
|
87
|
+
n_pad = n_elems + pad
|
|
88
|
+
n_chunks = n_pad // chunk_size
|
|
89
|
+
|
|
90
|
+
def _slice_first_dim(x, start, size):
|
|
91
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
92
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
93
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
94
|
+
|
|
95
|
+
def chunk_fn(i):
|
|
96
|
+
start = i * chunk_size
|
|
97
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
98
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
99
|
+
ctxs_pad,
|
|
100
|
+
)
|
|
101
|
+
u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
|
|
102
|
+
res_chunk = jax.vmap(ker)(ctx_chunk, u_chunk)
|
|
103
|
+
return res_chunk.reshape(-1)
|
|
104
|
+
|
|
105
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
106
|
+
elem_res = data_chunks.reshape(-1)[: n_elems * int(space.n_ldofs)].reshape(n_elems, -1)
|
|
107
|
+
rows = space.elem_dofs.reshape(-1)
|
|
108
|
+
data = elem_res.reshape(-1)
|
|
109
|
+
|
|
110
|
+
if sparse:
|
|
111
|
+
return rows, data, space.n_dofs
|
|
112
|
+
|
|
113
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
114
|
+
update_window_dims=(),
|
|
115
|
+
inserted_window_dims=(0,),
|
|
116
|
+
scatter_dims_to_operand_dims=(0,),
|
|
117
|
+
)
|
|
118
|
+
F = jnp.zeros((space.n_dofs,), dtype=data.dtype)
|
|
119
|
+
F = jax.lax.scatter_add(F, rows[:, None], data, sdn)
|
|
120
|
+
return F
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def assemble_mixed_jacobian_values(
|
|
124
|
+
space, res_form, u, params, *, kernel=None, n_chunks: int | None = None, pad_trace: bool = False
|
|
125
|
+
):
|
|
126
|
+
"""Assemble numeric values for mixed Jacobian (pattern-free)."""
|
|
127
|
+
u_vec = _coerce_mixed_u(space, u)
|
|
128
|
+
ctxs = space.build_form_contexts()
|
|
129
|
+
ker = kernel if kernel is not None else make_element_mixed_jacobian_kernel(
|
|
130
|
+
res_form, params, space.field_names, space.elem_slices
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
u_elems = u_vec[space.elem_dofs]
|
|
134
|
+
if n_chunks is None:
|
|
135
|
+
J_e_all = jax.vmap(ker)(u_elems, ctxs)
|
|
136
|
+
return J_e_all.reshape(-1)
|
|
137
|
+
|
|
138
|
+
n_elems = int(u_elems.shape[0])
|
|
139
|
+
if n_chunks <= 0:
|
|
140
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
141
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
142
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
143
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
144
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
145
|
+
pad = (-n_elems) % chunk_size
|
|
146
|
+
if pad:
|
|
147
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
148
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
149
|
+
ctxs,
|
|
150
|
+
)
|
|
151
|
+
u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
|
|
152
|
+
else:
|
|
153
|
+
ctxs_pad = ctxs
|
|
154
|
+
u_elems_pad = u_elems
|
|
155
|
+
|
|
156
|
+
n_pad = n_elems + pad
|
|
157
|
+
n_chunks = n_pad // chunk_size
|
|
158
|
+
m = int(space.n_ldofs)
|
|
159
|
+
|
|
160
|
+
def _slice_first_dim(x, start, size):
|
|
161
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
162
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
163
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
164
|
+
|
|
165
|
+
def chunk_fn(i):
|
|
166
|
+
start = i * chunk_size
|
|
167
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
168
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
169
|
+
ctxs_pad,
|
|
170
|
+
)
|
|
171
|
+
u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
|
|
172
|
+
J_e = jax.vmap(ker)(u_chunk, ctx_chunk)
|
|
173
|
+
return J_e.reshape(-1)
|
|
174
|
+
|
|
175
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
176
|
+
return data_chunks.reshape(-1)[: n_elems * m * m]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def assemble_mixed_jacobian_scatter(
|
|
180
|
+
space,
|
|
181
|
+
res_form,
|
|
182
|
+
u,
|
|
183
|
+
params,
|
|
184
|
+
*,
|
|
185
|
+
kernel=None,
|
|
186
|
+
sparse: bool = True,
|
|
187
|
+
return_flux_matrix: bool = False,
|
|
188
|
+
pattern=None,
|
|
189
|
+
n_chunks: int | None = None,
|
|
190
|
+
pad_trace: bool = False,
|
|
191
|
+
):
|
|
192
|
+
"""Assemble mixed Jacobian using jitted element kernels + scatter_add."""
|
|
193
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
194
|
+
|
|
195
|
+
pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
|
|
196
|
+
data = assemble_mixed_jacobian_values(
|
|
197
|
+
space, res_form, u, params, kernel=kernel, n_chunks=n_chunks, pad_trace=pad_trace
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if sparse:
|
|
201
|
+
if return_flux_matrix:
|
|
202
|
+
return FluxSparseMatrix(pat, data)
|
|
203
|
+
return pat.rows, pat.cols, data, pat.n_dofs
|
|
204
|
+
|
|
205
|
+
idx = pat.idx
|
|
206
|
+
if idx is None:
|
|
207
|
+
idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
208
|
+
|
|
209
|
+
n_entries = pat.n_dofs * pat.n_dofs
|
|
210
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
211
|
+
update_window_dims=(),
|
|
212
|
+
inserted_window_dims=(0,),
|
|
213
|
+
scatter_dims_to_operand_dims=(0,),
|
|
214
|
+
)
|
|
215
|
+
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
216
|
+
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
217
|
+
return K_flat.reshape(pat.n_dofs, pat.n_dofs)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def assemble_mixed_residual(
|
|
221
|
+
space, res_form, u, params, *, sparse: bool = False, n_chunks: int | None = None, pad_trace: bool = False
|
|
222
|
+
):
|
|
223
|
+
"""Assemble the global mixed residual vector."""
|
|
224
|
+
return assemble_mixed_residual_scatter(
|
|
225
|
+
space, res_form, u, params, sparse=sparse, n_chunks=n_chunks, pad_trace=pad_trace
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def assemble_mixed_jacobian(
|
|
230
|
+
space,
|
|
231
|
+
res_form,
|
|
232
|
+
u,
|
|
233
|
+
params,
|
|
234
|
+
*,
|
|
235
|
+
sparse: bool = True,
|
|
236
|
+
return_flux_matrix: bool = False,
|
|
237
|
+
pattern=None,
|
|
238
|
+
n_chunks: int | None = None,
|
|
239
|
+
pad_trace: bool = False,
|
|
240
|
+
):
|
|
241
|
+
"""Assemble the global mixed Jacobian."""
|
|
242
|
+
return assemble_mixed_jacobian_scatter(
|
|
243
|
+
space,
|
|
244
|
+
res_form,
|
|
245
|
+
u,
|
|
246
|
+
params,
|
|
247
|
+
sparse=sparse,
|
|
248
|
+
return_flux_matrix=return_flux_matrix,
|
|
249
|
+
pattern=pattern,
|
|
250
|
+
n_chunks=n_chunks,
|
|
251
|
+
pad_trace=pad_trace,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
__all__ = [
|
|
256
|
+
"make_element_mixed_residual_kernel",
|
|
257
|
+
"make_element_mixed_jacobian_kernel",
|
|
258
|
+
"assemble_mixed_residual",
|
|
259
|
+
"assemble_mixed_jacobian",
|
|
260
|
+
"assemble_mixed_residual_scatter",
|
|
261
|
+
"assemble_mixed_jacobian_scatter",
|
|
262
|
+
"assemble_mixed_jacobian_values",
|
|
263
|
+
]
|