fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. fluxfem/__init__.py +136 -161
  2. fluxfem/core/__init__.py +172 -41
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/context_types.py +36 -0
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +15 -1
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +348 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +262 -17
  13. fluxfem/core/weakform.py +1503 -312
  14. fluxfem/helpers_wf.py +53 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +322 -8
  17. fluxfem/mesh/contact.py +825 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +18 -16
  20. fluxfem/mesh/io.py +8 -4
  21. fluxfem/mesh/mortar.py +3907 -0
  22. fluxfem/mesh/supermesh.py +316 -0
  23. fluxfem/mesh/surface.py +22 -4
  24. fluxfem/mesh/tet.py +10 -4
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  27. fluxfem/physics/elasticity/linear.py +9 -2
  28. fluxfem/solver/__init__.py +42 -2
  29. fluxfem/solver/bc.py +38 -2
  30. fluxfem/solver/block_matrix.py +132 -0
  31. fluxfem/solver/block_system.py +454 -0
  32. fluxfem/solver/cg.py +115 -33
  33. fluxfem/solver/dirichlet.py +334 -4
  34. fluxfem/solver/newton.py +237 -60
  35. fluxfem/solver/petsc.py +439 -0
  36. fluxfem/solver/preconditioner.py +106 -0
  37. fluxfem/solver/result.py +18 -0
  38. fluxfem/solver/solve_runner.py +168 -1
  39. fluxfem/solver/solver.py +12 -1
  40. fluxfem/solver/sparse.py +124 -9
  41. fluxfem-0.2.0.dist-info/METADATA +303 -0
  42. fluxfem-0.2.0.dist-info/RECORD +59 -0
  43. fluxfem-0.1.3.dist-info/METADATA +0 -125
  44. fluxfem-0.1.3.dist-info/RECORD +0 -47
  45. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  46. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/mesh/dtypes.py ADDED
@@ -0,0 +1,12 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+
6
+ def default_dtype() -> jnp.dtype:
7
+ return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
8
+
9
+
10
+ DEFAULT_DTYPE = default_dtype()
11
+ INDEX_DTYPE = jnp.int64
12
+ NP_INDEX_DTYPE = np.int64
fluxfem/mesh/hex.py CHANGED
@@ -5,9 +5,9 @@ from dataclasses import dataclass
5
5
  from typing import Optional, Dict, Tuple, List, Callable
6
6
  import jax
7
7
  import jax.numpy as jnp
8
- from ..core.dtypes import DEFAULT_DTYPE
9
8
  import numpy as np
10
9
 
10
+ from .dtypes import INDEX_DTYPE, default_dtype
11
11
  from .base import BaseMesh, BaseMeshPytree
12
12
 
13
13
 
@@ -17,10 +17,10 @@ class HexMesh(BaseMesh):
17
17
  Structured / unstructured hex mesh (8-node linear hex elements).
18
18
 
19
19
  coords: (n_nodes, 3) float32
20
- conn: (n_elems, 8) int32 # node indices of each element
20
+ conn: (n_elems, 8) int64 # node indices of each element
21
21
  """
22
22
  coords: jnp.ndarray # shape (n_nodes, 3)
23
- conn: jnp.ndarray # shape (n_elems, 8), int32
23
+ conn: jnp.ndarray # shape (n_elems, 8), int64
24
24
 
25
25
  def face_node_patterns(self):
26
26
  return [
@@ -36,6 +36,7 @@ class HexMesh(BaseMesh):
36
36
  @jax.tree_util.register_pytree_node_class
37
37
  @dataclass(eq=False)
38
38
  class HexMeshPytree(BaseMeshPytree):
39
+ """Hex mesh registered as a JAX pytree."""
39
40
  def face_node_patterns(self):
40
41
  return [
41
42
  (0, 1, 2, 3), # -z
@@ -58,8 +59,8 @@ def tag_axis_minmax_facets(
58
59
  Tag boundary facets on min/max of the given axis.
59
60
 
60
61
  Returns:
61
- facets: (n_facets, 4) int32, quad node ids
62
- facet_tags: (n_facets,) int32, dirichlet_tag on min side, neumann_tag on max side
62
+ facets: (n_facets, 4) int64, quad node ids
63
+ facet_tags: (n_facets,) int64, dirichlet_tag on min side, neumann_tag on max side
63
64
  """
64
65
  coords = np.asarray(mesh.coords)
65
66
  conn = np.asarray(mesh.conn)
@@ -98,7 +99,7 @@ def tag_axis_minmax_facets(
98
99
  facet_map[key] = (nodes, tag)
99
100
 
100
101
  if not facet_map:
101
- return jnp.empty((0, 4), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
102
+ return jnp.empty((0, 4), dtype=INDEX_DTYPE), jnp.empty((0,), dtype=INDEX_DTYPE)
102
103
 
103
104
  facets = []
104
105
  tags = []
@@ -106,7 +107,7 @@ def tag_axis_minmax_facets(
106
107
  facets.append(nodes)
107
108
  tags.append(tag)
108
109
 
109
- return jnp.array(facets, dtype=jnp.int32), jnp.array(tags, dtype=jnp.int32)
110
+ return jnp.array(facets, dtype=INDEX_DTYPE), jnp.array(tags, dtype=INDEX_DTYPE)
110
111
 
111
112
 
112
113
  @dataclass
@@ -134,9 +135,10 @@ class StructuredHexBox:
134
135
  raise ValueError("order must be 1, 2, or 3")
135
136
 
136
137
  ox, oy, oz = self.origin
137
- xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=DEFAULT_DTYPE)
138
- ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=DEFAULT_DTYPE)
139
- zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=DEFAULT_DTYPE)
138
+ dtype = default_dtype()
139
+ xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=dtype)
140
+ ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=dtype)
141
+ zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=dtype)
140
142
 
141
143
  if self.order == 1:
142
144
  return self._build_hex8(xs, ys, zs)
@@ -150,7 +152,7 @@ class StructuredHexBox:
150
152
  for j in range(self.ny + 1):
151
153
  for i in range(self.nx + 1):
152
154
  coords_list.append([xs[i], ys[j], zs[k]])
153
- coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
155
+ coords = jnp.array(coords_list, dtype=default_dtype())
154
156
 
155
157
  def node_id(i: int, j: int, k: int) -> int:
156
158
  return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
@@ -169,7 +171,7 @@ class StructuredHexBox:
169
171
  n011 = node_id(i, j + 1, k + 1)
170
172
  conn_list.append([n000, n100, n110, n010, n001, n101, n111, n011])
171
173
 
172
- conn = jnp.array(conn_list, dtype=jnp.int32)
174
+ conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
173
175
  return HexMesh(coords=coords, conn=conn)
174
176
 
175
177
  def _build_hex20(self, xs, ys, zs) -> HexMesh:
@@ -240,8 +242,8 @@ class StructuredHexBox:
240
242
  ]
241
243
  )
242
244
 
243
- coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
244
- conn = jnp.array(conn_list, dtype=jnp.int32)
245
+ coords = jnp.array(coords_list, dtype=default_dtype())
246
+ conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
245
247
  return HexMesh(coords=coords, conn=conn)
246
248
 
247
249
  def _build_hex27(self, xs, ys, zs) -> HexMesh:
@@ -322,6 +324,6 @@ class StructuredHexBox:
322
324
  # order in lexicographic k,j,i -> length 27
323
325
  conn_list.append(nodes)
324
326
 
325
- coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
326
- conn = jnp.array(conn_list, dtype=jnp.int32)
327
+ coords = jnp.array(coords_list, dtype=default_dtype())
328
+ conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
327
329
  return HexMesh(coords=coords, conn=conn)
fluxfem/mesh/io.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
 
7
+ from .dtypes import NP_INDEX_DTYPE
8
+
7
9
  DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
8
10
 
9
11
  try:
@@ -36,10 +38,10 @@ def load_gmsh_mesh(path: str):
36
38
 
37
39
  mesh = None
38
40
  if "hexahedron" in msh.cells_dict:
39
- conn = np.asarray(msh.cells_dict["hexahedron"], dtype=np.int32)
41
+ conn = np.asarray(msh.cells_dict["hexahedron"], dtype=NP_INDEX_DTYPE)
40
42
  mesh = HexMesh(jnp.asarray(coords), jnp.asarray(conn))
41
43
  elif "tetra" in msh.cells_dict:
42
- conn = np.asarray(msh.cells_dict["tetra"], dtype=np.int32)
44
+ conn = np.asarray(msh.cells_dict["tetra"], dtype=NP_INDEX_DTYPE)
43
45
  mesh = TetMesh(jnp.asarray(coords), jnp.asarray(conn))
44
46
  else:
45
47
  raise ValueError("gmsh mesh does not contain hexahedron or tetra cells")
@@ -50,7 +52,7 @@ def load_gmsh_mesh(path: str):
50
52
  cell_dict = msh.cells_dict
51
53
  for fkey in ("quad", "triangle", "tri"):
52
54
  if fkey in cell_dict:
53
- facets = np.asarray(cell_dict[fkey], dtype=np.int32)
55
+ facets = np.asarray(cell_dict[fkey], dtype=NP_INDEX_DTYPE)
54
56
  break
55
57
  tags_raw = None
56
58
  if facets is not None:
@@ -63,12 +65,13 @@ def load_gmsh_mesh(path: str):
63
65
  tags_raw = data
64
66
  break
65
67
  if tags_raw is not None:
66
- facet_tags = np.asarray(tags_raw, dtype=np.int32)
68
+ facet_tags = np.asarray(tags_raw, dtype=NP_INDEX_DTYPE)
67
69
 
68
70
  return mesh, facets, facet_tags
69
71
 
70
72
 
71
73
  def load_gmsh_hex_mesh(path: str):
74
+ """Load a Gmsh mesh and return a HexMesh with optional facets/tags."""
72
75
  mesh, facets, tags = load_gmsh_mesh(path)
73
76
  if not isinstance(mesh, HexMesh):
74
77
  raise ValueError("gmsh mesh is not hexahedral")
@@ -76,6 +79,7 @@ def load_gmsh_hex_mesh(path: str):
76
79
 
77
80
 
78
81
  def load_gmsh_tet_mesh(path: str):
82
+ """Load a Gmsh mesh and return a TetMesh with optional facets/tags."""
79
83
  mesh, facets, tags = load_gmsh_mesh(path)
80
84
  if not isinstance(mesh, TetMesh):
81
85
  raise ValueError("gmsh mesh is not tetrahedral")