fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/mesh/dtypes.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def default_dtype() -> jnp.dtype:
|
|
7
|
+
return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_DTYPE = default_dtype()
|
|
11
|
+
INDEX_DTYPE = jnp.int64
|
|
12
|
+
NP_INDEX_DTYPE = np.int64
|
fluxfem/mesh/hex.py
CHANGED
|
@@ -5,9 +5,9 @@ from dataclasses import dataclass
|
|
|
5
5
|
from typing import Optional, Dict, Tuple, List, Callable
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
|
-
from ..core.dtypes import DEFAULT_DTYPE
|
|
9
8
|
import numpy as np
|
|
10
9
|
|
|
10
|
+
from .dtypes import INDEX_DTYPE, default_dtype
|
|
11
11
|
from .base import BaseMesh, BaseMeshPytree
|
|
12
12
|
|
|
13
13
|
|
|
@@ -17,10 +17,10 @@ class HexMesh(BaseMesh):
|
|
|
17
17
|
Structured / unstructured hex mesh (8-node linear hex elements).
|
|
18
18
|
|
|
19
19
|
coords: (n_nodes, 3) float32
|
|
20
|
-
conn: (n_elems, 8)
|
|
20
|
+
conn: (n_elems, 8) int64 # node indices of each element
|
|
21
21
|
"""
|
|
22
22
|
coords: jnp.ndarray # shape (n_nodes, 3)
|
|
23
|
-
conn: jnp.ndarray # shape (n_elems, 8),
|
|
23
|
+
conn: jnp.ndarray # shape (n_elems, 8), int64
|
|
24
24
|
|
|
25
25
|
def face_node_patterns(self):
|
|
26
26
|
return [
|
|
@@ -59,8 +59,8 @@ def tag_axis_minmax_facets(
|
|
|
59
59
|
Tag boundary facets on min/max of the given axis.
|
|
60
60
|
|
|
61
61
|
Returns:
|
|
62
|
-
facets: (n_facets, 4)
|
|
63
|
-
facet_tags: (n_facets,)
|
|
62
|
+
facets: (n_facets, 4) int64, quad node ids
|
|
63
|
+
facet_tags: (n_facets,) int64, dirichlet_tag on min side, neumann_tag on max side
|
|
64
64
|
"""
|
|
65
65
|
coords = np.asarray(mesh.coords)
|
|
66
66
|
conn = np.asarray(mesh.conn)
|
|
@@ -99,7 +99,7 @@ def tag_axis_minmax_facets(
|
|
|
99
99
|
facet_map[key] = (nodes, tag)
|
|
100
100
|
|
|
101
101
|
if not facet_map:
|
|
102
|
-
return jnp.empty((0, 4), dtype=
|
|
102
|
+
return jnp.empty((0, 4), dtype=INDEX_DTYPE), jnp.empty((0,), dtype=INDEX_DTYPE)
|
|
103
103
|
|
|
104
104
|
facets = []
|
|
105
105
|
tags = []
|
|
@@ -107,7 +107,7 @@ def tag_axis_minmax_facets(
|
|
|
107
107
|
facets.append(nodes)
|
|
108
108
|
tags.append(tag)
|
|
109
109
|
|
|
110
|
-
return jnp.array(facets, dtype=
|
|
110
|
+
return jnp.array(facets, dtype=INDEX_DTYPE), jnp.array(tags, dtype=INDEX_DTYPE)
|
|
111
111
|
|
|
112
112
|
|
|
113
113
|
@dataclass
|
|
@@ -135,9 +135,10 @@ class StructuredHexBox:
|
|
|
135
135
|
raise ValueError("order must be 1, 2, or 3")
|
|
136
136
|
|
|
137
137
|
ox, oy, oz = self.origin
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
138
|
+
dtype = default_dtype()
|
|
139
|
+
xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=dtype)
|
|
140
|
+
ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=dtype)
|
|
141
|
+
zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=dtype)
|
|
141
142
|
|
|
142
143
|
if self.order == 1:
|
|
143
144
|
return self._build_hex8(xs, ys, zs)
|
|
@@ -151,7 +152,7 @@ class StructuredHexBox:
|
|
|
151
152
|
for j in range(self.ny + 1):
|
|
152
153
|
for i in range(self.nx + 1):
|
|
153
154
|
coords_list.append([xs[i], ys[j], zs[k]])
|
|
154
|
-
coords = jnp.array(coords_list, dtype=
|
|
155
|
+
coords = jnp.array(coords_list, dtype=default_dtype())
|
|
155
156
|
|
|
156
157
|
def node_id(i: int, j: int, k: int) -> int:
|
|
157
158
|
return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
|
|
@@ -170,7 +171,7 @@ class StructuredHexBox:
|
|
|
170
171
|
n011 = node_id(i, j + 1, k + 1)
|
|
171
172
|
conn_list.append([n000, n100, n110, n010, n001, n101, n111, n011])
|
|
172
173
|
|
|
173
|
-
conn = jnp.array(conn_list, dtype=
|
|
174
|
+
conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
|
|
174
175
|
return HexMesh(coords=coords, conn=conn)
|
|
175
176
|
|
|
176
177
|
def _build_hex20(self, xs, ys, zs) -> HexMesh:
|
|
@@ -241,8 +242,8 @@ class StructuredHexBox:
|
|
|
241
242
|
]
|
|
242
243
|
)
|
|
243
244
|
|
|
244
|
-
coords = jnp.array(coords_list, dtype=
|
|
245
|
-
conn = jnp.array(conn_list, dtype=
|
|
245
|
+
coords = jnp.array(coords_list, dtype=default_dtype())
|
|
246
|
+
conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
|
|
246
247
|
return HexMesh(coords=coords, conn=conn)
|
|
247
248
|
|
|
248
249
|
def _build_hex27(self, xs, ys, zs) -> HexMesh:
|
|
@@ -323,6 +324,6 @@ class StructuredHexBox:
|
|
|
323
324
|
# order in lexicographic k,j,i -> length 27
|
|
324
325
|
conn_list.append(nodes)
|
|
325
326
|
|
|
326
|
-
coords = jnp.array(coords_list, dtype=
|
|
327
|
-
conn = jnp.array(conn_list, dtype=
|
|
327
|
+
coords = jnp.array(coords_list, dtype=default_dtype())
|
|
328
|
+
conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
|
|
328
329
|
return HexMesh(coords=coords, conn=conn)
|
fluxfem/mesh/io.py
CHANGED
|
@@ -3,6 +3,9 @@ from __future__ import annotations
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .dtypes import NP_INDEX_DTYPE
|
|
6
9
|
|
|
7
10
|
DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
|
|
8
11
|
|
|
@@ -10,7 +13,7 @@ try:
|
|
|
10
13
|
import meshio
|
|
11
14
|
except Exception as e: # pragma: no cover
|
|
12
15
|
meshio = None
|
|
13
|
-
meshio_import_error = e
|
|
16
|
+
meshio_import_error: Optional[Exception] = e
|
|
14
17
|
else:
|
|
15
18
|
meshio_import_error = None
|
|
16
19
|
|
|
@@ -34,12 +37,12 @@ def load_gmsh_mesh(path: str):
|
|
|
34
37
|
msh = meshio.read(path)
|
|
35
38
|
coords = np.asarray(msh.points[:, :3], dtype=DTYPE)
|
|
36
39
|
|
|
37
|
-
mesh = None
|
|
40
|
+
mesh: HexMesh | TetMesh | None = None
|
|
38
41
|
if "hexahedron" in msh.cells_dict:
|
|
39
|
-
conn = np.asarray(msh.cells_dict["hexahedron"], dtype=
|
|
42
|
+
conn = np.asarray(msh.cells_dict["hexahedron"], dtype=NP_INDEX_DTYPE)
|
|
40
43
|
mesh = HexMesh(jnp.asarray(coords), jnp.asarray(conn))
|
|
41
44
|
elif "tetra" in msh.cells_dict:
|
|
42
|
-
conn = np.asarray(msh.cells_dict["tetra"], dtype=
|
|
45
|
+
conn = np.asarray(msh.cells_dict["tetra"], dtype=NP_INDEX_DTYPE)
|
|
43
46
|
mesh = TetMesh(jnp.asarray(coords), jnp.asarray(conn))
|
|
44
47
|
else:
|
|
45
48
|
raise ValueError("gmsh mesh does not contain hexahedron or tetra cells")
|
|
@@ -50,7 +53,7 @@ def load_gmsh_mesh(path: str):
|
|
|
50
53
|
cell_dict = msh.cells_dict
|
|
51
54
|
for fkey in ("quad", "triangle", "tri"):
|
|
52
55
|
if fkey in cell_dict:
|
|
53
|
-
facets = np.asarray(cell_dict[fkey], dtype=
|
|
56
|
+
facets = np.asarray(cell_dict[fkey], dtype=NP_INDEX_DTYPE)
|
|
54
57
|
break
|
|
55
58
|
tags_raw = None
|
|
56
59
|
if facets is not None:
|
|
@@ -63,7 +66,7 @@ def load_gmsh_mesh(path: str):
|
|
|
63
66
|
tags_raw = data
|
|
64
67
|
break
|
|
65
68
|
if tags_raw is not None:
|
|
66
|
-
facet_tags = np.asarray(tags_raw, dtype=
|
|
69
|
+
facet_tags = np.asarray(tags_raw, dtype=NP_INDEX_DTYPE)
|
|
67
70
|
|
|
68
71
|
return mesh, facets, facet_tags
|
|
69
72
|
|