fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +1 -13
- fluxfem/core/__init__.py +53 -71
- fluxfem/core/assembly.py +41 -32
- fluxfem/core/basis.py +2 -2
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/mixed_space.py +42 -8
- fluxfem/core/mixed_weakform.py +1 -1
- fluxfem/core/space.py +68 -28
- fluxfem/core/weakform.py +95 -77
- fluxfem/mesh/base.py +3 -3
- fluxfem/mesh/contact.py +33 -17
- fluxfem/mesh/io.py +3 -2
- fluxfem/mesh/mortar.py +106 -43
- fluxfem/mesh/supermesh.py +2 -0
- fluxfem/mesh/surface.py +82 -22
- fluxfem/mesh/tet.py +7 -4
- fluxfem/physics/elasticity/hyperelastic.py +32 -3
- fluxfem/physics/elasticity/linear.py +13 -2
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +6 -1
- fluxfem/solver/block_matrix.py +165 -13
- fluxfem/solver/block_system.py +52 -29
- fluxfem/solver/cg.py +43 -30
- fluxfem/solver/dirichlet.py +35 -12
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +25 -12
- fluxfem/solver/petsc.py +13 -7
- fluxfem/solver/preconditioner.py +7 -4
- fluxfem/solver/solve_runner.py +42 -24
- fluxfem/solver/solver.py +23 -11
- fluxfem/solver/sparse.py +32 -13
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.2.0.dist-info/RECORD +0 -59
- {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/space.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
import operator
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeVar
|
|
4
|
+
from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, cast
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import numpy as np
|
|
@@ -16,15 +16,39 @@ if TYPE_CHECKING:
|
|
|
16
16
|
ElementLinearKernel,
|
|
17
17
|
ElementResidualKernel,
|
|
18
18
|
Kernel,
|
|
19
|
+
FormKernel,
|
|
20
|
+
BilinearReturn,
|
|
21
|
+
JacobianReturn,
|
|
22
|
+
LinearReturn,
|
|
23
|
+
MassReturn,
|
|
19
24
|
ResidualForm,
|
|
20
25
|
)
|
|
26
|
+
from .weakform import BilinearForm, LinearForm
|
|
27
|
+
from ..solver import FluxSparseMatrix, SparsityPattern
|
|
21
28
|
else:
|
|
22
29
|
Kernel = Callable[..., Any]
|
|
30
|
+
FormKernel = Callable[..., Any]
|
|
23
31
|
ResidualForm = Callable[..., Any]
|
|
24
32
|
ElementBilinearKernel = Callable[..., Any]
|
|
25
33
|
ElementLinearKernel = Callable[..., Any]
|
|
26
34
|
ElementResidualKernel = Callable[..., Any]
|
|
27
35
|
ElementJacobianKernel = Callable[..., Any]
|
|
36
|
+
BilinearReturn = Any
|
|
37
|
+
JacobianReturn = Any
|
|
38
|
+
LinearReturn = Any
|
|
39
|
+
MassReturn = Any
|
|
40
|
+
BilinearForm = Any
|
|
41
|
+
LinearForm = Any
|
|
42
|
+
FluxSparseMatrix = Any
|
|
43
|
+
SparsityPattern = Any
|
|
44
|
+
|
|
45
|
+
KernelCacheKey: TypeAlias = tuple[str, int, int, bool]
|
|
46
|
+
KernelCache: TypeAlias = dict[
|
|
47
|
+
KernelCacheKey,
|
|
48
|
+
ElementBilinearKernel | ElementLinearKernel | ElementResidualKernel | ElementJacobianKernel,
|
|
49
|
+
]
|
|
50
|
+
PatternCache: TypeAlias = dict[bool, SparsityPattern]
|
|
51
|
+
PatternLike: TypeAlias = str | SparsityPattern | None
|
|
28
52
|
|
|
29
53
|
_WARNED_UNTAGGED_KERNELS: set[int] = set()
|
|
30
54
|
|
|
@@ -93,8 +117,12 @@ class FESpaceBase(Protocol):
|
|
|
93
117
|
"""
|
|
94
118
|
elem_dofs: jnp.ndarray
|
|
95
119
|
value_dim: int
|
|
96
|
-
|
|
97
|
-
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def n_dofs(self) -> int: ...
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def n_ldofs(self) -> int: ...
|
|
98
126
|
|
|
99
127
|
def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext: ...
|
|
100
128
|
|
|
@@ -120,8 +148,8 @@ class FESpaceClosure:
|
|
|
120
148
|
_n_dofs: int | None = None
|
|
121
149
|
_n_ldofs: int | None = None
|
|
122
150
|
data: SpaceData | None = None
|
|
123
|
-
_pattern_cache:
|
|
124
|
-
_kernel_cache:
|
|
151
|
+
_pattern_cache: PatternCache = field(default_factory=dict, repr=False)
|
|
152
|
+
_kernel_cache: KernelCache = field(default_factory=dict, repr=False)
|
|
125
153
|
_elem_rows_cache: jnp.ndarray | None = field(default=None, repr=False)
|
|
126
154
|
|
|
127
155
|
def __post_init__(self):
|
|
@@ -232,7 +260,7 @@ class FESpaceClosure:
|
|
|
232
260
|
block_sizes=block_sizes,
|
|
233
261
|
)
|
|
234
262
|
|
|
235
|
-
def _kernel_cache_key(self, kind: str, form, params, *, jit: bool):
|
|
263
|
+
def _kernel_cache_key(self, kind: str, form, params, *, jit: bool) -> KernelCacheKey | None:
|
|
236
264
|
if not jit:
|
|
237
265
|
return None
|
|
238
266
|
form_key = getattr(form, "__wrapped__", form)
|
|
@@ -256,15 +284,15 @@ class FESpaceClosure:
|
|
|
256
284
|
|
|
257
285
|
def assemble(
|
|
258
286
|
self,
|
|
259
|
-
form,
|
|
260
|
-
params=None,
|
|
287
|
+
form: FormKernel[P] | BilinearForm | LinearForm,
|
|
288
|
+
params: P | None = None,
|
|
261
289
|
*,
|
|
262
290
|
kind: str | None = None,
|
|
263
291
|
n_chunks: int | None = None,
|
|
264
292
|
dep: jnp.ndarray | None = None,
|
|
265
293
|
jit: bool = True,
|
|
266
|
-
pattern:
|
|
267
|
-
kernel=None,
|
|
294
|
+
pattern: PatternLike = "auto",
|
|
295
|
+
kernel: ElementBilinearKernel | ElementLinearKernel | None = None,
|
|
268
296
|
**kwargs,
|
|
269
297
|
):
|
|
270
298
|
"""
|
|
@@ -324,13 +352,14 @@ class FESpaceClosure:
|
|
|
324
352
|
jit=jit,
|
|
325
353
|
maker=make_element_bilinear_kernel,
|
|
326
354
|
)
|
|
355
|
+
form_kernel = cast(FormKernel, form)
|
|
327
356
|
pattern_use = None
|
|
328
357
|
if pattern == "auto":
|
|
329
358
|
pattern_use = self.get_sparsity_pattern(with_idx=True)
|
|
330
359
|
else:
|
|
331
360
|
pattern_use = pattern
|
|
332
361
|
return self.assemble_bilinear_form(
|
|
333
|
-
|
|
362
|
+
form_kernel,
|
|
334
363
|
params,
|
|
335
364
|
n_chunks=n_chunks,
|
|
336
365
|
dep=dep,
|
|
@@ -347,8 +376,9 @@ class FESpaceClosure:
|
|
|
347
376
|
jit=jit,
|
|
348
377
|
maker=make_element_linear_kernel,
|
|
349
378
|
)
|
|
379
|
+
form_kernel = cast(FormKernel, form)
|
|
350
380
|
return self.assemble_linear_form(
|
|
351
|
-
|
|
381
|
+
form_kernel,
|
|
352
382
|
params,
|
|
353
383
|
n_chunks=n_chunks,
|
|
354
384
|
dep=dep,
|
|
@@ -360,14 +390,14 @@ class FESpaceClosure:
|
|
|
360
390
|
# --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
|
|
361
391
|
def assemble_bilinear_form(
|
|
362
392
|
self,
|
|
363
|
-
form:
|
|
393
|
+
form: FormKernel[P],
|
|
364
394
|
params: P,
|
|
365
395
|
*,
|
|
366
396
|
n_chunks: int | None = None,
|
|
367
397
|
dep: jnp.ndarray | None = None,
|
|
368
398
|
kernel: ElementBilinearKernel | None = None,
|
|
369
399
|
**kwargs,
|
|
370
|
-
):
|
|
400
|
+
) -> FluxSparseMatrix:
|
|
371
401
|
"""Assemble bilinear form; kernel(ctx) -> (n_ldofs, n_ldofs) if provided."""
|
|
372
402
|
from .assembly import assemble_bilinear_form
|
|
373
403
|
if "pattern" not in kwargs or kwargs.get("pattern") is None:
|
|
@@ -378,29 +408,39 @@ class FESpaceClosure:
|
|
|
378
408
|
|
|
379
409
|
def assemble_linear_form(
|
|
380
410
|
self,
|
|
381
|
-
form:
|
|
411
|
+
form: FormKernel[P],
|
|
382
412
|
params: P,
|
|
383
413
|
*,
|
|
384
414
|
n_chunks: int | None = None,
|
|
385
415
|
dep: jnp.ndarray | None = None,
|
|
386
416
|
kernel: ElementLinearKernel | None = None,
|
|
387
417
|
**kwargs,
|
|
388
|
-
):
|
|
418
|
+
) -> LinearReturn:
|
|
389
419
|
"""Assemble linear form; kernel(ctx) -> (n_ldofs,) if provided."""
|
|
390
420
|
from .assembly import assemble_linear_form
|
|
391
421
|
return assemble_linear_form(
|
|
392
422
|
self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
|
|
393
423
|
)
|
|
394
424
|
|
|
395
|
-
def assemble_functional(self, form, params):
|
|
425
|
+
def assemble_functional(self, form: FormKernel[P], params: P) -> jnp.ndarray:
|
|
396
426
|
from .assembly import assemble_functional
|
|
397
427
|
return assemble_functional(self, form, params)
|
|
398
428
|
|
|
399
|
-
def assemble_mass_matrix(
|
|
429
|
+
def assemble_mass_matrix(
|
|
430
|
+
self,
|
|
431
|
+
*,
|
|
432
|
+
n_chunks: int | None = None,
|
|
433
|
+
**kwargs,
|
|
434
|
+
) -> MassReturn:
|
|
400
435
|
from .assembly import assemble_mass_matrix
|
|
401
436
|
return assemble_mass_matrix(self, n_chunks=n_chunks, **kwargs)
|
|
402
437
|
|
|
403
|
-
def assemble_bilinear_dense(
|
|
438
|
+
def assemble_bilinear_dense(
|
|
439
|
+
self,
|
|
440
|
+
kernel: FormKernel[P],
|
|
441
|
+
params: P,
|
|
442
|
+
**kwargs,
|
|
443
|
+
) -> BilinearReturn:
|
|
404
444
|
from .assembly import assemble_bilinear_dense
|
|
405
445
|
return assemble_bilinear_dense(self, kernel, params, **kwargs)
|
|
406
446
|
|
|
@@ -412,7 +452,7 @@ class FESpaceClosure:
|
|
|
412
452
|
*,
|
|
413
453
|
kernel: ElementResidualKernel | None = None,
|
|
414
454
|
**kwargs,
|
|
415
|
-
):
|
|
455
|
+
) -> LinearReturn:
|
|
416
456
|
"""Assemble residual; kernel(ctx, u_elem) -> (n_ldofs,) if provided."""
|
|
417
457
|
from .assembly import assemble_residual
|
|
418
458
|
return assemble_residual(self, res_form, u, params, kernel=kernel, **kwargs)
|
|
@@ -425,7 +465,7 @@ class FESpaceClosure:
|
|
|
425
465
|
*,
|
|
426
466
|
kernel: ElementJacobianKernel | None = None,
|
|
427
467
|
**kwargs,
|
|
428
|
-
):
|
|
468
|
+
) -> JacobianReturn:
|
|
429
469
|
"""Assemble Jacobian; kernel(u_elem, ctx) -> (n_ldofs, n_ldofs) if provided."""
|
|
430
470
|
from .assembly import assemble_jacobian
|
|
431
471
|
return assemble_jacobian(self, res_form, u, params, kernel=kernel, **kwargs)
|
|
@@ -585,7 +625,7 @@ def make_tet10_space_pytree(
|
|
|
585
625
|
"""Create a pytree quadratic tet space (10-node elements)."""
|
|
586
626
|
basis = make_tet10_basis_pytree(intorder)
|
|
587
627
|
element = None if dim == 1 else ElementVector(dim)
|
|
588
|
-
return make_space_pytree(mesh, basis, element)
|
|
628
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
589
629
|
|
|
590
630
|
|
|
591
631
|
def make_hex_space(mesh: HexMesh, dim: int = 1, intorder: int = 2) -> FESpace:
|
|
@@ -601,7 +641,7 @@ def make_hex_space_pytree(
|
|
|
601
641
|
"""Create a pytree trilinear hex space (8-node elements)."""
|
|
602
642
|
basis = make_hex_basis_pytree(intorder)
|
|
603
643
|
element = None if dim == 1 else ElementVector(dim)
|
|
604
|
-
return make_space_pytree(mesh, basis, element)
|
|
644
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
605
645
|
|
|
606
646
|
|
|
607
647
|
def make_hex20_space(
|
|
@@ -619,7 +659,7 @@ def make_hex20_space_pytree(
|
|
|
619
659
|
"""Create a pytree serendipity hex space (20-node elements)."""
|
|
620
660
|
basis = make_hex20_basis_pytree(intorder)
|
|
621
661
|
element = None if dim == 1 else ElementVector(dim)
|
|
622
|
-
return make_space_pytree(mesh, basis, element)
|
|
662
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
623
663
|
|
|
624
664
|
|
|
625
665
|
def make_hex27_space(
|
|
@@ -637,14 +677,14 @@ def make_hex27_space_pytree(
|
|
|
637
677
|
"""Create a pytree triquadratic hex space (27-node elements)."""
|
|
638
678
|
basis = make_hex27_basis_pytree(intorder)
|
|
639
679
|
element = None if dim == 1 else ElementVector(dim)
|
|
640
|
-
return make_space_pytree(mesh, basis, element)
|
|
680
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
641
681
|
|
|
642
682
|
|
|
643
683
|
def make_tet_space(mesh: TetMesh, dim: int = 1, intorder: int = 2) -> FESpace:
|
|
644
684
|
"""Create a linear or quadratic tet space based on mesh nodes."""
|
|
645
685
|
n_nodes = mesh.conn.shape[1]
|
|
646
686
|
if n_nodes == 10:
|
|
647
|
-
basis = make_tet10_basis(intorder if intorder > 1 else 2)
|
|
687
|
+
basis: Basis3D = make_tet10_basis(intorder if intorder > 1 else 2)
|
|
648
688
|
else:
|
|
649
689
|
basis = make_tet_basis(intorder)
|
|
650
690
|
element = None if dim == 1 else ElementVector(dim)
|
|
@@ -657,8 +697,8 @@ def make_tet_space_pytree(
|
|
|
657
697
|
"""Create a pytree linear or quadratic tet space based on mesh nodes."""
|
|
658
698
|
n_nodes = mesh.conn.shape[1]
|
|
659
699
|
if n_nodes == 10:
|
|
660
|
-
basis = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
|
|
700
|
+
basis: Basis3D = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
|
|
661
701
|
else:
|
|
662
702
|
basis = make_tet_basis_pytree(intorder)
|
|
663
703
|
element = None if dim == 1 else ElementVector(dim)
|
|
664
|
-
return make_space_pytree(mesh, basis, element)
|
|
704
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|