fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fluxfem/__init__.py +1 -13
  2. fluxfem/core/__init__.py +53 -71
  3. fluxfem/core/assembly.py +41 -32
  4. fluxfem/core/basis.py +2 -2
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/mixed_space.py +42 -8
  7. fluxfem/core/mixed_weakform.py +1 -1
  8. fluxfem/core/space.py +68 -28
  9. fluxfem/core/weakform.py +95 -77
  10. fluxfem/mesh/base.py +3 -3
  11. fluxfem/mesh/contact.py +33 -17
  12. fluxfem/mesh/io.py +3 -2
  13. fluxfem/mesh/mortar.py +106 -43
  14. fluxfem/mesh/supermesh.py +2 -0
  15. fluxfem/mesh/surface.py +82 -22
  16. fluxfem/mesh/tet.py +7 -4
  17. fluxfem/physics/elasticity/hyperelastic.py +32 -3
  18. fluxfem/physics/elasticity/linear.py +13 -2
  19. fluxfem/physics/elasticity/stress.py +9 -5
  20. fluxfem/physics/operators.py +12 -5
  21. fluxfem/physics/postprocess.py +29 -3
  22. fluxfem/solver/__init__.py +6 -1
  23. fluxfem/solver/block_matrix.py +165 -13
  24. fluxfem/solver/block_system.py +52 -29
  25. fluxfem/solver/cg.py +43 -30
  26. fluxfem/solver/dirichlet.py +35 -12
  27. fluxfem/solver/history.py +15 -3
  28. fluxfem/solver/newton.py +25 -12
  29. fluxfem/solver/petsc.py +13 -7
  30. fluxfem/solver/preconditioner.py +7 -4
  31. fluxfem/solver/solve_runner.py +42 -24
  32. fluxfem/solver/solver.py +23 -11
  33. fluxfem/solver/sparse.py +32 -13
  34. fluxfem/tools/jit.py +19 -7
  35. fluxfem/tools/timer.py +14 -12
  36. fluxfem/tools/visualizer.py +16 -4
  37. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
  38. fluxfem-0.2.1.dist-info/RECORD +59 -0
  39. fluxfem-0.2.0.dist-info/RECORD +0 -59
  40. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  41. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/__init__.py CHANGED
@@ -31,9 +31,6 @@ __all__ = [
31
31
  "compile_surface_bilinear",
32
32
  "compile_mixed_surface_residual",
33
33
  "compile_mixed_residual",
34
- "outer",
35
- "sdot",
36
- "dOmega",
37
34
  "FormContext",
38
35
  "MixedFormContext",
39
36
  "VolumeContext",
@@ -85,11 +82,6 @@ __all__ = [
85
82
  "constant_body_force_form",
86
83
  "constant_body_force_vector_form",
87
84
  "diffusion_form",
88
- "dot",
89
- "ddot",
90
- "transpose_last2",
91
- "sym_grad",
92
- "sym_grad_u",
93
85
  "right_cauchy_green",
94
86
  "green_lagrange_strain",
95
87
  "deformation_gradient",
@@ -172,6 +164,7 @@ __all__ = [
172
164
  "coo_to_csr",
173
165
  "SparsityPattern",
174
166
  "FluxSparseMatrix",
167
+ "FluxBlockMatrix",
175
168
  "coalesce_coo",
176
169
  "concat_flux",
177
170
  "block_diag_flux",
@@ -233,11 +226,6 @@ _PHYSICS_EXPORTS = {
233
226
  "vector_body_force_form",
234
227
  "constant_body_force_vector_form",
235
228
  "diffusion_form",
236
- "dot",
237
- "ddot",
238
- "transpose_last2",
239
- "sym_grad",
240
- "sym_grad_u",
241
229
  "right_cauchy_green",
242
230
  "green_lagrange_strain",
243
231
  "deformation_gradient",
fluxfem/core/__init__.py CHANGED
@@ -80,76 +80,6 @@ from .weakform import (
80
80
  compile_linear,
81
81
  compile_residual,
82
82
  compile_mixed_residual,
83
- grad,
84
- sym_grad,
85
- outer,
86
- dot,
87
- sdot,
88
- ddot,
89
- inner,
90
- action,
91
- gaction,
92
- normal,
93
- ds,
94
- dOmega,
95
- I,
96
- det,
97
- inv,
98
- transpose,
99
- log,
100
- transpose_last2,
101
- matmul,
102
- matmul_std,
103
- einsum,
104
- )
105
- from ..mesh import (
106
- BaseMeshPytree,
107
- SurfaceWithFacetMap,
108
- bbox_predicate,
109
- plane_predicate,
110
- axis_plane_predicate,
111
- slab_predicate,
112
- HexMesh,
113
- HexMeshPytree,
114
- StructuredHexBox,
115
- SurfaceMesh,
116
- SurfaceMeshPytree,
117
- SurfaceWithElemConn,
118
- surface_with_elem_conn,
119
- SurfaceSupermesh,
120
- build_surface_supermesh,
121
- tag_axis_minmax_facets,
122
- TetMesh,
123
- TetMeshPytree,
124
- StructuredTetBox,
125
- StructuredTetTensorBox,
126
- load_gmsh_mesh,
127
- load_gmsh_hex_mesh,
128
- load_gmsh_tet_mesh,
129
- make_surface_from_facets,
130
- SurfaceSupermesh,
131
- build_surface_supermesh,
132
- MortarMatrix,
133
- assemble_mortar_matrices,
134
- assemble_contact_onesided_floor,
135
- assemble_mixed_surface_residual,
136
- assemble_mixed_surface_jacobian,
137
- map_surface_facets_to_tet_elements,
138
- map_surface_facets_to_hex_elements,
139
- tri_area,
140
- tri_quadrature,
141
- facet_triangles,
142
- facet_shape_values,
143
- volume_shape_values_at_points,
144
- quad_shape_and_local,
145
- quad9_shape_values,
146
- hex27_gradN,
147
- ContactSurfaceSpace,
148
- ContactSide,
149
- OneSidedContact,
150
- OneSidedContactSurfaceSpace,
151
- facet_gap_values,
152
- active_contact_facets,
153
83
  )
154
84
  from .basis import (
155
85
  HexTriLinearBasis,
@@ -167,9 +97,58 @@ import importlib
167
97
 
168
98
  from .solver import spdirect_solve_cpu, spdirect_solve_gpu, spdirect_solve_jax, coo_to_csr
169
99
 
100
+ _MESH_EXPORTS = {
101
+ "BaseMeshPytree",
102
+ "SurfaceWithFacetMap",
103
+ "bbox_predicate",
104
+ "plane_predicate",
105
+ "axis_plane_predicate",
106
+ "slab_predicate",
107
+ "HexMesh",
108
+ "HexMeshPytree",
109
+ "StructuredHexBox",
110
+ "SurfaceMesh",
111
+ "SurfaceMeshPytree",
112
+ "SurfaceWithElemConn",
113
+ "surface_with_elem_conn",
114
+ "SurfaceSupermesh",
115
+ "build_surface_supermesh",
116
+ "tag_axis_minmax_facets",
117
+ "TetMesh",
118
+ "TetMeshPytree",
119
+ "StructuredTetBox",
120
+ "StructuredTetTensorBox",
121
+ "load_gmsh_mesh",
122
+ "load_gmsh_hex_mesh",
123
+ "load_gmsh_tet_mesh",
124
+ "make_surface_from_facets",
125
+ "MortarMatrix",
126
+ "assemble_mortar_matrices",
127
+ "assemble_contact_onesided_floor",
128
+ "assemble_mixed_surface_residual",
129
+ "assemble_mixed_surface_jacobian",
130
+ "map_surface_facets_to_tet_elements",
131
+ "map_surface_facets_to_hex_elements",
132
+ "tri_area",
133
+ "tri_quadrature",
134
+ "facet_triangles",
135
+ "facet_shape_values",
136
+ "volume_shape_values_at_points",
137
+ "quad_shape_and_local",
138
+ "quad9_shape_values",
139
+ "hex27_gradN",
140
+ "ContactSurfaceSpace",
141
+ "ContactSide",
142
+ "OneSidedContact",
143
+ "OneSidedContactSurfaceSpace",
144
+ "facet_gap_values",
145
+ "active_contact_facets",
146
+ }
147
+
170
148
  _SOLVER_EXPORTS = {
171
149
  "SparsityPattern",
172
150
  "FluxSparseMatrix",
151
+ "FluxBlockMatrix",
173
152
  "coalesce_coo",
174
153
  "concat_flux",
175
154
  "block_diag_flux",
@@ -227,7 +206,9 @@ _SOLVER_BC_EXPORTS = {
227
206
 
228
207
 
229
208
  def __getattr__(name: str):
230
- if name in _SOLVER_EXPORTS:
209
+ if name in _MESH_EXPORTS:
210
+ module = importlib.import_module("..mesh", __name__)
211
+ elif name in _SOLVER_EXPORTS:
231
212
  module = importlib.import_module("..solver", __name__)
232
213
  elif name in _SOLVER_BC_EXPORTS:
233
214
  module = importlib.import_module("..solver.bc", __name__)
@@ -410,6 +391,7 @@ __all__ = [
410
391
  "coo_to_csr",
411
392
  "SparsityPattern",
412
393
  "FluxSparseMatrix",
394
+ "FluxBlockMatrix",
413
395
  "coalesce_coo",
414
396
  "concat_flux",
415
397
  "block_diag_flux",
fluxfem/core/assembly.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import Any, Callable, Literal, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union
2
+ from typing import Any, Callable, Literal, Mapping, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union, cast
3
3
  import numpy as np
4
4
  import jax
5
5
  import jax.numpy as jnp
@@ -10,11 +10,16 @@ from .forms import FormContext
10
10
  from .space import FESpaceBase
11
11
 
12
12
  # Shared call signatures for kernels/forms
13
- Array = jnp.ndarray
13
+ Array: TypeAlias = jnp.ndarray
14
14
  P = TypeVar("P")
15
15
 
16
- Kernel = Callable[[FormContext, P], Array]
16
+ FormKernel: TypeAlias = Callable[[FormContext, P], Array]
17
+ # Form kernels return integrands; element kernels return integrated element arrays.
18
+ Kernel: TypeAlias = Callable[[FormContext, P], Array]
19
+ ResidualInput: TypeAlias = Array | Mapping[str, Array]
20
+ ResidualValue: TypeAlias = Array | Mapping[str, Array]
17
21
  ResidualForm = Callable[[FormContext, Array, P], Array]
22
+ ResidualFormLike = Callable[[FormContext, ResidualInput, P], ResidualValue]
18
23
  ElementDofMapper = Callable[[Array], Array]
19
24
 
20
25
  if TYPE_CHECKING:
@@ -107,7 +112,7 @@ class BatchedAssembler:
107
112
  def __init__(
108
113
  self,
109
114
  space: SpaceLike,
110
- elem_data: Any,
115
+ elem_data: FormContext,
111
116
  elem_dofs: Array,
112
117
  *,
113
118
  pattern: SparsityPattern | None = None,
@@ -119,8 +124,8 @@ class BatchedAssembler:
119
124
  self.n_ldofs = int(space.n_ldofs)
120
125
  self.n_dofs = int(space.n_dofs)
121
126
  self.pattern = pattern
122
- self._rows = None
123
- self._cols = None
127
+ self._rows: Array | None = None
128
+ self._cols: Array | None = None
124
129
 
125
130
  @classmethod
126
131
  def from_space(
@@ -135,7 +140,7 @@ class BatchedAssembler:
135
140
 
136
141
  def make_mask(self, n_active: int) -> Array:
137
142
  n_active = max(0, min(int(n_active), self.n_elems))
138
- mask = np.zeros((self.n_elems,), dtype=float)
143
+ mask: np.ndarray = np.zeros((self.n_elems,), dtype=float)
139
144
  if n_active:
140
145
  mask[:n_active] = 1.0
141
146
  return jnp.asarray(mask)
@@ -177,7 +182,7 @@ class BatchedAssembler:
177
182
 
178
183
  def assemble_bilinear(
179
184
  self,
180
- form: Kernel[P],
185
+ form: FormKernel[P],
181
186
  params: P,
182
187
  *,
183
188
  mask: Array | None = None,
@@ -208,7 +213,7 @@ class BatchedAssembler:
208
213
 
209
214
  def assemble_linear(
210
215
  self,
211
- form: Kernel[P],
216
+ form: FormKernel[P],
212
217
  params: P,
213
218
  *,
214
219
  mask: Array | None = None,
@@ -359,7 +364,7 @@ class SpaceLike(FESpaceBase, Protocol):
359
364
 
360
365
  def assemble_bilinear_dense(
361
366
  space: SpaceLike,
362
- kernel: Kernel[P],
367
+ kernel: FormKernel[P],
363
368
  params: P,
364
369
  *,
365
370
  sparse: bool = False,
@@ -412,7 +417,7 @@ def assemble_bilinear_dense(
412
417
 
413
418
  def assemble_bilinear_form(
414
419
  space: SpaceLike,
415
- form: Kernel[P],
420
+ form: FormKernel[P],
416
421
  params: P,
417
422
  *,
418
423
  pattern: SparsityPattern | None = None,
@@ -589,7 +594,7 @@ def assemble_mass_matrix(
589
594
 
590
595
  def assemble_linear_form(
591
596
  space: SpaceLike,
592
- form: Kernel[P],
597
+ form: FormKernel[P],
593
598
  params: P,
594
599
  *,
595
600
  kernel: ElementLinearKernel | None = None,
@@ -670,7 +675,7 @@ def assemble_linear_form(
670
675
  return F
671
676
 
672
677
 
673
- def assemble_functional(space: SpaceLike, form: Kernel[P], params: P) -> jnp.ndarray:
678
+ def assemble_functional(space: SpaceLike, form: FormKernel[P], params: P) -> jnp.ndarray:
674
679
  """
675
680
  Assemble scalar functional J = ∫ form(ctx, params) dΩ.
676
681
  Expects form(ctx, params) -> (n_q,) or (n_q, 1).
@@ -798,7 +803,7 @@ def assemble_jacobian_elementwise(
798
803
  )
799
804
  K_flat = jnp.zeros(n_entries, dtype=data.dtype)
800
805
  K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
801
- return K_flat.reshape(pat.n_dofs, pat.n_dofs)
806
+ return K_flat.reshape(n_dofs, n_dofs)
802
807
 
803
808
 
804
809
  def assemble_residual_global(
@@ -885,7 +890,7 @@ assemble_residual_elementwise_xla = assemble_residual_elementwise
885
890
 
886
891
 
887
892
  def make_element_bilinear_kernel(
888
- form: Kernel[P], params: P, *, jit: bool = True
893
+ form: FormKernel[P], params: P, *, jit: bool = True
889
894
  ) -> ElementBilinearKernel:
890
895
  """Element kernel: (ctx) -> Ke."""
891
896
 
@@ -900,7 +905,7 @@ def make_element_bilinear_kernel(
900
905
 
901
906
 
902
907
  def make_element_linear_kernel(
903
- form: Kernel[P], params: P, *, jit: bool = True
908
+ form: FormKernel[P], params: P, *, jit: bool = True
904
909
  ) -> ElementLinearKernel:
905
910
  """Element kernel: (ctx) -> fe."""
906
911
 
@@ -945,8 +950,8 @@ def make_element_jacobian_kernel(
945
950
 
946
951
 
947
952
  def element_residual(
948
- res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P
949
- ) -> Any:
953
+ res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
954
+ ) -> ResidualValue:
950
955
  """
951
956
  Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
952
957
  Returns shape (n_ldofs,).
@@ -972,8 +977,8 @@ def element_residual(
972
977
 
973
978
 
974
979
  def element_jacobian(
975
- res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P
976
- ) -> Any:
980
+ res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
981
+ ) -> ResidualValue:
977
982
  """
978
983
  Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
979
984
  """
@@ -984,7 +989,7 @@ def element_jacobian(
984
989
 
985
990
 
986
991
  def make_element_kernel(
987
- form: Kernel[P] | ResidualForm[P],
992
+ form: FormKernel[P] | ResidualForm[P],
988
993
  params: P,
989
994
  *,
990
995
  kind: Literal["bilinear", "linear", "residual", "jacobian"],
@@ -999,22 +1004,26 @@ def make_element_kernel(
999
1004
  - "residual": kernel(ctx, u_elem) -> (n_ldofs,)
1000
1005
  - "jacobian": kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
1001
1006
  """
1002
- kind = kind.lower()
1007
+ kind = cast(Literal["bilinear", "linear", "residual", "jacobian"], kind.lower())
1003
1008
  if kind == "bilinear":
1004
- return make_element_bilinear_kernel(form, params, jit=jit)
1009
+ form_bilinear = cast(FormKernel[P], form)
1010
+ return make_element_bilinear_kernel(form_bilinear, params, jit=jit)
1005
1011
  if kind == "linear":
1012
+ form_linear = cast(FormKernel[P], form)
1006
1013
  def per_element(ctx: FormContext):
1007
- integrand = form(ctx, params)
1008
- if getattr(form, "_includes_measure", False):
1014
+ integrand = form_linear(ctx, params)
1015
+ if getattr(form_linear, "_includes_measure", False):
1009
1016
  return integrand.sum(axis=0)
1010
1017
  wJ = ctx.w * ctx.test.detJ
1011
1018
  return (integrand * wJ[:, None]).sum(axis=0)
1012
1019
 
1013
1020
  return jax.jit(per_element) if jit else per_element
1014
1021
  if kind == "residual":
1015
- return make_element_residual_kernel(form, params)
1022
+ form_residual = cast(ResidualForm[P], form)
1023
+ return make_element_residual_kernel(form_residual, params)
1016
1024
  if kind == "jacobian":
1017
- return make_element_jacobian_kernel(form, params)
1025
+ form_residual = cast(ResidualForm[P], form)
1026
+ return make_element_jacobian_kernel(form_residual, params)
1018
1027
  raise ValueError(f"Unknown kernel kind: {kind}")
1019
1028
 
1020
1029
 
@@ -1327,19 +1336,19 @@ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
1327
1336
  return load * ctx.test.N # (n_q, n_ldofs)
1328
1337
 
1329
1338
 
1330
- scalar_body_force_form._ff_kind = "linear"
1331
- scalar_body_force_form._ff_domain = "volume"
1339
+ scalar_body_force_form._ff_kind = "linear" # type: ignore[attr-defined]
1340
+ scalar_body_force_form._ff_domain = "volume" # type: ignore[attr-defined]
1332
1341
 
1333
1342
 
1334
- def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> Kernel[Any]:
1343
+ def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> FormKernel[Any]:
1335
1344
  """
1336
1345
  Build a scalar linear form from a callable f(x_q) -> (n_q,).
1337
1346
  """
1338
1347
  def _form(ctx: FormContext, _params):
1339
1348
  f_q = body_force(ctx.x_q)
1340
1349
  return f_q[..., None] * ctx.test.N
1341
- _form._ff_kind = "linear"
1342
- _form._ff_domain = "volume"
1350
+ _form._ff_kind = "linear" # type: ignore[attr-defined]
1351
+ _form._ff_domain = "volume" # type: ignore[attr-defined]
1343
1352
  return _form
1344
1353
 
1345
1354
 
fluxfem/core/basis.py CHANGED
@@ -421,7 +421,7 @@ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
421
421
 
422
422
  def tree_flatten(self):
423
423
  children = (self.quad_points, self.quad_weights)
424
- aux_data = {}
424
+ aux_data: dict[str, object] = {}
425
425
  return children, aux_data
426
426
 
427
427
  @classmethod
@@ -569,7 +569,7 @@ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
569
569
 
570
570
  def tree_flatten(self):
571
571
  children = (self.quad_points, self.quad_weights)
572
- aux_data = {}
572
+ aux_data: dict[str, object] = {}
573
573
  return children, aux_data
574
574
 
575
575
  @classmethod
@@ -1,36 +1,60 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Protocol, TypeAlias, runtime_checkable
3
+ from typing import Any, Mapping, Protocol, TYPE_CHECKING, TypeAlias, runtime_checkable
4
+
5
+ import numpy as np
6
+
7
+ if TYPE_CHECKING:
8
+ from jax import Array as JaxArray
9
+
10
+ ArrayLike: TypeAlias = np.ndarray | JaxArray
11
+ else:
12
+ ArrayLike: TypeAlias = np.ndarray
4
13
 
5
14
 
6
15
  @runtime_checkable
7
16
  class VolumeContext(Protocol):
8
17
  """Minimum interface for volume weak-form evaluation."""
9
18
 
10
- test: Any
11
- trial: Any
12
- w: Any
19
+ test: "FormFieldLike"
20
+ trial: "FormFieldLike"
21
+ w: ArrayLike
13
22
 
14
23
 
15
24
  @runtime_checkable
16
25
  class SurfaceContext(Protocol):
17
26
  """Minimum interface for surface weak-form evaluation."""
18
27
 
19
- v: Any
20
- w: Any
21
- detJ: Any
22
- normal: Any
28
+ v: "FormFieldLike"
29
+ w: ArrayLike
30
+ detJ: ArrayLike
31
+ normal: ArrayLike
23
32
 
24
33
 
25
34
  @runtime_checkable
26
35
  class FormFieldLike(Protocol):
27
36
  """Minimum interface for form fields used in weak-form evaluation."""
28
37
 
29
- N: Any
30
- gradN: Any
31
- detJ: Any
38
+ N: ArrayLike
39
+ gradN: ArrayLike
40
+ detJ: ArrayLike
32
41
  value_dim: int
42
+ basis: Any
43
+
44
+
45
+ @runtime_checkable
46
+ class WeakFormContext(Protocol):
47
+ """Context interface used when resolving field references."""
48
+
49
+ test: FormFieldLike
50
+ trial: FormFieldLike
51
+ v: FormFieldLike
52
+ unknown: FormFieldLike | None
53
+ fields: Mapping[str, Any] | None
54
+ test_fields: Mapping[str, FormFieldLike] | None
55
+ trial_fields: Mapping[str, FormFieldLike] | None
56
+ unknown_fields: Mapping[str, FormFieldLike] | None
33
57
 
34
58
 
35
- UElement: TypeAlias = Any
59
+ UElement: TypeAlias = ArrayLike | Mapping[str, ArrayLike]
36
60
  ParamsLike: TypeAlias = Any
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
- from typing import Mapping, Sequence, Callable
4
+ from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
5
5
 
6
6
  import numpy as np
7
7
  import jax.numpy as jnp
@@ -13,6 +13,16 @@ from ..solver.dirichlet import DirichletBC, free_dofs
13
13
  from ..solver.sparse import FluxSparseMatrix
14
14
  from .space import FESpaceClosure
15
15
 
16
+ P = TypeVar("P")
17
+
18
+ if TYPE_CHECKING:
19
+ from .assembly import JacobianReturn, LinearReturn
20
+
21
+ MixedResidualForm: TypeAlias = Callable[
22
+ [MixedFormContext, Mapping[str, jnp.ndarray], P],
23
+ Mapping[str, jnp.ndarray],
24
+ ]
25
+
16
26
 
17
27
  @dataclass(eq=False)
18
28
  class MixedFESpace:
@@ -131,13 +141,25 @@ class MixedFESpace:
131
141
 
132
142
  def get_sparsity_pattern(self, *, with_idx: bool = True):
133
143
  from .assembly import make_sparsity_pattern
134
- return make_sparsity_pattern(self, with_idx=with_idx)
144
+ return make_sparsity_pattern(cast(Any, self), with_idx=with_idx)
135
145
 
136
- def assemble_residual(self, res_form, u, params, **kwargs):
146
+ def assemble_residual(
147
+ self,
148
+ res_form: MixedResidualForm[P],
149
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
150
+ params: P,
151
+ **kwargs,
152
+ ) -> "LinearReturn":
137
153
  from .mixed_assembly import assemble_mixed_residual
138
154
  return assemble_mixed_residual(self, res_form, u, params, **kwargs)
139
155
 
140
- def assemble_jacobian(self, res_form, u, params, **kwargs):
156
+ def assemble_jacobian(
157
+ self,
158
+ res_form: MixedResidualForm[P],
159
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
160
+ params: P,
161
+ **kwargs,
162
+ ) -> "JacobianReturn":
141
163
  from .mixed_assembly import assemble_mixed_jacobian
142
164
  return assemble_mixed_jacobian(self, res_form, u, params, **kwargs)
143
165
 
@@ -233,7 +255,7 @@ class MixedProblem:
233
255
  pattern: object | None = None
234
256
  n_chunks: int | None = None
235
257
  pad_trace: bool = False
236
- _compiled: object = field(init=False, repr=False)
258
+ _compiled: Callable[..., Any] = field(init=False, repr=False)
237
259
 
238
260
  def __post_init__(self):
239
261
  if isinstance(self.residuals, MixedWeakForm):
@@ -257,18 +279,30 @@ class MixedProblem:
257
279
  def _wrapped(ctx, u_elem, _params):
258
280
  return self._compiled(ctx, u_elem, params(ctx))
259
281
 
260
- _wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False)
282
+ _wrapped._includes_measure = getattr(self._compiled, "_includes_measure", False) # type: ignore[attr-defined]
261
283
  return _wrapped, None
262
284
  return self._compiled, params
263
285
 
264
- def assemble_residual(self, u, *, params=None, **kwargs):
286
+ def assemble_residual(
287
+ self,
288
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
289
+ *,
290
+ params: P | None = None,
291
+ **kwargs,
292
+ ) -> "LinearReturn":
265
293
  use_params = self.params if params is None else params
266
294
  res_form, use_params = self._wrap_params(use_params)
267
295
  return self.space.assemble_residual(
268
296
  res_form, u, use_params, **self._merge_kwargs(kwargs)
269
297
  )
270
298
 
271
- def assemble_jacobian(self, u, *, params=None, **kwargs):
299
+ def assemble_jacobian(
300
+ self,
301
+ u: Mapping[str, jnp.ndarray] | Sequence[jnp.ndarray] | jnp.ndarray,
302
+ *,
303
+ params: P | None = None,
304
+ **kwargs,
305
+ ) -> "JacobianReturn":
272
306
  use_params = self.params if params is None else params
273
307
  res_form, use_params = self._wrap_params(use_params)
274
308
  return self.space.assemble_jacobian(
@@ -40,7 +40,7 @@ def _wrap_params(res_form, params):
40
40
  def _wrapped(ctx, u_elem, _params):
41
41
  return res_form(ctx, u_elem, params(ctx))
42
42
 
43
- _wrapped._includes_measure = getattr(res_form, "_includes_measure", False)
43
+ _wrapped._includes_measure = getattr(res_form, "_includes_measure", False) # type: ignore[attr-defined]
44
44
  return _wrapped, None
45
45
  return res_form, params
46
46