fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +136 -161
- fluxfem/core/__init__.py +172 -41
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/context_types.py +36 -0
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +15 -1
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +1503 -312
- fluxfem/helpers_wf.py +53 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +322 -8
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +18 -16
- fluxfem/mesh/io.py +8 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.3.dist-info/METADATA +0 -125
- fluxfem-0.1.3.dist-info/RECORD +0 -47
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/mesh/dtypes.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def default_dtype() -> jnp.dtype:
|
|
7
|
+
return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_DTYPE = default_dtype()
|
|
11
|
+
INDEX_DTYPE = jnp.int64
|
|
12
|
+
NP_INDEX_DTYPE = np.int64
|
fluxfem/mesh/hex.py
CHANGED
|
@@ -5,9 +5,9 @@ from dataclasses import dataclass
|
|
|
5
5
|
from typing import Optional, Dict, Tuple, List, Callable
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
|
-
from ..core.dtypes import DEFAULT_DTYPE
|
|
9
8
|
import numpy as np
|
|
10
9
|
|
|
10
|
+
from .dtypes import INDEX_DTYPE, default_dtype
|
|
11
11
|
from .base import BaseMesh, BaseMeshPytree
|
|
12
12
|
|
|
13
13
|
|
|
@@ -17,10 +17,10 @@ class HexMesh(BaseMesh):
|
|
|
17
17
|
Structured / unstructured hex mesh (8-node linear hex elements).
|
|
18
18
|
|
|
19
19
|
coords: (n_nodes, 3) float32
|
|
20
|
-
conn: (n_elems, 8)
|
|
20
|
+
conn: (n_elems, 8) int64 # node indices of each element
|
|
21
21
|
"""
|
|
22
22
|
coords: jnp.ndarray # shape (n_nodes, 3)
|
|
23
|
-
conn: jnp.ndarray # shape (n_elems, 8),
|
|
23
|
+
conn: jnp.ndarray # shape (n_elems, 8), int64
|
|
24
24
|
|
|
25
25
|
def face_node_patterns(self):
|
|
26
26
|
return [
|
|
@@ -36,6 +36,7 @@ class HexMesh(BaseMesh):
|
|
|
36
36
|
@jax.tree_util.register_pytree_node_class
|
|
37
37
|
@dataclass(eq=False)
|
|
38
38
|
class HexMeshPytree(BaseMeshPytree):
|
|
39
|
+
"""Hex mesh registered as a JAX pytree."""
|
|
39
40
|
def face_node_patterns(self):
|
|
40
41
|
return [
|
|
41
42
|
(0, 1, 2, 3), # -z
|
|
@@ -58,8 +59,8 @@ def tag_axis_minmax_facets(
|
|
|
58
59
|
Tag boundary facets on min/max of the given axis.
|
|
59
60
|
|
|
60
61
|
Returns:
|
|
61
|
-
facets: (n_facets, 4)
|
|
62
|
-
facet_tags: (n_facets,)
|
|
62
|
+
facets: (n_facets, 4) int64, quad node ids
|
|
63
|
+
facet_tags: (n_facets,) int64, dirichlet_tag on min side, neumann_tag on max side
|
|
63
64
|
"""
|
|
64
65
|
coords = np.asarray(mesh.coords)
|
|
65
66
|
conn = np.asarray(mesh.conn)
|
|
@@ -98,7 +99,7 @@ def tag_axis_minmax_facets(
|
|
|
98
99
|
facet_map[key] = (nodes, tag)
|
|
99
100
|
|
|
100
101
|
if not facet_map:
|
|
101
|
-
return jnp.empty((0, 4), dtype=
|
|
102
|
+
return jnp.empty((0, 4), dtype=INDEX_DTYPE), jnp.empty((0,), dtype=INDEX_DTYPE)
|
|
102
103
|
|
|
103
104
|
facets = []
|
|
104
105
|
tags = []
|
|
@@ -106,7 +107,7 @@ def tag_axis_minmax_facets(
|
|
|
106
107
|
facets.append(nodes)
|
|
107
108
|
tags.append(tag)
|
|
108
109
|
|
|
109
|
-
return jnp.array(facets, dtype=
|
|
110
|
+
return jnp.array(facets, dtype=INDEX_DTYPE), jnp.array(tags, dtype=INDEX_DTYPE)
|
|
110
111
|
|
|
111
112
|
|
|
112
113
|
@dataclass
|
|
@@ -134,9 +135,10 @@ class StructuredHexBox:
|
|
|
134
135
|
raise ValueError("order must be 1, 2, or 3")
|
|
135
136
|
|
|
136
137
|
ox, oy, oz = self.origin
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
dtype = default_dtype()
|
|
139
|
+
xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=dtype)
|
|
140
|
+
ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=dtype)
|
|
141
|
+
zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=dtype)
|
|
140
142
|
|
|
141
143
|
if self.order == 1:
|
|
142
144
|
return self._build_hex8(xs, ys, zs)
|
|
@@ -150,7 +152,7 @@ class StructuredHexBox:
|
|
|
150
152
|
for j in range(self.ny + 1):
|
|
151
153
|
for i in range(self.nx + 1):
|
|
152
154
|
coords_list.append([xs[i], ys[j], zs[k]])
|
|
153
|
-
coords = jnp.array(coords_list, dtype=
|
|
155
|
+
coords = jnp.array(coords_list, dtype=default_dtype())
|
|
154
156
|
|
|
155
157
|
def node_id(i: int, j: int, k: int) -> int:
|
|
156
158
|
return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
|
|
@@ -169,7 +171,7 @@ class StructuredHexBox:
|
|
|
169
171
|
n011 = node_id(i, j + 1, k + 1)
|
|
170
172
|
conn_list.append([n000, n100, n110, n010, n001, n101, n111, n011])
|
|
171
173
|
|
|
172
|
-
conn = jnp.array(conn_list, dtype=
|
|
174
|
+
conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
|
|
173
175
|
return HexMesh(coords=coords, conn=conn)
|
|
174
176
|
|
|
175
177
|
def _build_hex20(self, xs, ys, zs) -> HexMesh:
|
|
@@ -240,8 +242,8 @@ class StructuredHexBox:
|
|
|
240
242
|
]
|
|
241
243
|
)
|
|
242
244
|
|
|
243
|
-
coords = jnp.array(coords_list, dtype=
|
|
244
|
-
conn = jnp.array(conn_list, dtype=
|
|
245
|
+
coords = jnp.array(coords_list, dtype=default_dtype())
|
|
246
|
+
conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
|
|
245
247
|
return HexMesh(coords=coords, conn=conn)
|
|
246
248
|
|
|
247
249
|
def _build_hex27(self, xs, ys, zs) -> HexMesh:
|
|
@@ -322,6 +324,6 @@ class StructuredHexBox:
|
|
|
322
324
|
# order in lexicographic k,j,i -> length 27
|
|
323
325
|
conn_list.append(nodes)
|
|
324
326
|
|
|
325
|
-
coords = jnp.array(coords_list, dtype=
|
|
326
|
-
conn = jnp.array(conn_list, dtype=
|
|
327
|
+
coords = jnp.array(coords_list, dtype=default_dtype())
|
|
328
|
+
conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
|
|
327
329
|
return HexMesh(coords=coords, conn=conn)
|
fluxfem/mesh/io.py
CHANGED
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
|
|
7
|
+
from .dtypes import NP_INDEX_DTYPE
|
|
8
|
+
|
|
7
9
|
DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
|
|
8
10
|
|
|
9
11
|
try:
|
|
@@ -36,10 +38,10 @@ def load_gmsh_mesh(path: str):
|
|
|
36
38
|
|
|
37
39
|
mesh = None
|
|
38
40
|
if "hexahedron" in msh.cells_dict:
|
|
39
|
-
conn = np.asarray(msh.cells_dict["hexahedron"], dtype=
|
|
41
|
+
conn = np.asarray(msh.cells_dict["hexahedron"], dtype=NP_INDEX_DTYPE)
|
|
40
42
|
mesh = HexMesh(jnp.asarray(coords), jnp.asarray(conn))
|
|
41
43
|
elif "tetra" in msh.cells_dict:
|
|
42
|
-
conn = np.asarray(msh.cells_dict["tetra"], dtype=
|
|
44
|
+
conn = np.asarray(msh.cells_dict["tetra"], dtype=NP_INDEX_DTYPE)
|
|
43
45
|
mesh = TetMesh(jnp.asarray(coords), jnp.asarray(conn))
|
|
44
46
|
else:
|
|
45
47
|
raise ValueError("gmsh mesh does not contain hexahedron or tetra cells")
|
|
@@ -50,7 +52,7 @@ def load_gmsh_mesh(path: str):
|
|
|
50
52
|
cell_dict = msh.cells_dict
|
|
51
53
|
for fkey in ("quad", "triangle", "tri"):
|
|
52
54
|
if fkey in cell_dict:
|
|
53
|
-
facets = np.asarray(cell_dict[fkey], dtype=
|
|
55
|
+
facets = np.asarray(cell_dict[fkey], dtype=NP_INDEX_DTYPE)
|
|
54
56
|
break
|
|
55
57
|
tags_raw = None
|
|
56
58
|
if facets is not None:
|
|
@@ -63,12 +65,13 @@ def load_gmsh_mesh(path: str):
|
|
|
63
65
|
tags_raw = data
|
|
64
66
|
break
|
|
65
67
|
if tags_raw is not None:
|
|
66
|
-
facet_tags = np.asarray(tags_raw, dtype=
|
|
68
|
+
facet_tags = np.asarray(tags_raw, dtype=NP_INDEX_DTYPE)
|
|
67
69
|
|
|
68
70
|
return mesh, facets, facet_tags
|
|
69
71
|
|
|
70
72
|
|
|
71
73
|
def load_gmsh_hex_mesh(path: str):
|
|
74
|
+
"""Load a Gmsh mesh and return a HexMesh with optional facets/tags."""
|
|
72
75
|
mesh, facets, tags = load_gmsh_mesh(path)
|
|
73
76
|
if not isinstance(mesh, HexMesh):
|
|
74
77
|
raise ValueError("gmsh mesh is not hexahedral")
|
|
@@ -76,6 +79,7 @@ def load_gmsh_hex_mesh(path: str):
|
|
|
76
79
|
|
|
77
80
|
|
|
78
81
|
def load_gmsh_tet_mesh(path: str):
|
|
82
|
+
"""Load a Gmsh mesh and return a TetMesh with optional facets/tags."""
|
|
79
83
|
mesh, facets, tags = load_gmsh_mesh(path)
|
|
80
84
|
if not isinstance(mesh, TetMesh):
|
|
81
85
|
raise ValueError("gmsh mesh is not tetrahedral")
|