fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +1 -13
- fluxfem/core/__init__.py +53 -71
- fluxfem/core/assembly.py +41 -32
- fluxfem/core/basis.py +2 -2
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/mixed_space.py +42 -8
- fluxfem/core/mixed_weakform.py +1 -1
- fluxfem/core/space.py +68 -28
- fluxfem/core/weakform.py +95 -77
- fluxfem/mesh/base.py +3 -3
- fluxfem/mesh/contact.py +33 -17
- fluxfem/mesh/io.py +3 -2
- fluxfem/mesh/mortar.py +106 -43
- fluxfem/mesh/supermesh.py +2 -0
- fluxfem/mesh/surface.py +82 -22
- fluxfem/mesh/tet.py +7 -4
- fluxfem/physics/elasticity/hyperelastic.py +32 -3
- fluxfem/physics/elasticity/linear.py +13 -2
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +6 -1
- fluxfem/solver/block_matrix.py +165 -13
- fluxfem/solver/block_system.py +52 -29
- fluxfem/solver/cg.py +43 -30
- fluxfem/solver/dirichlet.py +35 -12
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +25 -12
- fluxfem/solver/petsc.py +13 -7
- fluxfem/solver/preconditioner.py +7 -4
- fluxfem/solver/solve_runner.py +42 -24
- fluxfem/solver/solver.py +23 -11
- fluxfem/solver/sparse.py +32 -13
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.2.0.dist-info/RECORD +0 -59
- {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/__init__.py
CHANGED
|
@@ -31,9 +31,6 @@ __all__ = [
|
|
|
31
31
|
"compile_surface_bilinear",
|
|
32
32
|
"compile_mixed_surface_residual",
|
|
33
33
|
"compile_mixed_residual",
|
|
34
|
-
"outer",
|
|
35
|
-
"sdot",
|
|
36
|
-
"dOmega",
|
|
37
34
|
"FormContext",
|
|
38
35
|
"MixedFormContext",
|
|
39
36
|
"VolumeContext",
|
|
@@ -85,11 +82,6 @@ __all__ = [
|
|
|
85
82
|
"constant_body_force_form",
|
|
86
83
|
"constant_body_force_vector_form",
|
|
87
84
|
"diffusion_form",
|
|
88
|
-
"dot",
|
|
89
|
-
"ddot",
|
|
90
|
-
"transpose_last2",
|
|
91
|
-
"sym_grad",
|
|
92
|
-
"sym_grad_u",
|
|
93
85
|
"right_cauchy_green",
|
|
94
86
|
"green_lagrange_strain",
|
|
95
87
|
"deformation_gradient",
|
|
@@ -172,6 +164,7 @@ __all__ = [
|
|
|
172
164
|
"coo_to_csr",
|
|
173
165
|
"SparsityPattern",
|
|
174
166
|
"FluxSparseMatrix",
|
|
167
|
+
"FluxBlockMatrix",
|
|
175
168
|
"coalesce_coo",
|
|
176
169
|
"concat_flux",
|
|
177
170
|
"block_diag_flux",
|
|
@@ -233,11 +226,6 @@ _PHYSICS_EXPORTS = {
|
|
|
233
226
|
"vector_body_force_form",
|
|
234
227
|
"constant_body_force_vector_form",
|
|
235
228
|
"diffusion_form",
|
|
236
|
-
"dot",
|
|
237
|
-
"ddot",
|
|
238
|
-
"transpose_last2",
|
|
239
|
-
"sym_grad",
|
|
240
|
-
"sym_grad_u",
|
|
241
229
|
"right_cauchy_green",
|
|
242
230
|
"green_lagrange_strain",
|
|
243
231
|
"deformation_gradient",
|
fluxfem/core/__init__.py
CHANGED
|
@@ -80,76 +80,6 @@ from .weakform import (
|
|
|
80
80
|
compile_linear,
|
|
81
81
|
compile_residual,
|
|
82
82
|
compile_mixed_residual,
|
|
83
|
-
grad,
|
|
84
|
-
sym_grad,
|
|
85
|
-
outer,
|
|
86
|
-
dot,
|
|
87
|
-
sdot,
|
|
88
|
-
ddot,
|
|
89
|
-
inner,
|
|
90
|
-
action,
|
|
91
|
-
gaction,
|
|
92
|
-
normal,
|
|
93
|
-
ds,
|
|
94
|
-
dOmega,
|
|
95
|
-
I,
|
|
96
|
-
det,
|
|
97
|
-
inv,
|
|
98
|
-
transpose,
|
|
99
|
-
log,
|
|
100
|
-
transpose_last2,
|
|
101
|
-
matmul,
|
|
102
|
-
matmul_std,
|
|
103
|
-
einsum,
|
|
104
|
-
)
|
|
105
|
-
from ..mesh import (
|
|
106
|
-
BaseMeshPytree,
|
|
107
|
-
SurfaceWithFacetMap,
|
|
108
|
-
bbox_predicate,
|
|
109
|
-
plane_predicate,
|
|
110
|
-
axis_plane_predicate,
|
|
111
|
-
slab_predicate,
|
|
112
|
-
HexMesh,
|
|
113
|
-
HexMeshPytree,
|
|
114
|
-
StructuredHexBox,
|
|
115
|
-
SurfaceMesh,
|
|
116
|
-
SurfaceMeshPytree,
|
|
117
|
-
SurfaceWithElemConn,
|
|
118
|
-
surface_with_elem_conn,
|
|
119
|
-
SurfaceSupermesh,
|
|
120
|
-
build_surface_supermesh,
|
|
121
|
-
tag_axis_minmax_facets,
|
|
122
|
-
TetMesh,
|
|
123
|
-
TetMeshPytree,
|
|
124
|
-
StructuredTetBox,
|
|
125
|
-
StructuredTetTensorBox,
|
|
126
|
-
load_gmsh_mesh,
|
|
127
|
-
load_gmsh_hex_mesh,
|
|
128
|
-
load_gmsh_tet_mesh,
|
|
129
|
-
make_surface_from_facets,
|
|
130
|
-
SurfaceSupermesh,
|
|
131
|
-
build_surface_supermesh,
|
|
132
|
-
MortarMatrix,
|
|
133
|
-
assemble_mortar_matrices,
|
|
134
|
-
assemble_contact_onesided_floor,
|
|
135
|
-
assemble_mixed_surface_residual,
|
|
136
|
-
assemble_mixed_surface_jacobian,
|
|
137
|
-
map_surface_facets_to_tet_elements,
|
|
138
|
-
map_surface_facets_to_hex_elements,
|
|
139
|
-
tri_area,
|
|
140
|
-
tri_quadrature,
|
|
141
|
-
facet_triangles,
|
|
142
|
-
facet_shape_values,
|
|
143
|
-
volume_shape_values_at_points,
|
|
144
|
-
quad_shape_and_local,
|
|
145
|
-
quad9_shape_values,
|
|
146
|
-
hex27_gradN,
|
|
147
|
-
ContactSurfaceSpace,
|
|
148
|
-
ContactSide,
|
|
149
|
-
OneSidedContact,
|
|
150
|
-
OneSidedContactSurfaceSpace,
|
|
151
|
-
facet_gap_values,
|
|
152
|
-
active_contact_facets,
|
|
153
83
|
)
|
|
154
84
|
from .basis import (
|
|
155
85
|
HexTriLinearBasis,
|
|
@@ -167,9 +97,58 @@ import importlib
|
|
|
167
97
|
|
|
168
98
|
from .solver import spdirect_solve_cpu, spdirect_solve_gpu, spdirect_solve_jax, coo_to_csr
|
|
169
99
|
|
|
100
|
+
_MESH_EXPORTS = {
|
|
101
|
+
"BaseMeshPytree",
|
|
102
|
+
"SurfaceWithFacetMap",
|
|
103
|
+
"bbox_predicate",
|
|
104
|
+
"plane_predicate",
|
|
105
|
+
"axis_plane_predicate",
|
|
106
|
+
"slab_predicate",
|
|
107
|
+
"HexMesh",
|
|
108
|
+
"HexMeshPytree",
|
|
109
|
+
"StructuredHexBox",
|
|
110
|
+
"SurfaceMesh",
|
|
111
|
+
"SurfaceMeshPytree",
|
|
112
|
+
"SurfaceWithElemConn",
|
|
113
|
+
"surface_with_elem_conn",
|
|
114
|
+
"SurfaceSupermesh",
|
|
115
|
+
"build_surface_supermesh",
|
|
116
|
+
"tag_axis_minmax_facets",
|
|
117
|
+
"TetMesh",
|
|
118
|
+
"TetMeshPytree",
|
|
119
|
+
"StructuredTetBox",
|
|
120
|
+
"StructuredTetTensorBox",
|
|
121
|
+
"load_gmsh_mesh",
|
|
122
|
+
"load_gmsh_hex_mesh",
|
|
123
|
+
"load_gmsh_tet_mesh",
|
|
124
|
+
"make_surface_from_facets",
|
|
125
|
+
"MortarMatrix",
|
|
126
|
+
"assemble_mortar_matrices",
|
|
127
|
+
"assemble_contact_onesided_floor",
|
|
128
|
+
"assemble_mixed_surface_residual",
|
|
129
|
+
"assemble_mixed_surface_jacobian",
|
|
130
|
+
"map_surface_facets_to_tet_elements",
|
|
131
|
+
"map_surface_facets_to_hex_elements",
|
|
132
|
+
"tri_area",
|
|
133
|
+
"tri_quadrature",
|
|
134
|
+
"facet_triangles",
|
|
135
|
+
"facet_shape_values",
|
|
136
|
+
"volume_shape_values_at_points",
|
|
137
|
+
"quad_shape_and_local",
|
|
138
|
+
"quad9_shape_values",
|
|
139
|
+
"hex27_gradN",
|
|
140
|
+
"ContactSurfaceSpace",
|
|
141
|
+
"ContactSide",
|
|
142
|
+
"OneSidedContact",
|
|
143
|
+
"OneSidedContactSurfaceSpace",
|
|
144
|
+
"facet_gap_values",
|
|
145
|
+
"active_contact_facets",
|
|
146
|
+
}
|
|
147
|
+
|
|
170
148
|
_SOLVER_EXPORTS = {
|
|
171
149
|
"SparsityPattern",
|
|
172
150
|
"FluxSparseMatrix",
|
|
151
|
+
"FluxBlockMatrix",
|
|
173
152
|
"coalesce_coo",
|
|
174
153
|
"concat_flux",
|
|
175
154
|
"block_diag_flux",
|
|
@@ -227,7 +206,9 @@ _SOLVER_BC_EXPORTS = {
|
|
|
227
206
|
|
|
228
207
|
|
|
229
208
|
def __getattr__(name: str):
|
|
230
|
-
if name in
|
|
209
|
+
if name in _MESH_EXPORTS:
|
|
210
|
+
module = importlib.import_module("..mesh", __name__)
|
|
211
|
+
elif name in _SOLVER_EXPORTS:
|
|
231
212
|
module = importlib.import_module("..solver", __name__)
|
|
232
213
|
elif name in _SOLVER_BC_EXPORTS:
|
|
233
214
|
module = importlib.import_module("..solver.bc", __name__)
|
|
@@ -410,6 +391,7 @@ __all__ = [
|
|
|
410
391
|
"coo_to_csr",
|
|
411
392
|
"SparsityPattern",
|
|
412
393
|
"FluxSparseMatrix",
|
|
394
|
+
"FluxBlockMatrix",
|
|
413
395
|
"coalesce_coo",
|
|
414
396
|
"concat_flux",
|
|
415
397
|
"block_diag_flux",
|
fluxfem/core/assembly.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Any, Callable, Literal, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union
|
|
2
|
+
from typing import Any, Callable, Literal, Mapping, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union, cast
|
|
3
3
|
import numpy as np
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
@@ -10,11 +10,16 @@ from .forms import FormContext
|
|
|
10
10
|
from .space import FESpaceBase
|
|
11
11
|
|
|
12
12
|
# Shared call signatures for kernels/forms
|
|
13
|
-
Array = jnp.ndarray
|
|
13
|
+
Array: TypeAlias = jnp.ndarray
|
|
14
14
|
P = TypeVar("P")
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
FormKernel: TypeAlias = Callable[[FormContext, P], Array]
|
|
17
|
+
# Form kernels return integrands; element kernels return integrated element arrays.
|
|
18
|
+
Kernel: TypeAlias = Callable[[FormContext, P], Array]
|
|
19
|
+
ResidualInput: TypeAlias = Array | Mapping[str, Array]
|
|
20
|
+
ResidualValue: TypeAlias = Array | Mapping[str, Array]
|
|
17
21
|
ResidualForm = Callable[[FormContext, Array, P], Array]
|
|
22
|
+
ResidualFormLike = Callable[[FormContext, ResidualInput, P], ResidualValue]
|
|
18
23
|
ElementDofMapper = Callable[[Array], Array]
|
|
19
24
|
|
|
20
25
|
if TYPE_CHECKING:
|
|
@@ -107,7 +112,7 @@ class BatchedAssembler:
|
|
|
107
112
|
def __init__(
|
|
108
113
|
self,
|
|
109
114
|
space: SpaceLike,
|
|
110
|
-
elem_data:
|
|
115
|
+
elem_data: FormContext,
|
|
111
116
|
elem_dofs: Array,
|
|
112
117
|
*,
|
|
113
118
|
pattern: SparsityPattern | None = None,
|
|
@@ -119,8 +124,8 @@ class BatchedAssembler:
|
|
|
119
124
|
self.n_ldofs = int(space.n_ldofs)
|
|
120
125
|
self.n_dofs = int(space.n_dofs)
|
|
121
126
|
self.pattern = pattern
|
|
122
|
-
self._rows = None
|
|
123
|
-
self._cols = None
|
|
127
|
+
self._rows: Array | None = None
|
|
128
|
+
self._cols: Array | None = None
|
|
124
129
|
|
|
125
130
|
@classmethod
|
|
126
131
|
def from_space(
|
|
@@ -135,7 +140,7 @@ class BatchedAssembler:
|
|
|
135
140
|
|
|
136
141
|
def make_mask(self, n_active: int) -> Array:
|
|
137
142
|
n_active = max(0, min(int(n_active), self.n_elems))
|
|
138
|
-
mask = np.zeros((self.n_elems,), dtype=float)
|
|
143
|
+
mask: np.ndarray = np.zeros((self.n_elems,), dtype=float)
|
|
139
144
|
if n_active:
|
|
140
145
|
mask[:n_active] = 1.0
|
|
141
146
|
return jnp.asarray(mask)
|
|
@@ -177,7 +182,7 @@ class BatchedAssembler:
|
|
|
177
182
|
|
|
178
183
|
def assemble_bilinear(
|
|
179
184
|
self,
|
|
180
|
-
form:
|
|
185
|
+
form: FormKernel[P],
|
|
181
186
|
params: P,
|
|
182
187
|
*,
|
|
183
188
|
mask: Array | None = None,
|
|
@@ -208,7 +213,7 @@ class BatchedAssembler:
|
|
|
208
213
|
|
|
209
214
|
def assemble_linear(
|
|
210
215
|
self,
|
|
211
|
-
form:
|
|
216
|
+
form: FormKernel[P],
|
|
212
217
|
params: P,
|
|
213
218
|
*,
|
|
214
219
|
mask: Array | None = None,
|
|
@@ -359,7 +364,7 @@ class SpaceLike(FESpaceBase, Protocol):
|
|
|
359
364
|
|
|
360
365
|
def assemble_bilinear_dense(
|
|
361
366
|
space: SpaceLike,
|
|
362
|
-
kernel:
|
|
367
|
+
kernel: FormKernel[P],
|
|
363
368
|
params: P,
|
|
364
369
|
*,
|
|
365
370
|
sparse: bool = False,
|
|
@@ -412,7 +417,7 @@ def assemble_bilinear_dense(
|
|
|
412
417
|
|
|
413
418
|
def assemble_bilinear_form(
|
|
414
419
|
space: SpaceLike,
|
|
415
|
-
form:
|
|
420
|
+
form: FormKernel[P],
|
|
416
421
|
params: P,
|
|
417
422
|
*,
|
|
418
423
|
pattern: SparsityPattern | None = None,
|
|
@@ -589,7 +594,7 @@ def assemble_mass_matrix(
|
|
|
589
594
|
|
|
590
595
|
def assemble_linear_form(
|
|
591
596
|
space: SpaceLike,
|
|
592
|
-
form:
|
|
597
|
+
form: FormKernel[P],
|
|
593
598
|
params: P,
|
|
594
599
|
*,
|
|
595
600
|
kernel: ElementLinearKernel | None = None,
|
|
@@ -670,7 +675,7 @@ def assemble_linear_form(
|
|
|
670
675
|
return F
|
|
671
676
|
|
|
672
677
|
|
|
673
|
-
def assemble_functional(space: SpaceLike, form:
|
|
678
|
+
def assemble_functional(space: SpaceLike, form: FormKernel[P], params: P) -> jnp.ndarray:
|
|
674
679
|
"""
|
|
675
680
|
Assemble scalar functional J = ∫ form(ctx, params) dΩ.
|
|
676
681
|
Expects form(ctx, params) -> (n_q,) or (n_q, 1).
|
|
@@ -798,7 +803,7 @@ def assemble_jacobian_elementwise(
|
|
|
798
803
|
)
|
|
799
804
|
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
800
805
|
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
801
|
-
return K_flat.reshape(
|
|
806
|
+
return K_flat.reshape(n_dofs, n_dofs)
|
|
802
807
|
|
|
803
808
|
|
|
804
809
|
def assemble_residual_global(
|
|
@@ -885,7 +890,7 @@ assemble_residual_elementwise_xla = assemble_residual_elementwise
|
|
|
885
890
|
|
|
886
891
|
|
|
887
892
|
def make_element_bilinear_kernel(
|
|
888
|
-
form:
|
|
893
|
+
form: FormKernel[P], params: P, *, jit: bool = True
|
|
889
894
|
) -> ElementBilinearKernel:
|
|
890
895
|
"""Element kernel: (ctx) -> Ke."""
|
|
891
896
|
|
|
@@ -900,7 +905,7 @@ def make_element_bilinear_kernel(
|
|
|
900
905
|
|
|
901
906
|
|
|
902
907
|
def make_element_linear_kernel(
|
|
903
|
-
form:
|
|
908
|
+
form: FormKernel[P], params: P, *, jit: bool = True
|
|
904
909
|
) -> ElementLinearKernel:
|
|
905
910
|
"""Element kernel: (ctx) -> fe."""
|
|
906
911
|
|
|
@@ -945,8 +950,8 @@ def make_element_jacobian_kernel(
|
|
|
945
950
|
|
|
946
951
|
|
|
947
952
|
def element_residual(
|
|
948
|
-
res_form:
|
|
949
|
-
) ->
|
|
953
|
+
res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
|
|
954
|
+
) -> ResidualValue:
|
|
950
955
|
"""
|
|
951
956
|
Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
|
|
952
957
|
Returns shape (n_ldofs,).
|
|
@@ -972,8 +977,8 @@ def element_residual(
|
|
|
972
977
|
|
|
973
978
|
|
|
974
979
|
def element_jacobian(
|
|
975
|
-
res_form:
|
|
976
|
-
) ->
|
|
980
|
+
res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
|
|
981
|
+
) -> ResidualValue:
|
|
977
982
|
"""
|
|
978
983
|
Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
|
|
979
984
|
"""
|
|
@@ -984,7 +989,7 @@ def element_jacobian(
|
|
|
984
989
|
|
|
985
990
|
|
|
986
991
|
def make_element_kernel(
|
|
987
|
-
form:
|
|
992
|
+
form: FormKernel[P] | ResidualForm[P],
|
|
988
993
|
params: P,
|
|
989
994
|
*,
|
|
990
995
|
kind: Literal["bilinear", "linear", "residual", "jacobian"],
|
|
@@ -999,22 +1004,26 @@ def make_element_kernel(
|
|
|
999
1004
|
- "residual": kernel(ctx, u_elem) -> (n_ldofs,)
|
|
1000
1005
|
- "jacobian": kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
|
|
1001
1006
|
"""
|
|
1002
|
-
kind = kind.lower()
|
|
1007
|
+
kind = cast(Literal["bilinear", "linear", "residual", "jacobian"], kind.lower())
|
|
1003
1008
|
if kind == "bilinear":
|
|
1004
|
-
|
|
1009
|
+
form_bilinear = cast(FormKernel[P], form)
|
|
1010
|
+
return make_element_bilinear_kernel(form_bilinear, params, jit=jit)
|
|
1005
1011
|
if kind == "linear":
|
|
1012
|
+
form_linear = cast(FormKernel[P], form)
|
|
1006
1013
|
def per_element(ctx: FormContext):
|
|
1007
|
-
integrand =
|
|
1008
|
-
if getattr(
|
|
1014
|
+
integrand = form_linear(ctx, params)
|
|
1015
|
+
if getattr(form_linear, "_includes_measure", False):
|
|
1009
1016
|
return integrand.sum(axis=0)
|
|
1010
1017
|
wJ = ctx.w * ctx.test.detJ
|
|
1011
1018
|
return (integrand * wJ[:, None]).sum(axis=0)
|
|
1012
1019
|
|
|
1013
1020
|
return jax.jit(per_element) if jit else per_element
|
|
1014
1021
|
if kind == "residual":
|
|
1015
|
-
|
|
1022
|
+
form_residual = cast(ResidualForm[P], form)
|
|
1023
|
+
return make_element_residual_kernel(form_residual, params)
|
|
1016
1024
|
if kind == "jacobian":
|
|
1017
|
-
|
|
1025
|
+
form_residual = cast(ResidualForm[P], form)
|
|
1026
|
+
return make_element_jacobian_kernel(form_residual, params)
|
|
1018
1027
|
raise ValueError(f"Unknown kernel kind: {kind}")
|
|
1019
1028
|
|
|
1020
1029
|
|
|
@@ -1327,19 +1336,19 @@ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
|
|
|
1327
1336
|
return load * ctx.test.N # (n_q, n_ldofs)
|
|
1328
1337
|
|
|
1329
1338
|
|
|
1330
|
-
scalar_body_force_form._ff_kind = "linear"
|
|
1331
|
-
scalar_body_force_form._ff_domain = "volume"
|
|
1339
|
+
scalar_body_force_form._ff_kind = "linear" # type: ignore[attr-defined]
|
|
1340
|
+
scalar_body_force_form._ff_domain = "volume" # type: ignore[attr-defined]
|
|
1332
1341
|
|
|
1333
1342
|
|
|
1334
|
-
def make_scalar_body_force_form(body_force: Callable[[Array], Array]) ->
|
|
1343
|
+
def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> FormKernel[Any]:
|
|
1335
1344
|
"""
|
|
1336
1345
|
Build a scalar linear form from a callable f(x_q) -> (n_q,).
|
|
1337
1346
|
"""
|
|
1338
1347
|
def _form(ctx: FormContext, _params):
|
|
1339
1348
|
f_q = body_force(ctx.x_q)
|
|
1340
1349
|
return f_q[..., None] * ctx.test.N
|
|
1341
|
-
_form._ff_kind = "linear"
|
|
1342
|
-
_form._ff_domain = "volume"
|
|
1350
|
+
_form._ff_kind = "linear" # type: ignore[attr-defined]
|
|
1351
|
+
_form._ff_domain = "volume" # type: ignore[attr-defined]
|
|
1343
1352
|
return _form
|
|
1344
1353
|
|
|
1345
1354
|
|
fluxfem/core/basis.py
CHANGED
|
@@ -421,7 +421,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
421
421
|
|
|
422
422
|
def tree_flatten(self):
|
|
423
423
|
children = (self.quad_points, self.quad_weights)
|
|
424
|
-
aux_data = {}
|
|
424
|
+
aux_data: dict[str, object] = {}
|
|
425
425
|
return children, aux_data
|
|
426
426
|
|
|
427
427
|
@classmethod
|
|
@@ -569,7 +569,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
|
569
569
|
|
|
570
570
|
def tree_flatten(self):
|
|
571
571
|
children = (self.quad_points, self.quad_weights)
|
|
572
|
-
aux_data = {}
|
|
572
|
+
aux_data: dict[str, object] = {}
|
|
573
573
|
return children, aux_data
|
|
574
574
|
|
|
575
575
|
@classmethod
|
fluxfem/core/context_types.py
CHANGED
|
@@ -1,36 +1,60 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any, Protocol, TypeAlias, runtime_checkable
|
|
3
|
+
from typing import Any, Mapping, Protocol, TYPE_CHECKING, TypeAlias, runtime_checkable
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from jax import Array as JaxArray
|
|
9
|
+
|
|
10
|
+
ArrayLike: TypeAlias = np.ndarray | JaxArray
|
|
11
|
+
else:
|
|
12
|
+
ArrayLike: TypeAlias = np.ndarray
|
|
4
13
|
|
|
5
14
|
|
|
6
15
|
@runtime_checkable
|
|
7
16
|
class VolumeContext(Protocol):
|
|
8
17
|
"""Minimum interface for volume weak-form evaluation."""
|
|
9
18
|
|
|
10
|
-
test:
|
|
11
|
-
trial:
|
|
12
|
-
w:
|
|
19
|
+
test: "FormFieldLike"
|
|
20
|
+
trial: "FormFieldLike"
|
|
21
|
+
w: ArrayLike
|
|
13
22
|
|
|
14
23
|
|
|
15
24
|
@runtime_checkable
|
|
16
25
|
class SurfaceContext(Protocol):
|
|
17
26
|
"""Minimum interface for surface weak-form evaluation."""
|
|
18
27
|
|
|
19
|
-
v:
|
|
20
|
-
w:
|
|
21
|
-
detJ:
|
|
22
|
-
normal:
|
|
28
|
+
v: "FormFieldLike"
|
|
29
|
+
w: ArrayLike
|
|
30
|
+
detJ: ArrayLike
|
|
31
|
+
normal: ArrayLike
|
|
23
32
|
|
|
24
33
|
|
|
25
34
|
@runtime_checkable
|
|
26
35
|
class FormFieldLike(Protocol):
|
|
27
36
|
"""Minimum interface for form fields used in weak-form evaluation."""
|
|
28
37
|
|
|
29
|
-
N:
|
|
30
|
-
gradN:
|
|
31
|
-
detJ:
|
|
38
|
+
N: ArrayLike
|
|
39
|
+
gradN: ArrayLike
|
|
40
|
+
detJ: ArrayLike
|
|
32
41
|
value_dim: int
|
|
42
|
+
basis: Any
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@runtime_checkable
|
|
46
|
+
class WeakFormContext(Protocol):
|
|
47
|
+
"""Context interface used when resolving field references."""
|
|
48
|
+
|
|
49
|
+
test: FormFieldLike
|
|
50
|
+
trial: FormFieldLike
|
|
51
|
+
v: FormFieldLike
|
|
52
|
+
unknown: FormFieldLike | None
|
|
53
|
+
fields: Mapping[str, Any] | None
|
|
54
|
+
test_fields: Mapping[str, FormFieldLike] | None
|
|
55
|
+
trial_fields: Mapping[str, FormFieldLike] | None
|
|
56
|
+
unknown_fields: Mapping[str, FormFieldLike] | None
|
|
33
57
|
|
|
34
58
|
|
|
35
|
-
UElement: TypeAlias =
|
|
59
|
+
UElement: TypeAlias = ArrayLike | Mapping[str, ArrayLike]
|
|
36
60
|
ParamsLike: TypeAlias = Any
|
fluxfem/core/mixed_space.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Mapping, Sequence,
|
|
4
|
+
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import jax.numpy as jnp
|
|
@@ -13,6 +13,16 @@ from ..solver.dirichlet import DirichletBC, free_dofs
|
|
|
13
13
|
from ..solver.sparse import FluxSparseMatrix
|
|
14
14
|
from .space import FESpaceClosure
|
|
15
15
|
|
|
16
|
+
P = TypeVar("P")
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .assembly import JacobianReturn, LinearReturn
|
|
20
|
+
|
|
21
|
+
MixedResidualForm: TypeAlias = Callable[
|
|
22
|
+
[MixedFormContext, Mapping[str, jnp.ndarray], P],
|
|
23
|
+
Mapping[str, jnp.ndarray],
|
|
24
|
+
]
|
|
25
|
+
|
|
16
26
|
|
|
17
27
|
@dataclass(eq=False)
|
|
18
28
|
class MixedFESpace:
|
|
@@ -131,13 +141,25 @@ class MixedFESpace:
|
|
|
131
141
|
|
|
132
142
|
def get_sparsity_pattern(self, *, with_idx: bool = True):
|
|
133
143
|
from .assembly import make_sparsity_pattern
|
|
134
|
-
return make_sparsity_pattern(self, with_idx=with_idx)
|
|
144
|
+
return make_sparsity_pattern(cast(Any, self), with_idx=with_idx)
|
|
135
145
|
|
|
136
|
-
def assemble_residual(
|
|
146
|
+
def assemble_residual(
|
|
147
|
+
self,
|
|
148
|
+
res_form: MixedResidualForm[P],
|
|
149
|
+
u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
|
|
150
|
+
params: P,
|
|
151
|
+
**kwargs,
|
|
152
|
+
) -> "LinearReturn":
|
|
137
153
|
from .mixed_assembly import assemble_mixed_residual
|
|
138
154
|
return assemble_mixed_residual(self, res_form, u, params, **kwargs)
|
|
139
155
|
|
|
140
|
-
def assemble_jacobian(
|
|
156
|
+
def assemble_jacobian(
|
|
157
|
+
self,
|
|
158
|
+
res_form: MixedResidualForm[P],
|
|
159
|
+
u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
|
|
160
|
+
params: P,
|
|
161
|
+
**kwargs,
|
|
162
|
+
) -> "JacobianReturn":
|
|
141
163
|
from .mixed_assembly import assemble_mixed_jacobian
|
|
142
164
|
return assemble_mixed_jacobian(self, res_form, u, params, **kwargs)
|
|
143
165
|
|
|
@@ -233,7 +255,7 @@ class MixedProblem:
|
|
|
233
255
|
pattern: object | None = None
|
|
234
256
|
n_chunks: int | None = None
|
|
235
257
|
pad_trace: bool = False
|
|
236
|
-
_compiled:
|
|
258
|
+
_compiled: Callable[..., Any] = field(init=False, repr=False)
|
|
237
259
|
|
|
238
260
|
def __post_init__(self):
|
|
239
261
|
if isinstance(self.residuals, MixedWeakForm):
|
|
@@ -257,18 +279,30 @@ class MixedProblem:
|
|
|
257
279
|
def _wrapped(ctx, u_elem, _params):
|
|
258
280
|
return self._compiled(ctx, u_elem, params(ctx))
|
|
259
281
|
|
|
260
|
-
_wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False)
|
|
282
|
+
_wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False) # type: ignore[attr-defined]
|
|
261
283
|
return _wrapped, None
|
|
262
284
|
return self._compiled, params
|
|
263
285
|
|
|
264
|
-
def assemble_residual(
|
|
286
|
+
def assemble_residual(
|
|
287
|
+
self,
|
|
288
|
+
u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
|
|
289
|
+
*,
|
|
290
|
+
params: P | None = None,
|
|
291
|
+
**kwargs,
|
|
292
|
+
) -> "LinearReturn":
|
|
265
293
|
use_params = self.params if params is None else params
|
|
266
294
|
res_form, use_params = self._wrap_params(use_params)
|
|
267
295
|
return self.space.assemble_residual(
|
|
268
296
|
res_form, u, use_params, **self._merge_kwargs(kwargs)
|
|
269
297
|
)
|
|
270
298
|
|
|
271
|
-
def assemble_jacobian(
|
|
299
|
+
def assemble_jacobian(
|
|
300
|
+
self,
|
|
301
|
+
u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
|
|
302
|
+
*,
|
|
303
|
+
params: P | None = None,
|
|
304
|
+
**kwargs,
|
|
305
|
+
) -> "JacobianReturn":
|
|
272
306
|
use_params = self.params if params is None else params
|
|
273
307
|
res_form, use_params = self._wrap_params(use_params)
|
|
274
308
|
return self.space.assemble_jacobian(
|
fluxfem/core/mixed_weakform.py
CHANGED
|
@@ -40,7 +40,7 @@ def _wrap_params(res_form, params):
|
|
|
40
40
|
def _wrapped(ctx, u_elem, _params):
|
|
41
41
|
return res_form(ctx, u_elem, params(ctx))
|
|
42
42
|
|
|
43
|
-
_wrapped._includes_measure = getattr(res_form, "_includes_measure", False)
|
|
43
|
+
_wrapped._includes_measure = getattr(res_form, "_includes_measure", False) # type: ignore[attr-defined]
|
|
44
44
|
return _wrapped, None
|
|
45
45
|
return res_form, params
|
|
46
46
|
|