fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fluxfem/__init__.py +1 -13
  2. fluxfem/core/__init__.py +53 -71
  3. fluxfem/core/assembly.py +41 -32
  4. fluxfem/core/basis.py +2 -2
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/mixed_space.py +42 -8
  7. fluxfem/core/mixed_weakform.py +1 -1
  8. fluxfem/core/space.py +68 -28
  9. fluxfem/core/weakform.py +95 -77
  10. fluxfem/mesh/base.py +3 -3
  11. fluxfem/mesh/contact.py +33 -17
  12. fluxfem/mesh/io.py +3 -2
  13. fluxfem/mesh/mortar.py +106 -43
  14. fluxfem/mesh/supermesh.py +2 -0
  15. fluxfem/mesh/surface.py +82 -22
  16. fluxfem/mesh/tet.py +7 -4
  17. fluxfem/physics/elasticity/hyperelastic.py +32 -3
  18. fluxfem/physics/elasticity/linear.py +13 -2
  19. fluxfem/physics/elasticity/stress.py +9 -5
  20. fluxfem/physics/operators.py +12 -5
  21. fluxfem/physics/postprocess.py +29 -3
  22. fluxfem/solver/__init__.py +6 -1
  23. fluxfem/solver/block_matrix.py +165 -13
  24. fluxfem/solver/block_system.py +52 -29
  25. fluxfem/solver/cg.py +43 -30
  26. fluxfem/solver/dirichlet.py +35 -12
  27. fluxfem/solver/history.py +15 -3
  28. fluxfem/solver/newton.py +25 -12
  29. fluxfem/solver/petsc.py +13 -7
  30. fluxfem/solver/preconditioner.py +7 -4
  31. fluxfem/solver/solve_runner.py +42 -24
  32. fluxfem/solver/solver.py +23 -11
  33. fluxfem/solver/sparse.py +32 -13
  34. fluxfem/tools/jit.py +19 -7
  35. fluxfem/tools/timer.py +14 -12
  36. fluxfem/tools/visualizer.py +16 -4
  37. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
  38. fluxfem-0.2.1.dist-info/RECORD +59 -0
  39. fluxfem-0.2.0.dist-info/RECORD +0 -59
  40. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  41. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/weakform.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable, Iterator, Literal, get_args
4
+ from typing import Any, Callable, Iterator, Literal, Mapping, TypeAlias, cast, get_args
5
5
  import inspect
6
6
  from dataclasses import dataclass
7
7
  from functools import update_wrapper
@@ -12,7 +12,7 @@ import jax.numpy as jnp
12
12
  import jax
13
13
 
14
14
  from ..physics import operators as _ops
15
- from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
15
+ from .context_types import ArrayLike, FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext, WeakFormContext
16
16
 
17
17
 
18
18
  OpName = Literal[
@@ -47,6 +47,7 @@ OpName = Literal[
47
47
  "einsum",
48
48
  ]
49
49
 
50
+
50
51
  # Use OpName as the single source of truth for valid ops.
51
52
  _OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
52
53
 
@@ -404,6 +405,7 @@ class _ZeroField:
404
405
  def __init__(self, base):
405
406
  self.N = jnp.zeros_like(base.N)
406
407
  self.gradN = None if getattr(base, "gradN", None) is None else jnp.zeros_like(base.gradN)
408
+ self.detJ = getattr(base, "detJ", None)
407
409
  self.value_dim = int(getattr(base, "value_dim", 1))
408
410
  self.basis = getattr(base, "basis", None)
409
411
 
@@ -414,6 +416,7 @@ class _ZeroFieldNp:
414
416
  def __init__(self, base):
415
417
  self.N = np.zeros_like(base.N)
416
418
  self.gradN = None if getattr(base, "gradN", None) is None else np.zeros_like(base.gradN)
419
+ self.detJ = getattr(base, "detJ", None)
417
420
  self.value_dim = int(getattr(base, "value_dim", 1))
418
421
  self.basis = getattr(base, "basis", None)
419
422
 
@@ -444,8 +447,8 @@ def zero_ref(name: str) -> FieldRef:
444
447
 
445
448
 
446
449
  def _eval_field(
447
- obj: Any,
448
- ctx: VolumeContext | SurfaceContext,
450
+ obj: FieldRef,
451
+ ctx: WeakFormContext,
449
452
  params: ParamsLike,
450
453
  ) -> FormFieldLike:
451
454
  if isinstance(obj, FieldRef):
@@ -453,12 +456,15 @@ def _eval_field(
453
456
  if obj.name is None:
454
457
  raise ValueError("zero_ref requires a named field.")
455
458
  base = None
456
- if getattr(ctx, "test_fields", None) is not None and obj.name in ctx.test_fields:
457
- base = ctx.test_fields[obj.name]
458
- if base is None and getattr(ctx, "trial_fields", None) is not None and obj.name in ctx.trial_fields:
459
- base = ctx.trial_fields[obj.name]
460
- if base is None and getattr(ctx, "fields", None) is not None and obj.name in ctx.fields:
461
- group = ctx.fields[obj.name]
459
+ test_fields = getattr(ctx, "test_fields", None)
460
+ if test_fields is not None and obj.name in test_fields:
461
+ base = test_fields[obj.name]
462
+ trial_fields = getattr(ctx, "trial_fields", None)
463
+ if base is None and trial_fields is not None and obj.name in trial_fields:
464
+ base = trial_fields[obj.name]
465
+ fields = getattr(ctx, "fields", None)
466
+ if base is None and fields is not None and obj.name in fields:
467
+ group = fields[obj.name]
462
468
  if hasattr(group, "test"):
463
469
  base = group.test
464
470
  elif hasattr(group, "trial"):
@@ -476,15 +482,15 @@ def _eval_field(
476
482
  return group.test
477
483
  if hasattr(group, "unknown") and obj.role == "unknown":
478
484
  return group.unknown if group.unknown is not None else group.trial
479
- if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
480
- if obj.name in ctx.trial_fields:
481
- return ctx.trial_fields[obj.name]
482
- if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
483
- if obj.name in ctx.test_fields:
484
- return ctx.test_fields[obj.name]
485
- if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
486
- if obj.name in ctx.unknown_fields:
487
- return ctx.unknown_fields[obj.name]
485
+ trial_fields = getattr(ctx, "trial_fields", None)
486
+ if obj.role == "trial" and trial_fields is not None and obj.name in trial_fields:
487
+ return trial_fields[obj.name]
488
+ test_fields = getattr(ctx, "test_fields", None)
489
+ if obj.role == "test" and test_fields is not None and obj.name in test_fields:
490
+ return test_fields[obj.name]
491
+ unknown_fields = getattr(ctx, "unknown_fields", None)
492
+ if obj.role == "unknown" and unknown_fields is not None and obj.name in unknown_fields:
493
+ return unknown_fields[obj.name]
488
494
  fields = getattr(ctx, "fields", None)
489
495
  if fields is not None and obj.name in fields:
490
496
  group = fields[obj.name]
@@ -509,8 +515,8 @@ def _eval_field(
509
515
 
510
516
 
511
517
  def _eval_field_np(
512
- obj: Any,
513
- ctx: VolumeContext | SurfaceContext,
518
+ obj: FieldRef,
519
+ ctx: WeakFormContext,
514
520
  params: ParamsLike,
515
521
  ) -> FormFieldLike:
516
522
  if isinstance(obj, FieldRef):
@@ -518,12 +524,15 @@ def _eval_field_np(
518
524
  if obj.name is None:
519
525
  raise ValueError("zero_ref requires a named field.")
520
526
  base = None
521
- if getattr(ctx, "test_fields", None) is not None and obj.name in ctx.test_fields:
522
- base = ctx.test_fields[obj.name]
523
- if base is None and getattr(ctx, "trial_fields", None) is not None and obj.name in ctx.trial_fields:
524
- base = ctx.trial_fields[obj.name]
525
- if base is None and getattr(ctx, "fields", None) is not None and obj.name in ctx.fields:
526
- group = ctx.fields[obj.name]
527
+ test_fields = getattr(ctx, "test_fields", None)
528
+ if test_fields is not None and obj.name in test_fields:
529
+ base = test_fields[obj.name]
530
+ trial_fields = getattr(ctx, "trial_fields", None)
531
+ if base is None and trial_fields is not None and obj.name in trial_fields:
532
+ base = trial_fields[obj.name]
533
+ fields = getattr(ctx, "fields", None)
534
+ if base is None and fields is not None and obj.name in fields:
535
+ group = fields[obj.name]
527
536
  if hasattr(group, "test"):
528
537
  base = group.test
529
538
  elif hasattr(group, "trial"):
@@ -541,15 +550,15 @@ def _eval_field_np(
541
550
  return group.test
542
551
  if hasattr(group, "unknown") and obj.role == "unknown":
543
552
  return group.unknown if group.unknown is not None else group.trial
544
- if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
545
- if obj.name in ctx.trial_fields:
546
- return ctx.trial_fields[obj.name]
547
- if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
548
- if obj.name in ctx.test_fields:
549
- return ctx.test_fields[obj.name]
550
- if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
551
- if obj.name in ctx.unknown_fields:
552
- return ctx.unknown_fields[obj.name]
553
+ trial_fields = getattr(ctx, "trial_fields", None)
554
+ if obj.role == "trial" and trial_fields is not None and obj.name in trial_fields:
555
+ return trial_fields[obj.name]
556
+ test_fields = getattr(ctx, "test_fields", None)
557
+ if obj.role == "test" and test_fields is not None and obj.name in test_fields:
558
+ return test_fields[obj.name]
559
+ unknown_fields = getattr(ctx, "unknown_fields", None)
560
+ if obj.role == "unknown" and unknown_fields is not None and obj.name in unknown_fields:
561
+ return unknown_fields[obj.name]
553
562
  fields = getattr(ctx, "fields", None)
554
563
  if fields is not None and obj.name in fields:
555
564
  group = fields[obj.name]
@@ -586,7 +595,7 @@ def _eval_field_np(
586
595
  # return obj
587
596
 
588
597
 
589
- def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
598
+ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement) -> ArrayLike:
590
599
  if u_elem is None:
591
600
  raise ValueError("u_elem is required to evaluate unknown field value.")
592
601
  if isinstance(u_elem, dict):
@@ -598,8 +607,9 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
598
607
 
599
608
 
600
609
  def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
601
- v_field = _eval_field(test, ctx, params)
602
- u_field = _eval_field(trial, ctx, params)
610
+ ctx_w = cast(WeakFormContext, ctx)
611
+ v_field = _eval_field(test, ctx_w, params)
612
+ u_field = _eval_field(trial, ctx_w, params)
603
613
  if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
604
614
  raise ValueError(
605
615
  "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
@@ -608,8 +618,9 @@ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
608
618
 
609
619
 
610
620
  def _basis_outer_np(test: FieldRef, trial: FieldRef, ctx, params):
611
- v_field = _eval_field_np(test, ctx, params)
612
- u_field = _eval_field_np(trial, ctx, params)
621
+ ctx_w = cast(WeakFormContext, ctx)
622
+ v_field = _eval_field_np(test, ctx_w, params)
623
+ u_field = _eval_field_np(trial, ctx_w, params)
613
624
  if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
614
625
  raise ValueError(
615
626
  "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
@@ -669,7 +680,7 @@ def _eval_unknown_grad_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UEl
669
680
  return np.einsum("qaj,ai->qij", field.gradN, u_nodes)
670
681
 
671
682
 
672
- def _vector_load_form_np(field: Any, load_vec: Any) -> np.ndarray:
683
+ def _vector_load_form_np(field: FormFieldLike, load_vec: ArrayLike) -> np.ndarray:
673
684
  lv = np.asarray(load_vec)
674
685
  if lv.ndim == 1:
675
686
  lv = lv[None, :]
@@ -709,7 +720,7 @@ def _sym_grad_np(field) -> np.ndarray:
709
720
  return B
710
721
 
711
722
 
712
- def _sym_grad_u_np(field, u_elem: Any) -> np.ndarray:
723
+ def _sym_grad_u_np(field: FormFieldLike, u_elem: ArrayLike) -> np.ndarray:
713
724
  B = _sym_grad_np(field)
714
725
  u_arr = np.asarray(u_elem)
715
726
  if u_arr.ndim == 2:
@@ -717,20 +728,20 @@ def _sym_grad_u_np(field, u_elem: Any) -> np.ndarray:
717
728
  return np.einsum("qik,k->qi", B, u_arr)
718
729
 
719
730
 
720
- def _ddot_np(a: Any, b: Any, c: Any | None = None) -> np.ndarray:
731
+ def _ddot_np(a: ArrayLike, b: ArrayLike, c: ArrayLike | None = None) -> np.ndarray:
721
732
  if c is None:
722
733
  return np.einsum("...ij,...ij->...", a, b)
723
734
  a_t = np.swapaxes(a, -1, -2)
724
735
  return np.einsum("...ik,kl,...lm->...im", a_t, b, c)
725
736
 
726
737
 
727
- def _dot_np(a: Any, b: Any) -> np.ndarray:
738
+ def _dot_np(a: FormFieldLike | ArrayLike, b: ArrayLike) -> np.ndarray:
728
739
  if hasattr(a, "N") and getattr(a, "value_dim", None) is not None:
729
- return _vector_load_form_np(a, b)
740
+ return _vector_load_form_np(cast(FormFieldLike, a), b)
730
741
  return np.matmul(a, b)
731
742
 
732
743
 
733
- def _transpose_last2_np(a: Any) -> np.ndarray:
744
+ def _transpose_last2_np(a: ArrayLike) -> np.ndarray:
734
745
  return np.swapaxes(a, -1, -2)
735
746
 
736
747
 
@@ -930,9 +941,10 @@ def compile_bilinear(fn):
930
941
  v = test_ref()
931
942
  p = param_ref()
932
943
  expr = _call_user(fn, u, v, params=p)
933
- expr = _as_expr(expr)
934
- if not isinstance(expr, Expr):
944
+ expr_raw = _as_expr(expr)
945
+ if not isinstance(expr_raw, Expr):
935
946
  raise TypeError("Bilinear form must return an Expr.")
947
+ expr = cast(Expr, expr_raw)
936
948
 
937
949
  volume_count = _count_op(expr, "volume_measure")
938
950
  surface_count = _count_op(expr, "surface_measure")
@@ -948,7 +960,7 @@ def compile_bilinear(fn):
948
960
  def _form(ctx, params):
949
961
  return eval_with_plan(plan, ctx, params)
950
962
 
951
- _form._includes_measure = True
963
+ _form._includes_measure = True # type: ignore[attr-defined]
952
964
  return _tag_form(_form, kind="bilinear", domain="volume")
953
965
 
954
966
 
@@ -960,9 +972,10 @@ def compile_linear(fn):
960
972
  v = test_ref()
961
973
  p = param_ref()
962
974
  expr = _call_user(fn, v, params=p)
963
- expr = _as_expr(expr)
964
- if not isinstance(expr, Expr):
975
+ expr_raw = _as_expr(expr)
976
+ if not isinstance(expr_raw, Expr):
965
977
  raise TypeError("Linear form must return an Expr.")
978
+ expr = cast(Expr, expr_raw)
966
979
 
967
980
  volume_count = _count_op(expr, "volume_measure")
968
981
  surface_count = _count_op(expr, "surface_measure")
@@ -978,7 +991,7 @@ def compile_linear(fn):
978
991
  def _form(ctx, params):
979
992
  return eval_with_plan(plan, ctx, params)
980
993
 
981
- _form._includes_measure = True
994
+ _form._includes_measure = True # type: ignore[attr-defined]
982
995
  return _tag_form(_form, kind="linear", domain="volume")
983
996
 
984
997
 
@@ -1067,6 +1080,7 @@ def eval_with_plan(
1067
1080
  nodes = plan.nodes
1068
1081
  index = plan.index
1069
1082
  vals: list[Any] = [None] * len(nodes)
1083
+ ctx_w = cast(WeakFormContext, ctx)
1070
1084
 
1071
1085
  def get(obj):
1072
1086
  if isinstance(obj, Expr):
@@ -1097,7 +1111,7 @@ def eval_with_plan(
1097
1111
  if op == "value":
1098
1112
  ref = args[0]
1099
1113
  assert isinstance(ref, FieldRef)
1100
- field = _eval_field(ref, ctx, params)
1114
+ field = _eval_field(ref, ctx_w, params)
1101
1115
  if ref.role == "unknown":
1102
1116
  vals[i] = _eval_unknown_value(ref, field, u_elem)
1103
1117
  else:
@@ -1106,7 +1120,7 @@ def eval_with_plan(
1106
1120
  if op == "grad":
1107
1121
  ref = args[0]
1108
1122
  assert isinstance(ref, FieldRef)
1109
- field = _eval_field(ref, ctx, params)
1123
+ field = _eval_field(ref, ctx_w, params)
1110
1124
  if ref.role == "unknown":
1111
1125
  vals[i] = _eval_unknown_grad(ref, field, u_elem)
1112
1126
  else:
@@ -1151,7 +1165,7 @@ def eval_with_plan(
1151
1165
  if op == "sym_grad":
1152
1166
  ref = args[0]
1153
1167
  assert isinstance(ref, FieldRef)
1154
- field = _eval_field(ref, ctx, params)
1168
+ field = _eval_field(ref, ctx_w, params)
1155
1169
  if ref.role == "unknown":
1156
1170
  if u_elem is None:
1157
1171
  raise ValueError("u_elem is required to evaluate unknown sym_grad.")
@@ -1215,7 +1229,7 @@ def eval_with_plan(
1215
1229
  if op == "dot":
1216
1230
  ref = args[0]
1217
1231
  if isinstance(ref, FieldRef):
1218
- vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
1232
+ vals[i] = _ops.dot(_eval_field(ref, ctx_w, params), get(args[1]))
1219
1233
  else:
1220
1234
  a = get(args[0])
1221
1235
  b = get(args[1])
@@ -1233,7 +1247,7 @@ def eval_with_plan(
1233
1247
  if op == "sdot":
1234
1248
  ref = args[0]
1235
1249
  if isinstance(ref, FieldRef):
1236
- vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
1250
+ vals[i] = _ops.dot(_eval_field(ref, ctx_w, params), get(args[1]))
1237
1251
  else:
1238
1252
  a = get(args[0])
1239
1253
  b = get(args[1])
@@ -1276,7 +1290,7 @@ def eval_with_plan(
1276
1290
  assert isinstance(ref, FieldRef)
1277
1291
  if isinstance(args[1], FieldRef):
1278
1292
  raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1279
- v_field = _eval_field(ref, ctx, params)
1293
+ v_field = _eval_field(ref, ctx_w, params)
1280
1294
  s = get(args[1])
1281
1295
  value_dim = int(getattr(v_field, "value_dim", 1))
1282
1296
  # action maps a test field with a scalar/vector expression into nodal space.
@@ -1294,7 +1308,7 @@ def eval_with_plan(
1294
1308
  if op == "gaction":
1295
1309
  ref = args[0]
1296
1310
  assert isinstance(ref, FieldRef)
1297
- v_field = _eval_field(ref, ctx, params)
1311
+ v_field = _eval_field(ref, ctx_w, params)
1298
1312
  q = get(args[1])
1299
1313
  # gaction maps a flux-like expression to nodal space via test gradients.
1300
1314
  if v_field.gradN.ndim != 3:
@@ -1336,6 +1350,7 @@ def eval_with_plan_numpy(
1336
1350
  nodes = plan.nodes
1337
1351
  index = plan.index
1338
1352
  vals: list[Any] = [None] * len(nodes)
1353
+ ctx_w = cast(WeakFormContext, ctx)
1339
1354
 
1340
1355
  def get(obj):
1341
1356
  if isinstance(obj, Expr):
@@ -1366,7 +1381,7 @@ def eval_with_plan_numpy(
1366
1381
  if op == "value":
1367
1382
  ref = args[0]
1368
1383
  assert isinstance(ref, FieldRef)
1369
- field = _eval_field_np(ref, ctx, params)
1384
+ field = _eval_field_np(ref, ctx_w, params)
1370
1385
  if ref.role == "unknown":
1371
1386
  vals[i] = _eval_unknown_value_np(ref, field, u_elem)
1372
1387
  else:
@@ -1375,7 +1390,7 @@ def eval_with_plan_numpy(
1375
1390
  if op == "grad":
1376
1391
  ref = args[0]
1377
1392
  assert isinstance(ref, FieldRef)
1378
- field = _eval_field_np(ref, ctx, params)
1393
+ field = _eval_field_np(ref, ctx_w, params)
1379
1394
  if ref.role == "unknown":
1380
1395
  vals[i] = _eval_unknown_grad_np(ref, field, u_elem)
1381
1396
  else:
@@ -1420,7 +1435,7 @@ def eval_with_plan_numpy(
1420
1435
  if op == "sym_grad":
1421
1436
  ref = args[0]
1422
1437
  assert isinstance(ref, FieldRef)
1423
- field = _eval_field_np(ref, ctx, params)
1438
+ field = _eval_field_np(ref, ctx_w, params)
1424
1439
  if ref.role == "unknown":
1425
1440
  if u_elem is None:
1426
1441
  raise ValueError("u_elem is required to evaluate unknown sym_grad.")
@@ -1484,7 +1499,7 @@ def eval_with_plan_numpy(
1484
1499
  if op == "dot":
1485
1500
  ref = args[0]
1486
1501
  if isinstance(ref, FieldRef):
1487
- vals[i] = _dot_np(_eval_field_np(ref, ctx, params), get(args[1]))
1502
+ vals[i] = _dot_np(_eval_field_np(ref, ctx_w, params), get(args[1]))
1488
1503
  else:
1489
1504
  a = get(args[0])
1490
1505
  b = get(args[1])
@@ -1503,7 +1518,7 @@ def eval_with_plan_numpy(
1503
1518
  if op == "sdot":
1504
1519
  ref = args[0]
1505
1520
  if isinstance(ref, FieldRef):
1506
- vals[i] = _dot_np(_eval_field_np(ref, ctx, params), get(args[1]))
1521
+ vals[i] = _dot_np(_eval_field_np(ref, ctx_w, params), get(args[1]))
1507
1522
  else:
1508
1523
  a = get(args[0])
1509
1524
  b = get(args[1])
@@ -1547,7 +1562,7 @@ def eval_with_plan_numpy(
1547
1562
  assert isinstance(ref, FieldRef)
1548
1563
  if isinstance(args[1], FieldRef):
1549
1564
  raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1550
- v_field = _eval_field_np(ref, ctx, params)
1565
+ v_field = _eval_field_np(ref, ctx_w, params)
1551
1566
  s = get(args[1])
1552
1567
  value_dim = int(getattr(v_field, "value_dim", 1))
1553
1568
  if value_dim == 1:
@@ -1564,7 +1579,7 @@ def eval_with_plan_numpy(
1564
1579
  if op == "gaction":
1565
1580
  ref = args[0]
1566
1581
  assert isinstance(ref, FieldRef)
1567
- v_field = _eval_field_np(ref, ctx, params)
1582
+ v_field = _eval_field_np(ref, ctx_w, params)
1568
1583
  q = get(args[1])
1569
1584
  if v_field.gradN.ndim != 3:
1570
1585
  raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
@@ -1623,9 +1638,10 @@ def compile_surface_linear(fn):
1623
1638
  p = param_ref()
1624
1639
  expr = _call_user(fn, v, params=p)
1625
1640
 
1626
- expr = _as_expr(expr)
1627
- if not isinstance(expr, Expr):
1641
+ expr_raw = _as_expr(expr)
1642
+ if not isinstance(expr_raw, Expr):
1628
1643
  raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
1644
+ expr = cast(Expr, expr_raw)
1629
1645
 
1630
1646
  surface_count = _count_op(expr, "surface_measure")
1631
1647
  volume_count = _count_op(expr, "volume_measure")
@@ -1655,9 +1671,10 @@ def compile_surface_bilinear(fn):
1655
1671
  p = param_ref()
1656
1672
  expr = _call_user(fn, u, v, params=p)
1657
1673
 
1658
- expr = _as_expr(expr)
1659
- if not isinstance(expr, Expr):
1674
+ expr_raw = _as_expr(expr)
1675
+ if not isinstance(expr_raw, Expr):
1660
1676
  raise ValueError("Surface bilinear form must return an Expr; use ds() in the expression.")
1677
+ expr = cast(Expr, expr_raw)
1661
1678
 
1662
1679
  surface_count = _count_op(expr, "surface_measure")
1663
1680
  volume_count = _count_op(expr, "volume_measure")
@@ -1738,9 +1755,10 @@ def compile_residual(fn):
1738
1755
  u = unknown_ref()
1739
1756
  p = param_ref()
1740
1757
  expr = _call_user(fn, v, u, params=p)
1741
- expr = _as_expr(expr)
1742
- if not isinstance(expr, Expr):
1758
+ expr_raw = _as_expr(expr)
1759
+ if not isinstance(expr_raw, Expr):
1743
1760
  raise TypeError("Residual form must return an Expr.")
1761
+ expr = cast(Expr, expr_raw)
1744
1762
 
1745
1763
  volume_count = _count_op(expr, "volume_measure")
1746
1764
  surface_count = _count_op(expr, "surface_measure")
@@ -1756,7 +1774,7 @@ def compile_residual(fn):
1756
1774
  def _form(ctx, u_elem, params):
1757
1775
  return eval_with_plan(plan, ctx, params, u_elem=u_elem)
1758
1776
 
1759
- _form._includes_measure = True
1777
+ _form._includes_measure = True # type: ignore[attr-defined]
1760
1778
  return _tag_form(_form, kind="residual", domain="volume")
1761
1779
 
1762
1780
 
@@ -1818,7 +1836,7 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
1818
1836
  for name, plan in plans.items()
1819
1837
  }
1820
1838
 
1821
- _form._includes_measure = includes_measure
1839
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
1822
1840
  return _tag_form(_form, kind="residual", domain="volume")
1823
1841
 
1824
1842
 
@@ -1878,7 +1896,7 @@ def compile_mixed_surface_residual(residuals: dict[str, Callable]):
1878
1896
  for name, plan in plans.items()
1879
1897
  }
1880
1898
 
1881
- _form._includes_measure = includes_measure
1899
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
1882
1900
  return _tag_form(_form, kind="residual", domain="surface")
1883
1901
 
1884
1902
 
@@ -1938,7 +1956,7 @@ def compile_mixed_surface_residual_numpy(residuals: dict[str, Callable]):
1938
1956
  for name, plan in plans.items()
1939
1957
  }
1940
1958
 
1941
- _form._includes_measure = includes_measure
1959
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
1942
1960
  return _tag_form(_form, kind="residual", domain="surface")
1943
1961
 
1944
1962
 
fluxfem/mesh/base.py CHANGED
@@ -160,7 +160,7 @@ class BaseMeshClosure:
160
160
  for pattern in patterns:
161
161
  nodes = tuple(sorted(int(elem_conn[i]) for i in pattern))
162
162
  face_counts[nodes] = face_counts.get(nodes, 0) + 1
163
- bnodes = set()
163
+ bnodes: set[int] = set()
164
164
  for nodes, count in face_counts.items():
165
165
  if count == 1:
166
166
  bnodes.update(nodes)
@@ -172,7 +172,7 @@ class BaseMeshClosure:
172
172
  """
173
173
  Return boolean mask for boundary nodes (shape: n_nodes).
174
174
  """
175
- mask = np.zeros(self.n_nodes, dtype=bool)
175
+ mask: np.ndarray = np.zeros(self.n_nodes, dtype=bool)
176
176
  nodes = self.boundary_node_indices()
177
177
  mask[nodes] = True
178
178
  return mask
@@ -449,7 +449,7 @@ class BaseMeshClosure:
449
449
  nodes_arr = np.asarray(list(nodes), dtype=int)
450
450
  if nodes_arr.size == 0:
451
451
  return np.asarray([], dtype=int)
452
- mark = np.zeros(self.n_nodes, dtype=bool)
452
+ mark: np.ndarray = np.zeros(self.n_nodes, dtype=bool)
453
453
  mark[nodes_arr] = True
454
454
  conn = np.asarray(self.conn)
455
455
  return np.nonzero(np.any(mark[conn], axis=1))[0]
fluxfem/mesh/contact.py CHANGED
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Mapping, Sequence
4
+ from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING, TypeAlias
5
5
 
6
6
  import numpy as np
7
+ import numpy.typing as npt
7
8
 
8
9
  from .mortar import (
9
10
  assemble_mixed_surface_jacobian,
@@ -18,6 +19,20 @@ from .supermesh import build_surface_supermesh
18
19
  from .surface import SurfaceMesh
19
20
  from .base import BaseMesh
20
21
 
22
+ if TYPE_CHECKING:
23
+ from .mortar import MortarMatrix
24
+ from ..core.weakform import Params as WeakParams
25
+ from .mortar import SurfaceMixedFormContext
26
+
27
+ ContactJacobianReturn: TypeAlias = np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, int]
28
+ MixedSurfaceResidualForm: TypeAlias = Callable[
29
+ ["SurfaceMixedFormContext", Mapping[str, npt.ArrayLike], Any],
30
+ Mapping[str, npt.ArrayLike],
31
+ ]
32
+ SurfaceHatFn: TypeAlias = Callable[[np.ndarray], npt.ArrayLike]
33
+
34
+ _CONTACT_SETUP_CACHE: dict[tuple, "ContactSurfaceSpace"] = {}
35
+
21
36
 
22
37
  @dataclass(frozen=True)
23
38
  class ContactSide:
@@ -257,7 +272,7 @@ class OneSidedContactSurfaceSpace:
257
272
 
258
273
  def assemble_bilinear(
259
274
  self,
260
- u_hat_fn,
275
+ u_hat_fn: SurfaceHatFn | None,
261
276
  params: "WeakParams",
262
277
  *,
263
278
  u_master: np.ndarray | None = None,
@@ -559,7 +574,7 @@ class ContactSurfaceSpace:
559
574
  setup_cache_trace=setup_cache_trace,
560
575
  )
561
576
 
562
- @classmethod
577
+ @classmethod # type: ignore[no-redef]
563
578
  def from_facets(
564
579
  cls,
565
580
  coords_master: np.ndarray,
@@ -636,7 +651,8 @@ class ContactSurfaceSpace:
636
651
  raise ValueError("backend must be 'jax' or 'numpy'")
637
652
  return use_backend
638
653
 
639
- def assemble_mortar_matrices(self):
654
+ def assemble_mortar_matrices(self) -> tuple["MortarMatrix", "MortarMatrix"]:
655
+ """Return (M_aa, M_ab) mortar coupling matrices."""
640
656
  return assemble_mortar_matrices(
641
657
  self.supermesh_coords,
642
658
  self.supermesh_conn,
@@ -648,13 +664,13 @@ class ContactSurfaceSpace:
648
664
 
649
665
  def assemble_residual(
650
666
  self,
651
- res_form,
652
- u,
653
- params,
667
+ res_form: MixedSurfaceResidualForm,
668
+ u: Mapping[str, npt.ArrayLike] | Sequence[npt.ArrayLike],
669
+ params: "WeakParams",
654
670
  *,
655
671
  normal_sign: float | None = None,
656
672
  normal_source: str = "master",
657
- ):
673
+ ) -> np.ndarray:
658
674
  u_master, u_slave = self._split_fields(u)
659
675
  if normal_sign is None:
660
676
  normal_sign = self.normal_sign
@@ -691,16 +707,16 @@ class ContactSurfaceSpace:
691
707
 
692
708
  def assemble_jacobian(
693
709
  self,
694
- res_form,
695
- u,
696
- params,
710
+ res_form: MixedSurfaceResidualForm,
711
+ u: Mapping[str, npt.ArrayLike] | Sequence[npt.ArrayLike],
712
+ params: "WeakParams",
697
713
  *,
698
714
  normal_sign: float | None = None,
699
715
  normal_source: str = "master",
700
716
  sparse: bool = False,
701
717
  backend: str | None = None,
702
718
  batch_jac: bool | None = None,
703
- ):
719
+ ) -> ContactJacobianReturn:
704
720
  u_master, u_slave = self._split_fields(u)
705
721
  if normal_sign is None:
706
722
  normal_sign = self.normal_sign
@@ -745,14 +761,14 @@ class ContactSurfaceSpace:
745
761
 
746
762
  def assemble_bilinear(
747
763
  self,
748
- bilin,
749
- u_master,
750
- u_slave=None,
751
- params=None,
764
+ bilin: Callable[..., Any],
765
+ u_master: Mapping[str, npt.ArrayLike] | Sequence[npt.ArrayLike] | npt.ArrayLike,
766
+ u_slave: npt.ArrayLike | None = None,
767
+ params: "WeakParams" | None = None,
752
768
  *,
753
769
  sparse: bool = False,
754
770
  normal_source: str = "master",
755
- ):
771
+ ) -> ContactJacobianReturn:
756
772
  """
757
773
  Assemble a mixed surface bilinear form with signature (v1, v2, u1, u2, params).
758
774
 
fluxfem/mesh/io.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import numpy as np
4
4
  import jax
5
5
  import jax.numpy as jnp
6
+ from typing import Optional
6
7
 
7
8
  from .dtypes import NP_INDEX_DTYPE
8
9
 
@@ -12,7 +13,7 @@ try:
12
13
  import meshio
13
14
  except Exception as e: # pragma: no cover
14
15
  meshio = None
15
- meshio_import_error = e
16
+ meshio_import_error: Optional[Exception] = e
16
17
  else:
17
18
  meshio_import_error = None
18
19
 
@@ -36,7 +37,7 @@ def load_gmsh_mesh(path: str):
36
37
  msh = meshio.read(path)
37
38
  coords = np.asarray(msh.points[:, :3], dtype=DTYPE)
38
39
 
39
- mesh = None
40
+ mesh: HexMesh | TetMesh | None = None
40
41
  if "hexahedron" in msh.cells_dict:
41
42
  conn = np.asarray(msh.cells_dict["hexahedron"], dtype=NP_INDEX_DTYPE)
42
43
  mesh = HexMesh(jnp.asarray(coords), jnp.asarray(conn))