fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/mesh/dtypes.py ADDED
@@ -0,0 +1,12 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+
6
+ def default_dtype() -> jnp.dtype:
7
+ return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
8
+
9
+
10
+ DEFAULT_DTYPE = default_dtype()
11
+ INDEX_DTYPE = jnp.int64
12
+ NP_INDEX_DTYPE = np.int64
fluxfem/mesh/hex.py CHANGED
@@ -5,9 +5,9 @@ from dataclasses import dataclass
5
5
  from typing import Optional, Dict, Tuple, List, Callable
6
6
  import jax
7
7
  import jax.numpy as jnp
8
- from ..core.dtypes import DEFAULT_DTYPE
9
8
  import numpy as np
10
9
 
10
+ from .dtypes import INDEX_DTYPE, default_dtype
11
11
  from .base import BaseMesh, BaseMeshPytree
12
12
 
13
13
 
@@ -17,10 +17,10 @@ class HexMesh(BaseMesh):
17
17
  Structured / unstructured hex mesh (8-node linear hex elements).
18
18
 
19
19
  coords: (n_nodes, 3) float32
20
- conn: (n_elems, 8) int32 # node indices of each element
20
+ conn: (n_elems, 8) int64 # node indices of each element
21
21
  """
22
22
  coords: jnp.ndarray # shape (n_nodes, 3)
23
- conn: jnp.ndarray # shape (n_elems, 8), int32
23
+ conn: jnp.ndarray # shape (n_elems, 8), int64
24
24
 
25
25
  def face_node_patterns(self):
26
26
  return [
@@ -59,8 +59,8 @@ def tag_axis_minmax_facets(
59
59
  Tag boundary facets on min/max of the given axis.
60
60
 
61
61
  Returns:
62
- facets: (n_facets, 4) int32, quad node ids
63
- facet_tags: (n_facets,) int32, dirichlet_tag on min side, neumann_tag on max side
62
+ facets: (n_facets, 4) int64, quad node ids
63
+ facet_tags: (n_facets,) int64, dirichlet_tag on min side, neumann_tag on max side
64
64
  """
65
65
  coords = np.asarray(mesh.coords)
66
66
  conn = np.asarray(mesh.conn)
@@ -99,7 +99,7 @@ def tag_axis_minmax_facets(
99
99
  facet_map[key] = (nodes, tag)
100
100
 
101
101
  if not facet_map:
102
- return jnp.empty((0, 4), dtype=jnp.int32), jnp.empty((0,), dtype=jnp.int32)
102
+ return jnp.empty((0, 4), dtype=INDEX_DTYPE), jnp.empty((0,), dtype=INDEX_DTYPE)
103
103
 
104
104
  facets = []
105
105
  tags = []
@@ -107,7 +107,7 @@ def tag_axis_minmax_facets(
107
107
  facets.append(nodes)
108
108
  tags.append(tag)
109
109
 
110
- return jnp.array(facets, dtype=jnp.int32), jnp.array(tags, dtype=jnp.int32)
110
+ return jnp.array(facets, dtype=INDEX_DTYPE), jnp.array(tags, dtype=INDEX_DTYPE)
111
111
 
112
112
 
113
113
  @dataclass
@@ -135,9 +135,10 @@ class StructuredHexBox:
135
135
  raise ValueError("order must be 1, 2, or 3")
136
136
 
137
137
  ox, oy, oz = self.origin
138
- xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=DEFAULT_DTYPE)
139
- ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=DEFAULT_DTYPE)
140
- zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=DEFAULT_DTYPE)
138
+ dtype = default_dtype()
139
+ xs = jnp.linspace(ox, ox + self.lx, self.nx + 1, dtype=dtype)
140
+ ys = jnp.linspace(oy, oy + self.ly, self.ny + 1, dtype=dtype)
141
+ zs = jnp.linspace(oz, oz + self.lz, self.nz + 1, dtype=dtype)
141
142
 
142
143
  if self.order == 1:
143
144
  return self._build_hex8(xs, ys, zs)
@@ -151,7 +152,7 @@ class StructuredHexBox:
151
152
  for j in range(self.ny + 1):
152
153
  for i in range(self.nx + 1):
153
154
  coords_list.append([xs[i], ys[j], zs[k]])
154
- coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
155
+ coords = jnp.array(coords_list, dtype=default_dtype())
155
156
 
156
157
  def node_id(i: int, j: int, k: int) -> int:
157
158
  return k * (self.ny + 1) * (self.nx + 1) + j * (self.nx + 1) + i
@@ -170,7 +171,7 @@ class StructuredHexBox:
170
171
  n011 = node_id(i, j + 1, k + 1)
171
172
  conn_list.append([n000, n100, n110, n010, n001, n101, n111, n011])
172
173
 
173
- conn = jnp.array(conn_list, dtype=jnp.int32)
174
+ conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
174
175
  return HexMesh(coords=coords, conn=conn)
175
176
 
176
177
  def _build_hex20(self, xs, ys, zs) -> HexMesh:
@@ -241,8 +242,8 @@ class StructuredHexBox:
241
242
  ]
242
243
  )
243
244
 
244
- coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
245
- conn = jnp.array(conn_list, dtype=jnp.int32)
245
+ coords = jnp.array(coords_list, dtype=default_dtype())
246
+ conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
246
247
  return HexMesh(coords=coords, conn=conn)
247
248
 
248
249
  def _build_hex27(self, xs, ys, zs) -> HexMesh:
@@ -323,6 +324,6 @@ class StructuredHexBox:
323
324
  # order in lexicographic k,j,i -> length 27
324
325
  conn_list.append(nodes)
325
326
 
326
- coords = jnp.array(coords_list, dtype=DEFAULT_DTYPE)
327
- conn = jnp.array(conn_list, dtype=jnp.int32)
327
+ coords = jnp.array(coords_list, dtype=default_dtype())
328
+ conn = jnp.array(conn_list, dtype=INDEX_DTYPE)
328
329
  return HexMesh(coords=coords, conn=conn)
fluxfem/mesh/io.py CHANGED
@@ -3,6 +3,9 @@ from __future__ import annotations
3
3
  import numpy as np
4
4
  import jax
5
5
  import jax.numpy as jnp
6
+ from typing import Optional
7
+
8
+ from .dtypes import NP_INDEX_DTYPE
6
9
 
7
10
  DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
8
11
 
@@ -10,7 +13,7 @@ try:
10
13
  import meshio
11
14
  except Exception as e: # pragma: no cover
12
15
  meshio = None
13
- meshio_import_error = e
16
+ meshio_import_error: Optional[Exception] = e
14
17
  else:
15
18
  meshio_import_error = None
16
19
 
@@ -34,12 +37,12 @@ def load_gmsh_mesh(path: str):
34
37
  msh = meshio.read(path)
35
38
  coords = np.asarray(msh.points[:, :3], dtype=DTYPE)
36
39
 
37
- mesh = None
40
+ mesh: HexMesh | TetMesh | None = None
38
41
  if "hexahedron" in msh.cells_dict:
39
- conn = np.asarray(msh.cells_dict["hexahedron"], dtype=np.int32)
42
+ conn = np.asarray(msh.cells_dict["hexahedron"], dtype=NP_INDEX_DTYPE)
40
43
  mesh = HexMesh(jnp.asarray(coords), jnp.asarray(conn))
41
44
  elif "tetra" in msh.cells_dict:
42
- conn = np.asarray(msh.cells_dict["tetra"], dtype=np.int32)
45
+ conn = np.asarray(msh.cells_dict["tetra"], dtype=NP_INDEX_DTYPE)
43
46
  mesh = TetMesh(jnp.asarray(coords), jnp.asarray(conn))
44
47
  else:
45
48
  raise ValueError("gmsh mesh does not contain hexahedron or tetra cells")
@@ -50,7 +53,7 @@ def load_gmsh_mesh(path: str):
50
53
  cell_dict = msh.cells_dict
51
54
  for fkey in ("quad", "triangle", "tri"):
52
55
  if fkey in cell_dict:
53
- facets = np.asarray(cell_dict[fkey], dtype=np.int32)
56
+ facets = np.asarray(cell_dict[fkey], dtype=NP_INDEX_DTYPE)
54
57
  break
55
58
  tags_raw = None
56
59
  if facets is not None:
@@ -63,7 +66,7 @@ def load_gmsh_mesh(path: str):
63
66
  tags_raw = data
64
67
  break
65
68
  if tags_raw is not None:
66
- facet_tags = np.asarray(tags_raw, dtype=np.int32)
69
+ facet_tags = np.asarray(tags_raw, dtype=NP_INDEX_DTYPE)
67
70
 
68
71
  return mesh, facets, facet_tags
69
72