fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/core/space.py ADDED
@@ -0,0 +1,419 @@
1
+ from __future__ import annotations
2
+ import operator
3
+ from dataclasses import dataclass, field
4
+ from typing import Protocol
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ from ..mesh import (
10
+ BaseMesh,
11
+ BaseMeshPytree,
12
+ HexMesh,
13
+ HexMeshPytree,
14
+ TetMesh,
15
+ TetMeshPytree,
16
+ )
17
+ from .basis import (
18
+ Basis3D,
19
+ HexTriLinearBasis,
20
+ HexTriLinearBasisPytree,
21
+ HexSerendipityBasis20,
22
+ HexSerendipityBasis20Pytree,
23
+ HexTriQuadraticBasis27,
24
+ HexTriQuadraticBasis27Pytree,
25
+ TetLinearBasis,
26
+ TetLinearBasisPytree,
27
+ TetQuadraticBasis10,
28
+ TetQuadraticBasis10Pytree,
29
+ make_hex20_basis,
30
+ make_hex20_basis_pytree,
31
+ make_hex27_basis,
32
+ make_hex27_basis_pytree,
33
+ make_hex_basis,
34
+ make_hex_basis_pytree,
35
+ make_tet10_basis,
36
+ make_tet10_basis_pytree,
37
+ make_tet_basis,
38
+ make_tet_basis_pytree,
39
+ )
40
+ from .forms import (
41
+ ElementVector,
42
+ FormContext,
43
+ FormFieldLike,
44
+ ScalarFormField,
45
+ VectorFormField,
46
+ )
47
+ from .data import SpaceData
48
+
49
+
50
+ class FESpaceBase(Protocol):
51
+ """
52
+ Protocol for FE space objects used by assembly.
53
+
54
+ This defines the minimal interface required by the core assembly routines:
55
+ element-to-DOF connectivity, value dimension, and the ability to build
56
+ per-element FormContext objects (test/trial fields plus quadrature data).
57
+ """
58
+ elem_dofs: jnp.ndarray
59
+ value_dim: int
60
+ n_dofs: int
61
+ n_ldofs: int
62
+
63
+ def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext: ...
64
+
65
+
66
+ @dataclass(eq=False)
67
+ class FESpaceClosure:
68
+ """
69
+ Finite element space built from a mesh, basis, and element dof map.
70
+
71
+ This is the standard space used by fluxfem. It bundles:
72
+ - a mesh (geometry and connectivity),
73
+ - a basis (shape functions + quadrature),
74
+ - an element-to-DOF map (elem_dofs),
75
+ - and metadata such as value_dim and cached sparsity patterns.
76
+
77
+ The class provides thin wrappers around assembly helpers and constructs
78
+ FormContext objects for element-level integration.
79
+ """
80
+ mesh: BaseMesh
81
+ basis: Basis3D
82
+ elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int32
83
+ value_dim: int = 1 # 1=scalar, 3=vector, etc.
84
+ _n_dofs: int | None = None
85
+ _n_ldofs: int | None = None
86
+ data: SpaceData | None = None
87
+ _pattern_cache: dict[bool, object] = field(default_factory=dict, repr=False)
88
+
89
+ def __post_init__(self):
90
+ # Ensure value_dim is a Python int (avoid tracers).
91
+ self.value_dim = operator.index(self.value_dim)
92
+
93
+ if self._n_ldofs is None:
94
+ self._n_ldofs = int(self.elem_dofs.shape[1])
95
+ if self._n_dofs is None:
96
+ self._n_dofs = int(np.asarray(self.elem_dofs).max()) + 1
97
+
98
+ n_nodes = int(self.mesh.element_coords().shape[1])
99
+ expected = n_nodes * self.value_dim
100
+ if self._n_ldofs != expected:
101
+ raise ValueError(
102
+ f"n_ldofs mismatch: elem_dofs has {self._n_ldofs}, "
103
+ f"but n_nodes({n_nodes})*value_dim({self.value_dim})={expected}"
104
+ )
105
+
106
+ if self.data is None:
107
+ self.data = SpaceData.from_space(self)
108
+
109
+ @property
110
+ def n_dofs(self) -> int:
111
+ assert self._n_dofs is not None
112
+ return self._n_dofs
113
+
114
+ @property
115
+ def n_ldofs(self) -> int:
116
+ assert self._n_ldofs is not None
117
+ return self._n_ldofs
118
+
119
+ def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext:
120
+ def _tie_in(x, y):
121
+ if x is None:
122
+ return y
123
+ try:
124
+ return jax.lax.tie_in(x, y)
125
+ except AttributeError:
126
+ return y + jnp.sin(x) * 0
127
+
128
+ vd = int(self.value_dim)
129
+ mesh, basis = self.mesh, self.basis
130
+ elem_coords = mesh.element_coords() # (n_elems, n_nodes, 3)
131
+ elem_coords = _tie_in(dep, elem_coords)
132
+
133
+ N_ref = basis.shape_functions() # (n_q, n_nodes)
134
+ w_ref = basis.quad_weights # (n_q,)
135
+ x_q = jnp.einsum("qa,eai->eqi", N_ref, elem_coords) # (n_elems, n_q, 3)
136
+ w = jnp.broadcast_to(
137
+ w_ref[None, :], (elem_coords.shape[0], w_ref.shape[0])
138
+ )
139
+ w = _tie_in(dep, w)
140
+
141
+ if vd == 1:
142
+ def make_field(Xe):
143
+ return ScalarFormField(N=N_ref, elem_coords=Xe, basis=basis)
144
+ else:
145
+ def make_field(Xe):
146
+ return VectorFormField(
147
+ N=N_ref, elem_coords=Xe, basis=basis, value_dim=vd
148
+ )
149
+
150
+ test = jax.vmap(make_field)(elem_coords)
151
+ trial = jax.vmap(make_field)(elem_coords)
152
+
153
+ return FormContext(
154
+ test=test, trial=trial, x_q=x_q,
155
+ w=w, elem_id=jnp.arange(elem_coords.shape[0])
156
+ )
157
+
158
+ # --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
159
+ def assemble_bilinear_form(self, form, params, *, chunk_size=None, dep=None, **kwargs):
160
+ from .assembly import assemble_bilinear_form
161
+ if "pattern" not in kwargs or kwargs.get("pattern") is None:
162
+ kwargs["pattern"] = self.get_sparsity_pattern(with_idx=True)
163
+ return assemble_bilinear_form(self, form, params, chunk_size=chunk_size, dep=dep, **kwargs)
164
+
165
+ def assemble_linear_form(self, form, params, *, chunk_size=None, dep=None, **kwargs):
166
+ from .assembly import assemble_linear_form
167
+ return assemble_linear_form(self, form, params, chunk_size=chunk_size, dep=dep, **kwargs)
168
+
169
+ def assemble_functional(self, form, params):
170
+ from .assembly import assemble_functional
171
+ return assemble_functional(self, form, params)
172
+
173
+ def assemble_mass_matrix(self, *, chunk_size=None, **kwargs):
174
+ from .assembly import assemble_mass_matrix
175
+ return assemble_mass_matrix(self, chunk_size=chunk_size, **kwargs)
176
+
177
+ def assemble_bilinear_dense(self, kernel, params, **kwargs):
178
+ from .assembly import assemble_bilinear_dense
179
+ return assemble_bilinear_dense(self, kernel, params, **kwargs)
180
+
181
+ def assemble_residual(self, res_form, u, params, **kwargs):
182
+ from .assembly import assemble_residual
183
+ return assemble_residual(self, res_form, u, params, **kwargs)
184
+
185
+ def assemble_jacobian(self, res_form, u, params, **kwargs):
186
+ from .assembly import assemble_jacobian
187
+ return assemble_jacobian(self, res_form, u, params, **kwargs)
188
+
189
+ def get_sparsity_pattern(self, *, with_idx: bool = True):
190
+ cached = self._pattern_cache.get(with_idx)
191
+ if cached is not None:
192
+ return cached
193
+ from .assembly import make_sparsity_pattern
194
+ pat = make_sparsity_pattern(self, with_idx=with_idx)
195
+ self._pattern_cache[with_idx] = pat
196
+ return pat
197
+
198
+
199
+ @jax.tree_util.register_pytree_node_class
200
+ class FESpacePytree(FESpaceClosure):
201
+ """
202
+ FESpaceClosure with JAX pytree support.
203
+
204
+ Use this when a space must be carried through JAX transformations (jit/vmap),
205
+ or stored inside other pytrees. Only mesh, basis, and elem_dofs are treated
206
+ as children; metadata is preserved as auxiliary data.
207
+ """
208
+ def tree_flatten(self):
209
+ children = (self.mesh, self.basis, self.elem_dofs)
210
+ aux = {
211
+ "value_dim": int(self.value_dim),
212
+ "_n_dofs": self._n_dofs,
213
+ "_n_ldofs": self._n_ldofs,
214
+ }
215
+ return children, aux
216
+
217
+ @classmethod
218
+ def tree_unflatten(cls, aux, children):
219
+ mesh, basis, elem_dofs = children
220
+ return cls(
221
+ mesh=mesh,
222
+ basis=basis,
223
+ elem_dofs=elem_dofs,
224
+ value_dim=aux["value_dim"],
225
+ _n_dofs=aux.get("_n_dofs", None),
226
+ _n_ldofs=aux.get("_n_ldofs", None),
227
+ )
228
+
229
+
230
+ FESpace = FESpaceClosure
231
+
232
+
233
+ def make_space(
234
+ mesh: BaseMesh,
235
+ basis: Basis3D,
236
+ element: ElementVector | None = None,
237
+ ) -> FESpace:
238
+ """
239
+ Build an FE space from a mesh and basis.
240
+
241
+ element=None → scalar dof per node (elem_dofs = mesh.conn), value_dim=1
242
+ element=ElementVector(dim) → vector dof per node, value_dim=dim
243
+ """
244
+ if element is None:
245
+ elem_dofs = mesh.conn
246
+ value_dim = 1
247
+ else:
248
+ elem_dofs = element.dof_map(mesh.conn)
249
+ value_dim = int(element.dim)
250
+
251
+ return FESpace(
252
+ mesh=mesh,
253
+ basis=basis,
254
+ elem_dofs=jnp.asarray(elem_dofs, dtype=jnp.int32),
255
+ value_dim=value_dim
256
+ )
257
+
258
+
259
+ def _mesh_to_pytree(mesh: BaseMesh) -> BaseMeshPytree:
260
+ if isinstance(mesh, HexMeshPytree) or isinstance(mesh, TetMeshPytree):
261
+ return mesh
262
+ if isinstance(mesh, HexMesh):
263
+ return HexMeshPytree(
264
+ coords=mesh.coords,
265
+ conn=mesh.conn,
266
+ cell_tags=mesh.cell_tags,
267
+ node_tags=mesh.node_tags,
268
+ )
269
+ if isinstance(mesh, TetMesh):
270
+ return TetMeshPytree(
271
+ coords=mesh.coords,
272
+ conn=mesh.conn,
273
+ cell_tags=mesh.cell_tags,
274
+ node_tags=mesh.node_tags,
275
+ )
276
+ raise TypeError(f"Unsupported mesh type for pytree: {type(mesh)}")
277
+
278
+
279
+ def _basis_to_pytree(basis):
280
+ if isinstance(
281
+ basis,
282
+ (
283
+ HexTriLinearBasisPytree,
284
+ HexSerendipityBasis20Pytree,
285
+ HexTriQuadraticBasis27Pytree,
286
+ TetLinearBasisPytree,
287
+ TetQuadraticBasis10Pytree,
288
+ ),
289
+ ):
290
+ return basis
291
+ if isinstance(basis, HexTriLinearBasis):
292
+ return HexTriLinearBasisPytree(basis.quad_points, basis.quad_weights)
293
+ if isinstance(basis, HexSerendipityBasis20):
294
+ return HexSerendipityBasis20Pytree(basis.quad_points, basis.quad_weights)
295
+ if isinstance(basis, HexTriQuadraticBasis27):
296
+ return HexTriQuadraticBasis27Pytree(basis.quad_points, basis.quad_weights)
297
+ if isinstance(basis, TetLinearBasis):
298
+ return TetLinearBasisPytree(basis.quad_points, basis.quad_weights)
299
+ if isinstance(basis, TetQuadraticBasis10):
300
+ return TetQuadraticBasis10Pytree(basis.quad_points, basis.quad_weights)
301
+ raise TypeError(f"Unsupported basis type for pytree: {type(basis)}")
302
+
303
+
304
+ def make_space_pytree(
305
+ mesh: BaseMeshPytree,
306
+ basis: Basis3D,
307
+ element: ElementVector | None = None,
308
+ ) -> FESpacePytree:
309
+ """Build a pytree-compatible FE space."""
310
+ if element is None:
311
+ elem_dofs = mesh.conn
312
+ value_dim = 1
313
+ else:
314
+ elem_dofs = element.dof_map(mesh.conn)
315
+ value_dim = int(element.dim)
316
+
317
+ mesh_py = _mesh_to_pytree(mesh)
318
+ basis_py = _basis_to_pytree(basis)
319
+
320
+ return FESpacePytree(
321
+ mesh=mesh_py,
322
+ basis=basis_py,
323
+ elem_dofs=jnp.asarray(elem_dofs, dtype=jnp.int32),
324
+ value_dim=value_dim,
325
+ )
326
+
327
+
328
+ def make_tet10_space(
329
+ mesh: TetMesh, dim: int = 1, intorder: int = 2
330
+ ) -> FESpace:
331
+ """Create a quadratic tet space (10-node elements)."""
332
+ basis = make_tet10_basis(intorder)
333
+ element = None if dim == 1 else ElementVector(dim)
334
+ return make_space(mesh, basis, element)
335
+
336
+
337
+ def make_tet10_space_pytree(
338
+ mesh: TetMesh, dim: int = 1, intorder: int = 2
339
+ ) -> FESpacePytree:
340
+ """Create a pytree quadratic tet space (10-node elements)."""
341
+ basis = make_tet10_basis_pytree(intorder)
342
+ element = None if dim == 1 else ElementVector(dim)
343
+ return make_space_pytree(mesh, basis, element)
344
+
345
+
346
+ def make_hex_space(mesh: HexMesh, dim: int = 1, intorder: int = 2) -> FESpace:
347
+ """Create a trilinear hex space (8-node elements)."""
348
+ basis = make_hex_basis(intorder)
349
+ element = None if dim == 1 else ElementVector(dim)
350
+ return make_space(mesh, basis, element)
351
+
352
+
353
+ def make_hex_space_pytree(
354
+ mesh: HexMesh, dim: int = 1, intorder: int = 2
355
+ ) -> FESpacePytree:
356
+ """Create a pytree trilinear hex space (8-node elements)."""
357
+ basis = make_hex_basis_pytree(intorder)
358
+ element = None if dim == 1 else ElementVector(dim)
359
+ return make_space_pytree(mesh, basis, element)
360
+
361
+
362
+ def make_hex20_space(
363
+ mesh: HexMesh, dim: int = 1, intorder: int = 2
364
+ ) -> FESpace:
365
+ """Create a serendipity hex space (20-node elements)."""
366
+ basis = make_hex20_basis(intorder)
367
+ element = None if dim == 1 else ElementVector(dim)
368
+ return make_space(mesh, basis, element)
369
+
370
+
371
+ def make_hex20_space_pytree(
372
+ mesh: HexMesh, dim: int = 1, intorder: int = 2
373
+ ) -> FESpacePytree:
374
+ """Create a pytree serendipity hex space (20-node elements)."""
375
+ basis = make_hex20_basis_pytree(intorder)
376
+ element = None if dim == 1 else ElementVector(dim)
377
+ return make_space_pytree(mesh, basis, element)
378
+
379
+
380
+ def make_hex27_space(
381
+ mesh: HexMesh, dim: int = 1, intorder: int = 3
382
+ ) -> FESpace:
383
+ """Create a triquadratic hex space (27-node elements)."""
384
+ basis = make_hex27_basis(intorder)
385
+ element = None if dim == 1 else ElementVector(dim)
386
+ return make_space(mesh, basis, element)
387
+
388
+
389
+ def make_hex27_space_pytree(
390
+ mesh: HexMesh, dim: int = 1, intorder: int = 3
391
+ ) -> FESpacePytree:
392
+ """Create a pytree triquadratic hex space (27-node elements)."""
393
+ basis = make_hex27_basis_pytree(intorder)
394
+ element = None if dim == 1 else ElementVector(dim)
395
+ return make_space_pytree(mesh, basis, element)
396
+
397
+
398
+ def make_tet_space(mesh: TetMesh, dim: int = 1, intorder: int = 2) -> FESpace:
399
+ """Create a linear or quadratic tet space based on mesh nodes."""
400
+ n_nodes = mesh.conn.shape[1]
401
+ if n_nodes == 10:
402
+ basis = make_tet10_basis(intorder if intorder > 1 else 2)
403
+ else:
404
+ basis = make_tet_basis(intorder)
405
+ element = None if dim == 1 else ElementVector(dim)
406
+ return make_space(mesh, basis, element)
407
+
408
+
409
+ def make_tet_space_pytree(
410
+ mesh: TetMesh, dim: int = 1, intorder: int = 2
411
+ ) -> FESpacePytree:
412
+ """Create a pytree linear or quadratic tet space based on mesh nodes."""
413
+ n_nodes = mesh.conn.shape[1]
414
+ if n_nodes == 10:
415
+ basis = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
416
+ else:
417
+ basis = make_tet_basis_pytree(intorder)
418
+ element = None if dim == 1 else ElementVector(dim)
419
+ return make_space_pytree(mesh, basis, element)