fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/weakform.py CHANGED
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable, Iterator, Literal, get_args
4
+ from typing import Any, Callable, Iterator, Literal, Mapping, TypeAlias, cast, get_args
5
5
  import inspect
6
+ from dataclasses import dataclass
7
+ from functools import update_wrapper
6
8
 
7
9
  import numpy as np
8
10
 
@@ -10,7 +12,7 @@ import jax.numpy as jnp
10
12
  import jax
11
13
 
12
14
  from ..physics import operators as _ops
13
- from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
15
+ from .context_types import ArrayLike, FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext, WeakFormContext
14
16
 
15
17
 
16
18
  OpName = Literal[
@@ -45,6 +47,7 @@ OpName = Literal[
45
47
  "einsum",
46
48
  ]
47
49
 
50
+
48
51
  # Use OpName as the single source of truth for valid ops.
49
52
  _OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
50
53
 
@@ -396,6 +399,28 @@ class Params:
396
399
  return cls(**dict(zip(keys, values)))
397
400
 
398
401
 
402
+ class _ZeroField:
403
+ """Field-like object that evaluates to zeros with the same shape."""
404
+
405
+ def __init__(self, base):
406
+ self.N = jnp.zeros_like(base.N)
407
+ self.gradN = None if getattr(base, "gradN", None) is None else jnp.zeros_like(base.gradN)
408
+ self.detJ = getattr(base, "detJ", None)
409
+ self.value_dim = int(getattr(base, "value_dim", 1))
410
+ self.basis = getattr(base, "basis", None)
411
+
412
+
413
+ class _ZeroFieldNp:
414
+ """Numpy variant of a zero-valued field (for numpy backend evaluation)."""
415
+
416
+ def __init__(self, base):
417
+ self.N = np.zeros_like(base.N)
418
+ self.gradN = None if getattr(base, "gradN", None) is None else np.zeros_like(base.gradN)
419
+ self.detJ = getattr(base, "detJ", None)
420
+ self.value_dim = int(getattr(base, "value_dim", 1))
421
+ self.basis = getattr(base, "basis", None)
422
+
423
+
399
424
  def trial_ref(name: str | None = "u") -> FieldRef:
400
425
  """Create a symbolic trial field reference."""
401
426
  return FieldRef(role="trial", name=name)
@@ -416,12 +441,105 @@ def param_ref() -> ParamRef:
416
441
  return ParamRef()
417
442
 
418
443
 
444
+ def zero_ref(name: str) -> FieldRef:
445
+ """Create a zero-valued field reference (shape derived from context)."""
446
+ return FieldRef("zero", name)
447
+
448
+
419
449
  def _eval_field(
420
- obj: Any,
421
- ctx: VolumeContext | SurfaceContext,
450
+ obj: FieldRef,
451
+ ctx: WeakFormContext,
422
452
  params: ParamsLike,
423
453
  ) -> FormFieldLike:
424
454
  if isinstance(obj, FieldRef):
455
+ if obj.role == "zero":
456
+ if obj.name is None:
457
+ raise ValueError("zero_ref requires a named field.")
458
+ base = None
459
+ test_fields = getattr(ctx, "test_fields", None)
460
+ if test_fields is not None and obj.name in test_fields:
461
+ base = test_fields[obj.name]
462
+ trial_fields = getattr(ctx, "trial_fields", None)
463
+ if base is None and trial_fields is not None and obj.name in trial_fields:
464
+ base = trial_fields[obj.name]
465
+ fields = getattr(ctx, "fields", None)
466
+ if base is None and fields is not None and obj.name in fields:
467
+ group = fields[obj.name]
468
+ if hasattr(group, "test"):
469
+ base = group.test
470
+ elif hasattr(group, "trial"):
471
+ base = group.trial
472
+ if base is None:
473
+ raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
474
+ return _ZeroField(base)
475
+ if obj.name is not None:
476
+ mixed_fields = getattr(ctx, "fields", None)
477
+ if mixed_fields is not None and obj.name in mixed_fields:
478
+ group = mixed_fields[obj.name]
479
+ if hasattr(group, "trial") and obj.role == "trial":
480
+ return group.trial
481
+ if hasattr(group, "test") and obj.role == "test":
482
+ return group.test
483
+ if hasattr(group, "unknown") and obj.role == "unknown":
484
+ return group.unknown if group.unknown is not None else group.trial
485
+ trial_fields = getattr(ctx, "trial_fields", None)
486
+ if obj.role == "trial" and trial_fields is not None and obj.name in trial_fields:
487
+ return trial_fields[obj.name]
488
+ test_fields = getattr(ctx, "test_fields", None)
489
+ if obj.role == "test" and test_fields is not None and obj.name in test_fields:
490
+ return test_fields[obj.name]
491
+ unknown_fields = getattr(ctx, "unknown_fields", None)
492
+ if obj.role == "unknown" and unknown_fields is not None and obj.name in unknown_fields:
493
+ return unknown_fields[obj.name]
494
+ fields = getattr(ctx, "fields", None)
495
+ if fields is not None and obj.name in fields:
496
+ group = fields[obj.name]
497
+ if isinstance(group, dict):
498
+ if obj.role in group:
499
+ return group[obj.role]
500
+ if "field" in group:
501
+ return group["field"]
502
+ return group
503
+ if obj.role == "trial":
504
+ return ctx.trial
505
+ if obj.role == "test":
506
+ if hasattr(ctx, "test"):
507
+ return ctx.test
508
+ if hasattr(ctx, "v"):
509
+ return ctx.v
510
+ raise ValueError("Surface context is missing test field.")
511
+ if obj.role == "unknown":
512
+ return getattr(ctx, "unknown", ctx.trial)
513
+ raise ValueError(f"Unknown field role: {obj.role}")
514
+ raise TypeError("Expected a field reference for this operator.")
515
+
516
+
517
+ def _eval_field_np(
518
+ obj: FieldRef,
519
+ ctx: WeakFormContext,
520
+ params: ParamsLike,
521
+ ) -> FormFieldLike:
522
+ if isinstance(obj, FieldRef):
523
+ if obj.role == "zero":
524
+ if obj.name is None:
525
+ raise ValueError("zero_ref requires a named field.")
526
+ base = None
527
+ test_fields = getattr(ctx, "test_fields", None)
528
+ if test_fields is not None and obj.name in test_fields:
529
+ base = test_fields[obj.name]
530
+ trial_fields = getattr(ctx, "trial_fields", None)
531
+ if base is None and trial_fields is not None and obj.name in trial_fields:
532
+ base = trial_fields[obj.name]
533
+ fields = getattr(ctx, "fields", None)
534
+ if base is None and fields is not None and obj.name in fields:
535
+ group = fields[obj.name]
536
+ if hasattr(group, "test"):
537
+ base = group.test
538
+ elif hasattr(group, "trial"):
539
+ base = group.trial
540
+ if base is None:
541
+ raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
542
+ return _ZeroFieldNp(base)
425
543
  if obj.name is not None:
426
544
  mixed_fields = getattr(ctx, "fields", None)
427
545
  if mixed_fields is not None and obj.name in mixed_fields:
@@ -432,15 +550,15 @@ def _eval_field(
432
550
  return group.test
433
551
  if hasattr(group, "unknown") and obj.role == "unknown":
434
552
  return group.unknown if group.unknown is not None else group.trial
435
- if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
436
- if obj.name in ctx.trial_fields:
437
- return ctx.trial_fields[obj.name]
438
- if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
439
- if obj.name in ctx.test_fields:
440
- return ctx.test_fields[obj.name]
441
- if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
442
- if obj.name in ctx.unknown_fields:
443
- return ctx.unknown_fields[obj.name]
553
+ trial_fields = getattr(ctx, "trial_fields", None)
554
+ if obj.role == "trial" and trial_fields is not None and obj.name in trial_fields:
555
+ return trial_fields[obj.name]
556
+ test_fields = getattr(ctx, "test_fields", None)
557
+ if obj.role == "test" and test_fields is not None and obj.name in test_fields:
558
+ return test_fields[obj.name]
559
+ unknown_fields = getattr(ctx, "unknown_fields", None)
560
+ if obj.role == "unknown" and unknown_fields is not None and obj.name in unknown_fields:
561
+ return unknown_fields[obj.name]
444
562
  fields = getattr(ctx, "fields", None)
445
563
  if fields is not None and obj.name in fields:
446
564
  group = fields[obj.name]
@@ -477,7 +595,7 @@ def _eval_field(
477
595
  # return obj
478
596
 
479
597
 
480
- def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
598
+ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement) -> ArrayLike:
481
599
  if u_elem is None:
482
600
  raise ValueError("u_elem is required to evaluate unknown field value.")
483
601
  if isinstance(u_elem, dict):
@@ -489,8 +607,9 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
489
607
 
490
608
 
491
609
  def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
492
- v_field = _eval_field(test, ctx, params)
493
- u_field = _eval_field(trial, ctx, params)
610
+ ctx_w = cast(WeakFormContext, ctx)
611
+ v_field = _eval_field(test, ctx_w, params)
612
+ u_field = _eval_field(trial, ctx_w, params)
494
613
  if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
495
614
  raise ValueError(
496
615
  "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
@@ -498,6 +617,17 @@ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
498
617
  return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
499
618
 
500
619
 
620
+ def _basis_outer_np(test: FieldRef, trial: FieldRef, ctx, params):
621
+ ctx_w = cast(WeakFormContext, ctx)
622
+ v_field = _eval_field_np(test, ctx_w, params)
623
+ u_field = _eval_field_np(trial, ctx_w, params)
624
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
625
+ raise ValueError(
626
+ "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
627
+ )
628
+ return np.einsum("qi,qj->qij", v_field.N, u_field.N)
629
+
630
+
501
631
  def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
502
632
  u_local = _extract_unknown_elem(field_ref, u_elem)
503
633
  value_dim = int(getattr(field, "value_dim", 1))
@@ -507,6 +637,21 @@ def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElem
507
637
  return jnp.einsum("qa,ai->qi", field.N, u_nodes)
508
638
 
509
639
 
640
+ def _eval_unknown_value_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
641
+ u_local = _extract_unknown_elem(field_ref, u_elem)
642
+ value_dim = int(getattr(field, "value_dim", 1))
643
+ u_arr = np.asarray(u_local)
644
+ if value_dim == 1:
645
+ if u_arr.ndim == 2:
646
+ return np.einsum("qa,ab->qb", field.N, u_arr)
647
+ return np.einsum("qa,a->q", field.N, u_arr)
648
+ if u_arr.ndim == 2:
649
+ u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
650
+ return np.einsum("qa,aib->qib", field.N, u_nodes)
651
+ u_nodes = u_arr.reshape((-1, value_dim))
652
+ return np.einsum("qa,ai->qi", field.N, u_nodes)
653
+
654
+
510
655
  def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
511
656
  u_local = _extract_unknown_elem(field_ref, u_elem)
512
657
  if u_local is None:
@@ -518,6 +663,88 @@ def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UEleme
518
663
  return jnp.einsum("qaj,ai->qij", field.gradN, u_nodes)
519
664
 
520
665
 
666
+ def _eval_unknown_grad_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
667
+ u_local = _extract_unknown_elem(field_ref, u_elem)
668
+ if u_local is None:
669
+ raise ValueError("u_elem is required to evaluate unknown field gradient.")
670
+ value_dim = int(getattr(field, "value_dim", 1))
671
+ u_arr = np.asarray(u_local)
672
+ if value_dim == 1:
673
+ if u_arr.ndim == 2:
674
+ return np.einsum("qaj,ab->qjb", field.gradN, u_arr)
675
+ return np.einsum("qaj,a->qj", field.gradN, u_arr)
676
+ if u_arr.ndim == 2:
677
+ u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
678
+ return np.einsum("qaj,aib->qijb", field.gradN, u_nodes)
679
+ u_nodes = u_arr.reshape((-1, value_dim))
680
+ return np.einsum("qaj,ai->qij", field.gradN, u_nodes)
681
+
682
+
683
+ def _vector_load_form_np(field: FormFieldLike, load_vec: ArrayLike) -> np.ndarray:
684
+ lv = np.asarray(load_vec)
685
+ if lv.ndim == 1:
686
+ lv = lv[None, :]
687
+ elif lv.ndim not in (2, 3):
688
+ raise ValueError("load_vec must be shape (dim,), (n_q, dim), or (n_q, dim, batch)")
689
+ if lv.shape[0] == 1:
690
+ lv = np.broadcast_to(lv, (field.N.shape[0], lv.shape[1]))
691
+ elif lv.shape[0] != field.N.shape[0]:
692
+ raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
693
+ if lv.ndim == 3:
694
+ load = field.N[..., None, None] * lv[:, None, :, :]
695
+ return load.reshape(load.shape[0], -1, load.shape[-1])
696
+ load = field.N[..., None] * lv[:, None, :]
697
+ return load.reshape(load.shape[0], -1)
698
+
699
+
700
+ def _sym_grad_np(field) -> np.ndarray:
701
+ gradN = np.asarray(field.gradN)
702
+ dofs = int(getattr(field.basis, "dofs_per_node", 3))
703
+ n_q, n_nodes, _ = gradN.shape
704
+ n_dofs = dofs * n_nodes
705
+ B = np.zeros((n_q, 6, n_dofs), dtype=gradN.dtype)
706
+ for a in range(n_nodes):
707
+ col = dofs * a
708
+ dNdx = gradN[:, a, 0]
709
+ dNdy = gradN[:, a, 1]
710
+ dNdz = gradN[:, a, 2]
711
+ B[:, 0, col + 0] = dNdx
712
+ B[:, 1, col + 1] = dNdy
713
+ B[:, 2, col + 2] = dNdz
714
+ B[:, 3, col + 0] = dNdy
715
+ B[:, 3, col + 1] = dNdx
716
+ B[:, 4, col + 1] = dNdz
717
+ B[:, 4, col + 2] = dNdy
718
+ B[:, 5, col + 0] = dNdz
719
+ B[:, 5, col + 2] = dNdx
720
+ return B
721
+
722
+
723
+ def _sym_grad_u_np(field: FormFieldLike, u_elem: ArrayLike) -> np.ndarray:
724
+ B = _sym_grad_np(field)
725
+ u_arr = np.asarray(u_elem)
726
+ if u_arr.ndim == 2:
727
+ return np.einsum("qik,kb->qib", B, u_arr)
728
+ return np.einsum("qik,k->qi", B, u_arr)
729
+
730
+
731
+ def _ddot_np(a: ArrayLike, b: ArrayLike, c: ArrayLike | None = None) -> np.ndarray:
732
+ if c is None:
733
+ return np.einsum("...ij,...ij->...", a, b)
734
+ a_t = np.swapaxes(a, -1, -2)
735
+ return np.einsum("...ik,kl,...lm->...im", a_t, b, c)
736
+
737
+
738
+ def _dot_np(a: FormFieldLike | ArrayLike, b: ArrayLike) -> np.ndarray:
739
+ if hasattr(a, "N") and getattr(a, "value_dim", None) is not None:
740
+ return _vector_load_form_np(cast(FormFieldLike, a), b)
741
+ return np.matmul(a, b)
742
+
743
+
744
+ def _transpose_last2_np(a: ArrayLike) -> np.ndarray:
745
+ return np.swapaxes(a, -1, -2)
746
+
747
+
521
748
  def grad(field) -> Expr:
522
749
  """Return basis gradients for a scalar or vector FormField."""
523
750
  return Expr("grad", _as_expr(field))
@@ -649,6 +876,62 @@ def _call_user(fn, *args, params):
649
876
  return fn(*args)
650
877
 
651
878
 
879
+ @dataclass(frozen=True)
880
+ class KernelSpec:
881
+ kind: str
882
+ domain: str
883
+
884
+
885
+ class TaggedKernel:
886
+ def __init__(self, fn, spec: KernelSpec):
887
+ self._fn = fn
888
+ self._ff_spec = spec
889
+ self._ff_kind = spec.kind
890
+ self._ff_domain = spec.domain
891
+ update_wrapper(self, fn)
892
+ self.__wrapped__ = fn
893
+
894
+ def __call__(self, *args, **kwargs):
895
+ return self._fn(*args, **kwargs)
896
+
897
+ def __repr__(self) -> str:
898
+ return f"TaggedKernel(kind={self._ff_kind!r}, domain={self._ff_domain!r})"
899
+
900
+ @property
901
+ def spec(self) -> KernelSpec:
902
+ return self._ff_spec
903
+
904
+ @property
905
+ def kind(self) -> str:
906
+ return self._ff_kind
907
+
908
+ @property
909
+ def domain(self) -> str:
910
+ return self._ff_domain
911
+
912
+ def __hash__(self) -> int:
913
+ return hash(self._fn)
914
+
915
+
916
+ def _tag_form(fn, *, kind: str, domain: str):
917
+ spec = KernelSpec(kind=kind, domain=domain)
918
+ fn._ff_spec = spec
919
+ fn._ff_kind = kind
920
+ fn._ff_domain = domain
921
+ return fn
922
+
923
+
924
+ def kernel(*, kind: str, domain: str = "volume"):
925
+ """
926
+ Decorator to tag raw kernels with kind/domain metadata for assembly inference.
927
+ """
928
+ def _deco(fn):
929
+ spec = KernelSpec(kind=kind, domain=domain)
930
+ return TaggedKernel(fn, spec)
931
+
932
+ return _deco
933
+
934
+
652
935
  def compile_bilinear(fn):
653
936
  """get_compiled a bilinear weak form (u, v, params) -> Expr into a kernel."""
654
937
  if isinstance(fn, Expr):
@@ -658,9 +941,10 @@ def compile_bilinear(fn):
658
941
  v = test_ref()
659
942
  p = param_ref()
660
943
  expr = _call_user(fn, u, v, params=p)
661
- expr = _as_expr(expr)
662
- if not isinstance(expr, Expr):
944
+ expr_raw = _as_expr(expr)
945
+ if not isinstance(expr_raw, Expr):
663
946
  raise TypeError("Bilinear form must return an Expr.")
947
+ expr = cast(Expr, expr_raw)
664
948
 
665
949
  volume_count = _count_op(expr, "volume_measure")
666
950
  surface_count = _count_op(expr, "surface_measure")
@@ -676,8 +960,8 @@ def compile_bilinear(fn):
676
960
  def _form(ctx, params):
677
961
  return eval_with_plan(plan, ctx, params)
678
962
 
679
- _form._includes_measure = True
680
- return _form
963
+ _form._includes_measure = True # type: ignore[attr-defined]
964
+ return _tag_form(_form, kind="bilinear", domain="volume")
681
965
 
682
966
 
683
967
  def compile_linear(fn):
@@ -688,9 +972,10 @@ def compile_linear(fn):
688
972
  v = test_ref()
689
973
  p = param_ref()
690
974
  expr = _call_user(fn, v, params=p)
691
- expr = _as_expr(expr)
692
- if not isinstance(expr, Expr):
975
+ expr_raw = _as_expr(expr)
976
+ if not isinstance(expr_raw, Expr):
693
977
  raise TypeError("Linear form must return an Expr.")
978
+ expr = cast(Expr, expr_raw)
694
979
 
695
980
  volume_count = _count_op(expr, "volume_measure")
696
981
  surface_count = _count_op(expr, "surface_measure")
@@ -706,8 +991,8 @@ def compile_linear(fn):
706
991
  def _form(ctx, params):
707
992
  return eval_with_plan(plan, ctx, params)
708
993
 
709
- _form._includes_measure = True
710
- return _form
994
+ _form._includes_measure = True # type: ignore[attr-defined]
995
+ return _tag_form(_form, kind="linear", domain="volume")
711
996
 
712
997
 
713
998
  def _expr_contains(expr: Expr, op: str) -> bool:
@@ -795,6 +1080,7 @@ def eval_with_plan(
795
1080
  nodes = plan.nodes
796
1081
  index = plan.index
797
1082
  vals: list[Any] = [None] * len(nodes)
1083
+ ctx_w = cast(WeakFormContext, ctx)
798
1084
 
799
1085
  def get(obj):
800
1086
  if isinstance(obj, Expr):
@@ -825,7 +1111,7 @@ def eval_with_plan(
825
1111
  if op == "value":
826
1112
  ref = args[0]
827
1113
  assert isinstance(ref, FieldRef)
828
- field = _eval_field(ref, ctx, params)
1114
+ field = _eval_field(ref, ctx_w, params)
829
1115
  if ref.role == "unknown":
830
1116
  vals[i] = _eval_unknown_value(ref, field, u_elem)
831
1117
  else:
@@ -834,7 +1120,7 @@ def eval_with_plan(
834
1120
  if op == "grad":
835
1121
  ref = args[0]
836
1122
  assert isinstance(ref, FieldRef)
837
- field = _eval_field(ref, ctx, params)
1123
+ field = _eval_field(ref, ctx_w, params)
838
1124
  if ref.role == "unknown":
839
1125
  vals[i] = _eval_unknown_grad(ref, field, u_elem)
840
1126
  else:
@@ -879,7 +1165,7 @@ def eval_with_plan(
879
1165
  if op == "sym_grad":
880
1166
  ref = args[0]
881
1167
  assert isinstance(ref, FieldRef)
882
- field = _eval_field(ref, ctx, params)
1168
+ field = _eval_field(ref, ctx_w, params)
883
1169
  if ref.role == "unknown":
884
1170
  if u_elem is None:
885
1171
  raise ValueError("u_elem is required to evaluate unknown sym_grad.")
@@ -943,7 +1229,7 @@ def eval_with_plan(
943
1229
  if op == "dot":
944
1230
  ref = args[0]
945
1231
  if isinstance(ref, FieldRef):
946
- vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
1232
+ vals[i] = _ops.dot(_eval_field(ref, ctx_w, params), get(args[1]))
947
1233
  else:
948
1234
  a = get(args[0])
949
1235
  b = get(args[1])
@@ -961,7 +1247,7 @@ def eval_with_plan(
961
1247
  if op == "sdot":
962
1248
  ref = args[0]
963
1249
  if isinstance(ref, FieldRef):
964
- vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
1250
+ vals[i] = _ops.dot(_eval_field(ref, ctx_w, params), get(args[1]))
965
1251
  else:
966
1252
  a = get(args[0])
967
1253
  b = get(args[1])
@@ -1004,7 +1290,7 @@ def eval_with_plan(
1004
1290
  assert isinstance(ref, FieldRef)
1005
1291
  if isinstance(args[1], FieldRef):
1006
1292
  raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1007
- v_field = _eval_field(ref, ctx, params)
1293
+ v_field = _eval_field(ref, ctx_w, params)
1008
1294
  s = get(args[1])
1009
1295
  value_dim = int(getattr(v_field, "value_dim", 1))
1010
1296
  # action maps a test field with a scalar/vector expression into nodal space.
@@ -1022,7 +1308,7 @@ def eval_with_plan(
1022
1308
  if op == "gaction":
1023
1309
  ref = args[0]
1024
1310
  assert isinstance(ref, FieldRef)
1025
- v_field = _eval_field(ref, ctx, params)
1311
+ v_field = _eval_field(ref, ctx_w, params)
1026
1312
  q = get(args[1])
1027
1313
  # gaction maps a flux-like expression to nodal space via test gradients.
1028
1314
  if v_field.gradN.ndim != 3:
@@ -1043,7 +1329,10 @@ def eval_with_plan(
1043
1329
  continue
1044
1330
  if op == "einsum":
1045
1331
  subscripts = args[0]
1046
- operands = [get(arg) for arg in args[1:]]
1332
+ operands = [
1333
+ (jnp.asarray(arg) if isinstance(arg, tuple) else arg)
1334
+ for arg in (get(arg) for arg in args[1:])
1335
+ ]
1047
1336
  vals[i] = jnp.einsum(subscripts, *operands)
1048
1337
  continue
1049
1338
 
@@ -1052,6 +1341,294 @@ def eval_with_plan(
1052
1341
  return vals[index[id(plan.expr)]]
1053
1342
 
1054
1343
 
1344
+ def eval_with_plan_numpy(
1345
+ plan: EvalPlan,
1346
+ ctx: VolumeContext | SurfaceContext,
1347
+ params: ParamsLike,
1348
+ u_elem: UElement | None = None,
1349
+ ):
1350
+ nodes = plan.nodes
1351
+ index = plan.index
1352
+ vals: list[Any] = [None] * len(nodes)
1353
+ ctx_w = cast(WeakFormContext, ctx)
1354
+
1355
+ def get(obj):
1356
+ if isinstance(obj, Expr):
1357
+ return vals[index[id(obj)]]
1358
+ if isinstance(obj, FieldRef):
1359
+ raise TypeError(
1360
+ "FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
1361
+ )
1362
+ if isinstance(obj, ParamRef):
1363
+ return params
1364
+ return obj
1365
+
1366
+ for i, node in enumerate(nodes):
1367
+ op = node.op
1368
+ args = node.args
1369
+
1370
+ if op == "lit":
1371
+ vals[i] = args[0]
1372
+ continue
1373
+ if op == "getattr":
1374
+ base = get(args[0])
1375
+ name = args[1]
1376
+ if isinstance(base, dict):
1377
+ vals[i] = base[name]
1378
+ else:
1379
+ vals[i] = getattr(base, name)
1380
+ continue
1381
+ if op == "value":
1382
+ ref = args[0]
1383
+ assert isinstance(ref, FieldRef)
1384
+ field = _eval_field_np(ref, ctx_w, params)
1385
+ if ref.role == "unknown":
1386
+ vals[i] = _eval_unknown_value_np(ref, field, u_elem)
1387
+ else:
1388
+ vals[i] = field.N
1389
+ continue
1390
+ if op == "grad":
1391
+ ref = args[0]
1392
+ assert isinstance(ref, FieldRef)
1393
+ field = _eval_field_np(ref, ctx_w, params)
1394
+ if ref.role == "unknown":
1395
+ vals[i] = _eval_unknown_grad_np(ref, field, u_elem)
1396
+ else:
1397
+ vals[i] = field.gradN
1398
+ continue
1399
+ if op == "pow":
1400
+ base = get(args[0])
1401
+ exp = get(args[1])
1402
+ vals[i] = base**exp
1403
+ continue
1404
+ if op == "eye":
1405
+ vals[i] = np.eye(int(args[0]))
1406
+ continue
1407
+ if op == "det":
1408
+ vals[i] = np.linalg.det(get(args[0]))
1409
+ continue
1410
+ if op == "inv":
1411
+ vals[i] = np.linalg.inv(get(args[0]))
1412
+ continue
1413
+ if op == "transpose":
1414
+ vals[i] = np.swapaxes(get(args[0]), -1, -2)
1415
+ continue
1416
+ if op == "log":
1417
+ vals[i] = np.log(get(args[0]))
1418
+ continue
1419
+ if op == "surface_normal":
1420
+ normal = getattr(ctx, "normal", None)
1421
+ if normal is None:
1422
+ raise ValueError("surface normal is not available in context")
1423
+ vals[i] = normal
1424
+ continue
1425
+ if op == "surface_measure":
1426
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
1427
+ raise TypeError("surface measure requires SurfaceContext.")
1428
+ vals[i] = ctx.w * ctx.detJ
1429
+ continue
1430
+ if op == "volume_measure":
1431
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
1432
+ raise TypeError("volume measure requires VolumeContext.")
1433
+ vals[i] = ctx.w * ctx.test.detJ
1434
+ continue
1435
+ if op == "sym_grad":
1436
+ ref = args[0]
1437
+ assert isinstance(ref, FieldRef)
1438
+ field = _eval_field_np(ref, ctx_w, params)
1439
+ if ref.role == "unknown":
1440
+ if u_elem is None:
1441
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
1442
+ u_local = _extract_unknown_elem(ref, u_elem)
1443
+ vals[i] = _sym_grad_u_np(field, u_local)
1444
+ else:
1445
+ vals[i] = _sym_grad_np(field)
1446
+ continue
1447
+ if op == "outer":
1448
+ a, b = args
1449
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
1450
+ raise TypeError("outer expects FieldRef operands.")
1451
+ test, trial = a, b
1452
+ vals[i] = _basis_outer_np(test, trial, ctx, params)
1453
+ continue
1454
+ if op == "add":
1455
+ vals[i] = get(args[0]) + get(args[1])
1456
+ continue
1457
+ if op == "sub":
1458
+ vals[i] = get(args[0]) - get(args[1])
1459
+ continue
1460
+ if op == "mul":
1461
+ a = get(args[0])
1462
+ b = get(args[1])
1463
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
1464
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
1465
+ a = a[:, None]
1466
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
1467
+ b = b[:, None]
1468
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
1469
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
1470
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
1471
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
1472
+ vals[i] = a * b
1473
+ continue
1474
+ if op == "matmul":
1475
+ a = get(args[0])
1476
+ b = get(args[1])
1477
+ if (
1478
+ hasattr(a, "ndim")
1479
+ and hasattr(b, "ndim")
1480
+ and a.ndim == 3
1481
+ and b.ndim == 3
1482
+ and a.shape[0] == b.shape[0]
1483
+ and a.shape[-1] == b.shape[-1]
1484
+ ):
1485
+ vals[i] = np.einsum("qia,qja->qij", a, b)
1486
+ else:
1487
+ raise TypeError(
1488
+ "Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
1489
+ )
1490
+ continue
1491
+ if op == "matmul_std":
1492
+ a = get(args[0])
1493
+ b = get(args[1])
1494
+ vals[i] = np.matmul(a, b)
1495
+ continue
1496
+ if op == "neg":
1497
+ vals[i] = -get(args[0])
1498
+ continue
1499
+ if op == "dot":
1500
+ ref = args[0]
1501
+ if isinstance(ref, FieldRef):
1502
+ vals[i] = _dot_np(_eval_field_np(ref, ctx_w, params), get(args[1]))
1503
+ else:
1504
+ a = get(args[0])
1505
+ b = get(args[1])
1506
+ if (
1507
+ hasattr(a, "ndim")
1508
+ and hasattr(b, "ndim")
1509
+ and a.ndim >= 3
1510
+ and b.ndim >= 3
1511
+ and a.shape[0] == b.shape[0]
1512
+ and a.shape[1] == b.shape[1]
1513
+ ):
1514
+ vals[i] = np.einsum("qi...,qj...->qij...", a, b)
1515
+ else:
1516
+ vals[i] = np.matmul(a, b)
1517
+ continue
1518
+ if op == "sdot":
1519
+ ref = args[0]
1520
+ if isinstance(ref, FieldRef):
1521
+ vals[i] = _dot_np(_eval_field_np(ref, ctx_w, params), get(args[1]))
1522
+ else:
1523
+ a = get(args[0])
1524
+ b = get(args[1])
1525
+ if (
1526
+ hasattr(a, "ndim")
1527
+ and hasattr(b, "ndim")
1528
+ and a.ndim >= 3
1529
+ and b.ndim >= 3
1530
+ and a.shape[0] == b.shape[0]
1531
+ and a.shape[1] == b.shape[1]
1532
+ ):
1533
+ vals[i] = np.einsum("qi...,qj...->qij...", a, b)
1534
+ else:
1535
+ vals[i] = np.matmul(a, b)
1536
+ continue
1537
+ if op == "ddot":
1538
+ if len(args) == 2:
1539
+ a = get(args[0])
1540
+ b = get(args[1])
1541
+ if (
1542
+ hasattr(a, "ndim")
1543
+ and hasattr(b, "ndim")
1544
+ and a.ndim == 3
1545
+ and b.ndim == 3
1546
+ and a.shape[0] == b.shape[0]
1547
+ and a.shape[1] == b.shape[1]
1548
+ ):
1549
+ vals[i] = np.einsum("qik,qim->qkm", a, b)
1550
+ else:
1551
+ vals[i] = _ddot_np(a, b)
1552
+ else:
1553
+ vals[i] = _ddot_np(get(args[0]), get(args[1]), get(args[2]))
1554
+ continue
1555
+ if op == "inner":
1556
+ a = get(args[0])
1557
+ b = get(args[1])
1558
+ vals[i] = np.einsum("...i,...i->...", a, b)
1559
+ continue
1560
+ if op == "action":
1561
+ ref = args[0]
1562
+ assert isinstance(ref, FieldRef)
1563
+ if isinstance(args[1], FieldRef):
1564
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1565
+ v_field = _eval_field_np(ref, ctx_w, params)
1566
+ s = get(args[1])
1567
+ value_dim = int(getattr(v_field, "value_dim", 1))
1568
+ if value_dim == 1:
1569
+ if v_field.N.ndim != 2:
1570
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
1571
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
1572
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
1573
+ vals[i] = v_field.N * s
1574
+ else:
1575
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
1576
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
1577
+ vals[i] = _dot_np(v_field, s)
1578
+ continue
1579
+ if op == "gaction":
1580
+ ref = args[0]
1581
+ assert isinstance(ref, FieldRef)
1582
+ v_field = _eval_field_np(ref, ctx_w, params)
1583
+ q = get(args[1])
1584
+ if v_field.gradN.ndim != 3:
1585
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
1586
+ if not hasattr(q, "ndim"):
1587
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1588
+ if q.ndim == 2:
1589
+ vals[i] = np.einsum("qaj,qj->qa", v_field.gradN, q)
1590
+ elif q.ndim == 3:
1591
+ if int(getattr(v_field, "value_dim", 1)) == 1:
1592
+ raise ValueError("gaction tensor flux requires vector test field.")
1593
+ vals[i] = np.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
1594
+ else:
1595
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1596
+ continue
1597
+ if op == "transpose_last2":
1598
+ vals[i] = _transpose_last2_np(get(args[0]))
1599
+ continue
1600
+ if op == "einsum":
1601
+ subscripts = args[0]
1602
+ operands = [
1603
+ (np.asarray(arg) if isinstance(arg, tuple) else arg)
1604
+ for arg in (get(arg) for arg in args[1:])
1605
+ ]
1606
+ if "..." not in subscripts:
1607
+ has_extra = False
1608
+ parts = subscripts.split("->")
1609
+ in_terms = parts[0].split(",")
1610
+ out_term = parts[1] if len(parts) > 1 else None
1611
+ updated_terms = []
1612
+ for term, opnd in zip(in_terms, operands):
1613
+ if hasattr(opnd, "ndim") and opnd.ndim > len(term):
1614
+ has_extra = True
1615
+ updated_terms.append(term + "...")
1616
+ else:
1617
+ updated_terms.append(term)
1618
+ if has_extra:
1619
+ if out_term is not None:
1620
+ out_term = out_term + "..."
1621
+ subscripts = ",".join(updated_terms) + "->" + out_term
1622
+ else:
1623
+ subscripts = ",".join(updated_terms)
1624
+ vals[i] = np.einsum(subscripts, *operands)
1625
+ continue
1626
+
1627
+ raise ValueError(f"Unknown Expr op: {op}")
1628
+
1629
+ return vals[index[id(plan.expr)]]
1630
+
1631
+
1055
1632
  def compile_surface_linear(fn):
1056
1633
  """get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
1057
1634
  if isinstance(fn, Expr):
@@ -1061,9 +1638,10 @@ def compile_surface_linear(fn):
1061
1638
  p = param_ref()
1062
1639
  expr = _call_user(fn, v, params=p)
1063
1640
 
1064
- expr = _as_expr(expr)
1065
- if not isinstance(expr, Expr):
1641
+ expr_raw = _as_expr(expr)
1642
+ if not isinstance(expr_raw, Expr):
1066
1643
  raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
1644
+ expr = cast(Expr, expr_raw)
1067
1645
 
1068
1646
  surface_count = _count_op(expr, "surface_measure")
1069
1647
  volume_count = _count_op(expr, "volume_measure")
@@ -1080,7 +1658,40 @@ def compile_surface_linear(fn):
1080
1658
  return eval_with_plan(plan, ctx, params)
1081
1659
 
1082
1660
  _form._includes_measure = True # type: ignore[attr-defined]
1083
- return _form
1661
+ return _tag_form(_form, kind="linear", domain="surface")
1662
+
1663
+
1664
+ def compile_surface_bilinear(fn):
1665
+ """get_compiled a surface bilinear form into a kernel (ctx, params) -> ndarray."""
1666
+ if isinstance(fn, Expr):
1667
+ expr = fn
1668
+ else:
1669
+ v = test_ref()
1670
+ u = trial_ref()
1671
+ p = param_ref()
1672
+ expr = _call_user(fn, u, v, params=p)
1673
+
1674
+ expr_raw = _as_expr(expr)
1675
+ if not isinstance(expr_raw, Expr):
1676
+ raise ValueError("Surface bilinear form must return an Expr; use ds() in the expression.")
1677
+ expr = cast(Expr, expr_raw)
1678
+
1679
+ surface_count = _count_op(expr, "surface_measure")
1680
+ volume_count = _count_op(expr, "volume_measure")
1681
+ if surface_count == 0:
1682
+ raise ValueError("Surface bilinear form must include ds().")
1683
+ if surface_count > 1:
1684
+ raise ValueError("Surface bilinear form must include ds() exactly once.")
1685
+ if volume_count > 0:
1686
+ raise ValueError("Surface bilinear form must not include dOmega().")
1687
+
1688
+ plan = make_eval_plan(expr)
1689
+
1690
+ def _form(ctx, params):
1691
+ return eval_with_plan(plan, ctx, params)
1692
+
1693
+ _form._includes_measure = True # type: ignore[attr-defined]
1694
+ return _tag_form(_form, kind="bilinear", domain="surface")
1084
1695
 
1085
1696
 
1086
1697
  class LinearForm:
@@ -1144,9 +1755,10 @@ def compile_residual(fn):
1144
1755
  u = unknown_ref()
1145
1756
  p = param_ref()
1146
1757
  expr = _call_user(fn, v, u, params=p)
1147
- expr = _as_expr(expr)
1148
- if not isinstance(expr, Expr):
1758
+ expr_raw = _as_expr(expr)
1759
+ if not isinstance(expr_raw, Expr):
1149
1760
  raise TypeError("Residual form must return an Expr.")
1761
+ expr = cast(Expr, expr_raw)
1150
1762
 
1151
1763
  volume_count = _count_op(expr, "volume_measure")
1152
1764
  surface_count = _count_op(expr, "surface_measure")
@@ -1162,8 +1774,8 @@ def compile_residual(fn):
1162
1774
  def _form(ctx, u_elem, params):
1163
1775
  return eval_with_plan(plan, ctx, params, u_elem=u_elem)
1164
1776
 
1165
- _form._includes_measure = True
1166
- return _form
1777
+ _form._includes_measure = True # type: ignore[attr-defined]
1778
+ return _tag_form(_form, kind="residual", domain="volume")
1167
1779
 
1168
1780
 
1169
1781
  def compile_mixed_residual(residuals: dict[str, Callable]):
@@ -1194,11 +1806,158 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
1194
1806
  if surface_count > 0:
1195
1807
  raise ValueError(f"Mixed residual '{name}' must not include ds().")
1196
1808
 
1809
+ class _MixedContextView:
1810
+ def __init__(self, ctx, field_name: str):
1811
+ self._ctx = ctx
1812
+ self.fields = ctx.fields
1813
+ self.x_q = ctx.x_q
1814
+ self.w = ctx.w
1815
+ self.elem_id = ctx.elem_id
1816
+ self.trial_fields = ctx.trial_fields
1817
+ self.test_fields = ctx.test_fields
1818
+ self.unknown_fields = ctx.unknown_fields
1819
+ self.unknown = ctx.unknown
1820
+
1821
+ pair = ctx.fields[field_name]
1822
+ self.test = pair.test
1823
+ self.trial = pair.trial
1824
+ self.v = pair.test
1825
+ self.u = pair.trial
1826
+
1827
+ if hasattr(ctx, "normal"):
1828
+ self.normal = ctx.normal
1829
+
1830
+ def __getattr__(self, name: str):
1831
+ return getattr(self._ctx, name)
1832
+
1833
+ def _form(ctx, u_elem, params):
1834
+ return {
1835
+ name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1836
+ for name, plan in plans.items()
1837
+ }
1838
+
1839
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
1840
+ return _tag_form(_form, kind="residual", domain="volume")
1841
+
1842
+
1843
+ def compile_mixed_surface_residual(residuals: dict[str, Callable]):
1844
+ """get_compiled mixed surface residuals keyed by field name."""
1845
+ compiled = {}
1846
+ plans = {}
1847
+ includes_measure = {}
1848
+ for name, fn in residuals.items():
1849
+ if isinstance(fn, Expr):
1850
+ expr = fn
1851
+ else:
1852
+ v = test_ref(name)
1853
+ u = unknown_ref(name)
1854
+ p = param_ref()
1855
+ expr = _call_user(fn, v, u, params=p)
1856
+ expr = _as_expr(expr)
1857
+ if not isinstance(expr, Expr):
1858
+ raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
1859
+ compiled[name] = expr
1860
+ plans[name] = make_eval_plan(expr)
1861
+ volume_count = _count_op(compiled[name], "volume_measure")
1862
+ surface_count = _count_op(compiled[name], "surface_measure")
1863
+ includes_measure[name] = surface_count == 1
1864
+ if surface_count == 0:
1865
+ raise ValueError(f"Mixed surface residual '{name}' must include ds().")
1866
+ if surface_count > 1:
1867
+ raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
1868
+ if volume_count > 0:
1869
+ raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
1870
+
1871
+ class _MixedContextView:
1872
+ def __init__(self, ctx, field_name: str):
1873
+ self._ctx = ctx
1874
+ self.fields = ctx.fields
1875
+ self.x_q = ctx.x_q
1876
+ self.w = ctx.w
1877
+ self.detJ = ctx.detJ
1878
+ self.normal = getattr(ctx, "normal", None)
1879
+ self.trial_fields = ctx.trial_fields
1880
+ self.test_fields = ctx.test_fields
1881
+ self.unknown_fields = ctx.unknown_fields
1882
+ self.unknown = getattr(ctx, "unknown", None)
1883
+
1884
+ pair = ctx.fields[field_name]
1885
+ self.test = pair.test
1886
+ self.trial = pair.trial
1887
+ self.v = pair.test
1888
+ self.u = pair.trial
1889
+
1890
+ def __getattr__(self, name: str):
1891
+ return getattr(self._ctx, name)
1892
+
1893
+ def _form(ctx, u_elem, params):
1894
+ return {
1895
+ name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1896
+ for name, plan in plans.items()
1897
+ }
1898
+
1899
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
1900
+ return _tag_form(_form, kind="residual", domain="surface")
1901
+
1902
+
1903
+ def compile_mixed_surface_residual_numpy(residuals: dict[str, Callable]):
1904
+ """Mixed surface residual compiled for numpy evaluation."""
1905
+ compiled = {}
1906
+ plans = {}
1907
+ includes_measure = {}
1908
+ for name, fn in residuals.items():
1909
+ if isinstance(fn, Expr):
1910
+ expr = fn
1911
+ else:
1912
+ v = test_ref(name)
1913
+ u = unknown_ref(name)
1914
+ p = param_ref()
1915
+ expr = _call_user(fn, v, u, params=p)
1916
+ expr = _as_expr(expr)
1917
+ if not isinstance(expr, Expr):
1918
+ raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
1919
+ compiled[name] = expr
1920
+ plans[name] = make_eval_plan(expr)
1921
+ volume_count = _count_op(compiled[name], "volume_measure")
1922
+ surface_count = _count_op(compiled[name], "surface_measure")
1923
+ includes_measure[name] = surface_count == 1
1924
+ if surface_count == 0:
1925
+ raise ValueError(f"Mixed surface residual '{name}' must include ds().")
1926
+ if surface_count > 1:
1927
+ raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
1928
+ if volume_count > 0:
1929
+ raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
1930
+
1931
+ class _MixedContextView:
1932
+ def __init__(self, ctx, field_name: str):
1933
+ self._ctx = ctx
1934
+ self.fields = ctx.fields
1935
+ self.x_q = ctx.x_q
1936
+ self.w = ctx.w
1937
+ self.detJ = ctx.detJ
1938
+ self.normal = getattr(ctx, "normal", None)
1939
+ self.trial_fields = ctx.trial_fields
1940
+ self.test_fields = ctx.test_fields
1941
+ self.unknown_fields = ctx.unknown_fields
1942
+ self.unknown = getattr(ctx, "unknown", None)
1943
+
1944
+ pair = ctx.fields[field_name]
1945
+ self.test = pair.test
1946
+ self.trial = pair.trial
1947
+ self.v = pair.test
1948
+ self.u = pair.trial
1949
+
1950
+ def __getattr__(self, name: str):
1951
+ return getattr(self._ctx, name)
1952
+
1197
1953
  def _form(ctx, u_elem, params):
1198
- return {name: eval_with_plan(plan, ctx, params, u_elem=u_elem) for name, plan in plans.items()}
1954
+ return {
1955
+ name: eval_with_plan_numpy(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1956
+ for name, plan in plans.items()
1957
+ }
1199
1958
 
1200
- _form._includes_measure = includes_measure
1201
- return _form
1959
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
1960
+ return _tag_form(_form, kind="residual", domain="surface")
1202
1961
 
1203
1962
 
1204
1963
  class MixedWeakForm:
@@ -1213,6 +1972,20 @@ class MixedWeakForm:
1213
1972
  return compile_mixed_residual(self.residuals)
1214
1973
 
1215
1974
 
1975
+ def make_mixed_residuals(residuals: dict[str, Callable] | None = None, **kwargs) -> dict[str, Callable]:
1976
+ """
1977
+ Helper to build mixed residual dictionaries.
1978
+
1979
+ Example:
1980
+ res = make_mixed_residuals(u=res_u, p=res_p)
1981
+ """
1982
+ if residuals is not None and kwargs:
1983
+ raise ValueError("Pass either residuals dict or keyword residuals, not both.")
1984
+ if residuals is None:
1985
+ return dict(kwargs)
1986
+ return dict(residuals)
1987
+
1988
+
1216
1989
  def _eval_expr(
1217
1990
  expr: Expr,
1218
1991
  ctx: VolumeContext | SurfaceContext,
@@ -1230,13 +2003,19 @@ __all__ = [
1230
2003
  "trial_ref",
1231
2004
  "test_ref",
1232
2005
  "unknown_ref",
2006
+ "zero_ref",
1233
2007
  "param_ref",
1234
2008
  "Params",
1235
2009
  "MixedWeakForm",
2010
+ "make_mixed_residuals",
2011
+ "kernel",
1236
2012
  "ResidualForm",
1237
2013
  "compile_bilinear",
1238
2014
  "compile_linear",
1239
2015
  "compile_residual",
2016
+ "compile_surface_bilinear",
2017
+ "compile_mixed_surface_residual",
2018
+ "compile_mixed_surface_residual_numpy",
1240
2019
  "compile_mixed_residual",
1241
2020
  "grad",
1242
2021
  "sym_grad",