fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/weakform.py CHANGED
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, Callable, Iterator, Literal, get_args
5
5
  import inspect
6
+ from dataclasses import dataclass
7
+ from functools import update_wrapper
6
8
 
7
9
  import numpy as np
8
10
 
@@ -396,6 +398,26 @@ class Params:
396
398
  return cls(**dict(zip(keys, values)))
397
399
 
398
400
 
401
+ class _ZeroField:
402
+ """Field-like object that evaluates to zeros with the same shape."""
403
+
404
+ def __init__(self, base):
405
+ self.N = jnp.zeros_like(base.N)
406
+ self.gradN = None if getattr(base, "gradN", None) is None else jnp.zeros_like(base.gradN)
407
+ self.value_dim = int(getattr(base, "value_dim", 1))
408
+ self.basis = getattr(base, "basis", None)
409
+
410
+
411
+ class _ZeroFieldNp:
412
+ """Numpy variant of a zero-valued field (for numpy backend evaluation)."""
413
+
414
+ def __init__(self, base):
415
+ self.N = np.zeros_like(base.N)
416
+ self.gradN = None if getattr(base, "gradN", None) is None else np.zeros_like(base.gradN)
417
+ self.value_dim = int(getattr(base, "value_dim", 1))
418
+ self.basis = getattr(base, "basis", None)
419
+
420
+
399
421
  def trial_ref(name: str | None = "u") -> FieldRef:
400
422
  """Create a symbolic trial field reference."""
401
423
  return FieldRef(role="trial", name=name)
@@ -416,12 +438,99 @@ def param_ref() -> ParamRef:
416
438
  return ParamRef()
417
439
 
418
440
 
441
+ def zero_ref(name: str) -> FieldRef:
442
+ """Create a zero-valued field reference (shape derived from context)."""
443
+ return FieldRef("zero", name)
444
+
445
+
419
446
  def _eval_field(
420
447
  obj: Any,
421
448
  ctx: VolumeContext | SurfaceContext,
422
449
  params: ParamsLike,
423
450
  ) -> FormFieldLike:
424
451
  if isinstance(obj, FieldRef):
452
+ if obj.role == "zero":
453
+ if obj.name is None:
454
+ raise ValueError("zero_ref requires a named field.")
455
+ base = None
456
+ if getattr(ctx, "test_fields", None) is not None and obj.name in ctx.test_fields:
457
+ base = ctx.test_fields[obj.name]
458
+ if base is None and getattr(ctx, "trial_fields", None) is not None and obj.name in ctx.trial_fields:
459
+ base = ctx.trial_fields[obj.name]
460
+ if base is None and getattr(ctx, "fields", None) is not None and obj.name in ctx.fields:
461
+ group = ctx.fields[obj.name]
462
+ if hasattr(group, "test"):
463
+ base = group.test
464
+ elif hasattr(group, "trial"):
465
+ base = group.trial
466
+ if base is None:
467
+ raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
468
+ return _ZeroField(base)
469
+ if obj.name is not None:
470
+ mixed_fields = getattr(ctx, "fields", None)
471
+ if mixed_fields is not None and obj.name in mixed_fields:
472
+ group = mixed_fields[obj.name]
473
+ if hasattr(group, "trial") and obj.role == "trial":
474
+ return group.trial
475
+ if hasattr(group, "test") and obj.role == "test":
476
+ return group.test
477
+ if hasattr(group, "unknown") and obj.role == "unknown":
478
+ return group.unknown if group.unknown is not None else group.trial
479
+ if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
480
+ if obj.name in ctx.trial_fields:
481
+ return ctx.trial_fields[obj.name]
482
+ if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
483
+ if obj.name in ctx.test_fields:
484
+ return ctx.test_fields[obj.name]
485
+ if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
486
+ if obj.name in ctx.unknown_fields:
487
+ return ctx.unknown_fields[obj.name]
488
+ fields = getattr(ctx, "fields", None)
489
+ if fields is not None and obj.name in fields:
490
+ group = fields[obj.name]
491
+ if isinstance(group, dict):
492
+ if obj.role in group:
493
+ return group[obj.role]
494
+ if "field" in group:
495
+ return group["field"]
496
+ return group
497
+ if obj.role == "trial":
498
+ return ctx.trial
499
+ if obj.role == "test":
500
+ if hasattr(ctx, "test"):
501
+ return ctx.test
502
+ if hasattr(ctx, "v"):
503
+ return ctx.v
504
+ raise ValueError("Surface context is missing test field.")
505
+ if obj.role == "unknown":
506
+ return getattr(ctx, "unknown", ctx.trial)
507
+ raise ValueError(f"Unknown field role: {obj.role}")
508
+ raise TypeError("Expected a field reference for this operator.")
509
+
510
+
511
+ def _eval_field_np(
512
+ obj: Any,
513
+ ctx: VolumeContext | SurfaceContext,
514
+ params: ParamsLike,
515
+ ) -> FormFieldLike:
516
+ if isinstance(obj, FieldRef):
517
+ if obj.role == "zero":
518
+ if obj.name is None:
519
+ raise ValueError("zero_ref requires a named field.")
520
+ base = None
521
+ if getattr(ctx, "test_fields", None) is not None and obj.name in ctx.test_fields:
522
+ base = ctx.test_fields[obj.name]
523
+ if base is None and getattr(ctx, "trial_fields", None) is not None and obj.name in ctx.trial_fields:
524
+ base = ctx.trial_fields[obj.name]
525
+ if base is None and getattr(ctx, "fields", None) is not None and obj.name in ctx.fields:
526
+ group = ctx.fields[obj.name]
527
+ if hasattr(group, "test"):
528
+ base = group.test
529
+ elif hasattr(group, "trial"):
530
+ base = group.trial
531
+ if base is None:
532
+ raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
533
+ return _ZeroFieldNp(base)
425
534
  if obj.name is not None:
426
535
  mixed_fields = getattr(ctx, "fields", None)
427
536
  if mixed_fields is not None and obj.name in mixed_fields:
@@ -498,6 +607,16 @@ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
498
607
  return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
499
608
 
500
609
 
610
+ def _basis_outer_np(test: FieldRef, trial: FieldRef, ctx, params):
611
+ v_field = _eval_field_np(test, ctx, params)
612
+ u_field = _eval_field_np(trial, ctx, params)
613
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
614
+ raise ValueError(
615
+ "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
616
+ )
617
+ return np.einsum("qi,qj->qij", v_field.N, u_field.N)
618
+
619
+
501
620
  def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
502
621
  u_local = _extract_unknown_elem(field_ref, u_elem)
503
622
  value_dim = int(getattr(field, "value_dim", 1))
@@ -507,6 +626,21 @@ def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElem
507
626
  return jnp.einsum("qa,ai->qi", field.N, u_nodes)
508
627
 
509
628
 
629
+ def _eval_unknown_value_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
630
+ u_local = _extract_unknown_elem(field_ref, u_elem)
631
+ value_dim = int(getattr(field, "value_dim", 1))
632
+ u_arr = np.asarray(u_local)
633
+ if value_dim == 1:
634
+ if u_arr.ndim == 2:
635
+ return np.einsum("qa,ab->qb", field.N, u_arr)
636
+ return np.einsum("qa,a->q", field.N, u_arr)
637
+ if u_arr.ndim == 2:
638
+ u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
639
+ return np.einsum("qa,aib->qib", field.N, u_nodes)
640
+ u_nodes = u_arr.reshape((-1, value_dim))
641
+ return np.einsum("qa,ai->qi", field.N, u_nodes)
642
+
643
+
510
644
  def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
511
645
  u_local = _extract_unknown_elem(field_ref, u_elem)
512
646
  if u_local is None:
@@ -518,6 +652,88 @@ def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UEleme
518
652
  return jnp.einsum("qaj,ai->qij", field.gradN, u_nodes)
519
653
 
520
654
 
655
+ def _eval_unknown_grad_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
656
+ u_local = _extract_unknown_elem(field_ref, u_elem)
657
+ if u_local is None:
658
+ raise ValueError("u_elem is required to evaluate unknown field gradient.")
659
+ value_dim = int(getattr(field, "value_dim", 1))
660
+ u_arr = np.asarray(u_local)
661
+ if value_dim == 1:
662
+ if u_arr.ndim == 2:
663
+ return np.einsum("qaj,ab->qjb", field.gradN, u_arr)
664
+ return np.einsum("qaj,a->qj", field.gradN, u_arr)
665
+ if u_arr.ndim == 2:
666
+ u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
667
+ return np.einsum("qaj,aib->qijb", field.gradN, u_nodes)
668
+ u_nodes = u_arr.reshape((-1, value_dim))
669
+ return np.einsum("qaj,ai->qij", field.gradN, u_nodes)
670
+
671
+
672
+ def _vector_load_form_np(field: Any, load_vec: Any) -> np.ndarray:
673
+ lv = np.asarray(load_vec)
674
+ if lv.ndim == 1:
675
+ lv = lv[None, :]
676
+ elif lv.ndim not in (2, 3):
677
+ raise ValueError("load_vec must be shape (dim,), (n_q, dim), or (n_q, dim, batch)")
678
+ if lv.shape[0] == 1:
679
+ lv = np.broadcast_to(lv, (field.N.shape[0], lv.shape[1]))
680
+ elif lv.shape[0] != field.N.shape[0]:
681
+ raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
682
+ if lv.ndim == 3:
683
+ load = field.N[..., None, None] * lv[:, None, :, :]
684
+ return load.reshape(load.shape[0], -1, load.shape[-1])
685
+ load = field.N[..., None] * lv[:, None, :]
686
+ return load.reshape(load.shape[0], -1)
687
+
688
+
689
+ def _sym_grad_np(field) -> np.ndarray:
690
+ gradN = np.asarray(field.gradN)
691
+ dofs = int(getattr(field.basis, "dofs_per_node", 3))
692
+ n_q, n_nodes, _ = gradN.shape
693
+ n_dofs = dofs * n_nodes
694
+ B = np.zeros((n_q, 6, n_dofs), dtype=gradN.dtype)
695
+ for a in range(n_nodes):
696
+ col = dofs * a
697
+ dNdx = gradN[:, a, 0]
698
+ dNdy = gradN[:, a, 1]
699
+ dNdz = gradN[:, a, 2]
700
+ B[:, 0, col + 0] = dNdx
701
+ B[:, 1, col + 1] = dNdy
702
+ B[:, 2, col + 2] = dNdz
703
+ B[:, 3, col + 0] = dNdy
704
+ B[:, 3, col + 1] = dNdx
705
+ B[:, 4, col + 1] = dNdz
706
+ B[:, 4, col + 2] = dNdy
707
+ B[:, 5, col + 0] = dNdz
708
+ B[:, 5, col + 2] = dNdx
709
+ return B
710
+
711
+
712
+ def _sym_grad_u_np(field, u_elem: Any) -> np.ndarray:
713
+ B = _sym_grad_np(field)
714
+ u_arr = np.asarray(u_elem)
715
+ if u_arr.ndim == 2:
716
+ return np.einsum("qik,kb->qib", B, u_arr)
717
+ return np.einsum("qik,k->qi", B, u_arr)
718
+
719
+
720
+ def _ddot_np(a: Any, b: Any, c: Any | None = None) -> np.ndarray:
721
+ if c is None:
722
+ return np.einsum("...ij,...ij->...", a, b)
723
+ a_t = np.swapaxes(a, -1, -2)
724
+ return np.einsum("...ik,kl,...lm->...im", a_t, b, c)
725
+
726
+
727
+ def _dot_np(a: Any, b: Any) -> np.ndarray:
728
+ if hasattr(a, "N") and getattr(a, "value_dim", None) is not None:
729
+ return _vector_load_form_np(a, b)
730
+ return np.matmul(a, b)
731
+
732
+
733
+ def _transpose_last2_np(a: Any) -> np.ndarray:
734
+ return np.swapaxes(a, -1, -2)
735
+
736
+
521
737
  def grad(field) -> Expr:
522
738
  """Return basis gradients for a scalar or vector FormField."""
523
739
  return Expr("grad", _as_expr(field))
@@ -649,6 +865,62 @@ def _call_user(fn, *args, params):
649
865
  return fn(*args)
650
866
 
651
867
 
868
+ @dataclass(frozen=True)
869
+ class KernelSpec:
870
+ kind: str
871
+ domain: str
872
+
873
+
874
+ class TaggedKernel:
875
+ def __init__(self, fn, spec: KernelSpec):
876
+ self._fn = fn
877
+ self._ff_spec = spec
878
+ self._ff_kind = spec.kind
879
+ self._ff_domain = spec.domain
880
+ update_wrapper(self, fn)
881
+ self.__wrapped__ = fn
882
+
883
+ def __call__(self, *args, **kwargs):
884
+ return self._fn(*args, **kwargs)
885
+
886
+ def __repr__(self) -> str:
887
+ return f"TaggedKernel(kind={self._ff_kind!r}, domain={self._ff_domain!r})"
888
+
889
+ @property
890
+ def spec(self) -> KernelSpec:
891
+ return self._ff_spec
892
+
893
+ @property
894
+ def kind(self) -> str:
895
+ return self._ff_kind
896
+
897
+ @property
898
+ def domain(self) -> str:
899
+ return self._ff_domain
900
+
901
+ def __hash__(self) -> int:
902
+ return hash(self._fn)
903
+
904
+
905
+ def _tag_form(fn, *, kind: str, domain: str):
906
+ spec = KernelSpec(kind=kind, domain=domain)
907
+ fn._ff_spec = spec
908
+ fn._ff_kind = kind
909
+ fn._ff_domain = domain
910
+ return fn
911
+
912
+
913
+ def kernel(*, kind: str, domain: str = "volume"):
914
+ """
915
+ Decorator to tag raw kernels with kind/domain metadata for assembly inference.
916
+ """
917
+ def _deco(fn):
918
+ spec = KernelSpec(kind=kind, domain=domain)
919
+ return TaggedKernel(fn, spec)
920
+
921
+ return _deco
922
+
923
+
652
924
  def compile_bilinear(fn):
653
925
  """get_compiled a bilinear weak form (u, v, params) -> Expr into a kernel."""
654
926
  if isinstance(fn, Expr):
@@ -677,7 +949,7 @@ def compile_bilinear(fn):
677
949
  return eval_with_plan(plan, ctx, params)
678
950
 
679
951
  _form._includes_measure = True
680
- return _form
952
+ return _tag_form(_form, kind="bilinear", domain="volume")
681
953
 
682
954
 
683
955
  def compile_linear(fn):
@@ -707,7 +979,7 @@ def compile_linear(fn):
707
979
  return eval_with_plan(plan, ctx, params)
708
980
 
709
981
  _form._includes_measure = True
710
- return _form
982
+ return _tag_form(_form, kind="linear", domain="volume")
711
983
 
712
984
 
713
985
  def _expr_contains(expr: Expr, op: str) -> bool:
@@ -1043,7 +1315,10 @@ def eval_with_plan(
1043
1315
  continue
1044
1316
  if op == "einsum":
1045
1317
  subscripts = args[0]
1046
- operands = [get(arg) for arg in args[1:]]
1318
+ operands = [
1319
+ (jnp.asarray(arg) if isinstance(arg, tuple) else arg)
1320
+ for arg in (get(arg) for arg in args[1:])
1321
+ ]
1047
1322
  vals[i] = jnp.einsum(subscripts, *operands)
1048
1323
  continue
1049
1324
 
@@ -1052,6 +1327,293 @@ def eval_with_plan(
1052
1327
  return vals[index[id(plan.expr)]]
1053
1328
 
1054
1329
 
1330
+ def eval_with_plan_numpy(
1331
+ plan: EvalPlan,
1332
+ ctx: VolumeContext | SurfaceContext,
1333
+ params: ParamsLike,
1334
+ u_elem: UElement | None = None,
1335
+ ):
1336
+ nodes = plan.nodes
1337
+ index = plan.index
1338
+ vals: list[Any] = [None] * len(nodes)
1339
+
1340
+ def get(obj):
1341
+ if isinstance(obj, Expr):
1342
+ return vals[index[id(obj)]]
1343
+ if isinstance(obj, FieldRef):
1344
+ raise TypeError(
1345
+ "FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
1346
+ )
1347
+ if isinstance(obj, ParamRef):
1348
+ return params
1349
+ return obj
1350
+
1351
+ for i, node in enumerate(nodes):
1352
+ op = node.op
1353
+ args = node.args
1354
+
1355
+ if op == "lit":
1356
+ vals[i] = args[0]
1357
+ continue
1358
+ if op == "getattr":
1359
+ base = get(args[0])
1360
+ name = args[1]
1361
+ if isinstance(base, dict):
1362
+ vals[i] = base[name]
1363
+ else:
1364
+ vals[i] = getattr(base, name)
1365
+ continue
1366
+ if op == "value":
1367
+ ref = args[0]
1368
+ assert isinstance(ref, FieldRef)
1369
+ field = _eval_field_np(ref, ctx, params)
1370
+ if ref.role == "unknown":
1371
+ vals[i] = _eval_unknown_value_np(ref, field, u_elem)
1372
+ else:
1373
+ vals[i] = field.N
1374
+ continue
1375
+ if op == "grad":
1376
+ ref = args[0]
1377
+ assert isinstance(ref, FieldRef)
1378
+ field = _eval_field_np(ref, ctx, params)
1379
+ if ref.role == "unknown":
1380
+ vals[i] = _eval_unknown_grad_np(ref, field, u_elem)
1381
+ else:
1382
+ vals[i] = field.gradN
1383
+ continue
1384
+ if op == "pow":
1385
+ base = get(args[0])
1386
+ exp = get(args[1])
1387
+ vals[i] = base**exp
1388
+ continue
1389
+ if op == "eye":
1390
+ vals[i] = np.eye(int(args[0]))
1391
+ continue
1392
+ if op == "det":
1393
+ vals[i] = np.linalg.det(get(args[0]))
1394
+ continue
1395
+ if op == "inv":
1396
+ vals[i] = np.linalg.inv(get(args[0]))
1397
+ continue
1398
+ if op == "transpose":
1399
+ vals[i] = np.swapaxes(get(args[0]), -1, -2)
1400
+ continue
1401
+ if op == "log":
1402
+ vals[i] = np.log(get(args[0]))
1403
+ continue
1404
+ if op == "surface_normal":
1405
+ normal = getattr(ctx, "normal", None)
1406
+ if normal is None:
1407
+ raise ValueError("surface normal is not available in context")
1408
+ vals[i] = normal
1409
+ continue
1410
+ if op == "surface_measure":
1411
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
1412
+ raise TypeError("surface measure requires SurfaceContext.")
1413
+ vals[i] = ctx.w * ctx.detJ
1414
+ continue
1415
+ if op == "volume_measure":
1416
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
1417
+ raise TypeError("volume measure requires VolumeContext.")
1418
+ vals[i] = ctx.w * ctx.test.detJ
1419
+ continue
1420
+ if op == "sym_grad":
1421
+ ref = args[0]
1422
+ assert isinstance(ref, FieldRef)
1423
+ field = _eval_field_np(ref, ctx, params)
1424
+ if ref.role == "unknown":
1425
+ if u_elem is None:
1426
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
1427
+ u_local = _extract_unknown_elem(ref, u_elem)
1428
+ vals[i] = _sym_grad_u_np(field, u_local)
1429
+ else:
1430
+ vals[i] = _sym_grad_np(field)
1431
+ continue
1432
+ if op == "outer":
1433
+ a, b = args
1434
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
1435
+ raise TypeError("outer expects FieldRef operands.")
1436
+ test, trial = a, b
1437
+ vals[i] = _basis_outer_np(test, trial, ctx, params)
1438
+ continue
1439
+ if op == "add":
1440
+ vals[i] = get(args[0]) + get(args[1])
1441
+ continue
1442
+ if op == "sub":
1443
+ vals[i] = get(args[0]) - get(args[1])
1444
+ continue
1445
+ if op == "mul":
1446
+ a = get(args[0])
1447
+ b = get(args[1])
1448
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
1449
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
1450
+ a = a[:, None]
1451
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
1452
+ b = b[:, None]
1453
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
1454
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
1455
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
1456
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
1457
+ vals[i] = a * b
1458
+ continue
1459
+ if op == "matmul":
1460
+ a = get(args[0])
1461
+ b = get(args[1])
1462
+ if (
1463
+ hasattr(a, "ndim")
1464
+ and hasattr(b, "ndim")
1465
+ and a.ndim == 3
1466
+ and b.ndim == 3
1467
+ and a.shape[0] == b.shape[0]
1468
+ and a.shape[-1] == b.shape[-1]
1469
+ ):
1470
+ vals[i] = np.einsum("qia,qja->qij", a, b)
1471
+ else:
1472
+ raise TypeError(
1473
+ "Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
1474
+ )
1475
+ continue
1476
+ if op == "matmul_std":
1477
+ a = get(args[0])
1478
+ b = get(args[1])
1479
+ vals[i] = np.matmul(a, b)
1480
+ continue
1481
+ if op == "neg":
1482
+ vals[i] = -get(args[0])
1483
+ continue
1484
+ if op == "dot":
1485
+ ref = args[0]
1486
+ if isinstance(ref, FieldRef):
1487
+ vals[i] = _dot_np(_eval_field_np(ref, ctx, params), get(args[1]))
1488
+ else:
1489
+ a = get(args[0])
1490
+ b = get(args[1])
1491
+ if (
1492
+ hasattr(a, "ndim")
1493
+ and hasattr(b, "ndim")
1494
+ and a.ndim >= 3
1495
+ and b.ndim >= 3
1496
+ and a.shape[0] == b.shape[0]
1497
+ and a.shape[1] == b.shape[1]
1498
+ ):
1499
+ vals[i] = np.einsum("qi...,qj...->qij...", a, b)
1500
+ else:
1501
+ vals[i] = np.matmul(a, b)
1502
+ continue
1503
+ if op == "sdot":
1504
+ ref = args[0]
1505
+ if isinstance(ref, FieldRef):
1506
+ vals[i] = _dot_np(_eval_field_np(ref, ctx, params), get(args[1]))
1507
+ else:
1508
+ a = get(args[0])
1509
+ b = get(args[1])
1510
+ if (
1511
+ hasattr(a, "ndim")
1512
+ and hasattr(b, "ndim")
1513
+ and a.ndim >= 3
1514
+ and b.ndim >= 3
1515
+ and a.shape[0] == b.shape[0]
1516
+ and a.shape[1] == b.shape[1]
1517
+ ):
1518
+ vals[i] = np.einsum("qi...,qj...->qij...", a, b)
1519
+ else:
1520
+ vals[i] = np.matmul(a, b)
1521
+ continue
1522
+ if op == "ddot":
1523
+ if len(args) == 2:
1524
+ a = get(args[0])
1525
+ b = get(args[1])
1526
+ if (
1527
+ hasattr(a, "ndim")
1528
+ and hasattr(b, "ndim")
1529
+ and a.ndim == 3
1530
+ and b.ndim == 3
1531
+ and a.shape[0] == b.shape[0]
1532
+ and a.shape[1] == b.shape[1]
1533
+ ):
1534
+ vals[i] = np.einsum("qik,qim->qkm", a, b)
1535
+ else:
1536
+ vals[i] = _ddot_np(a, b)
1537
+ else:
1538
+ vals[i] = _ddot_np(get(args[0]), get(args[1]), get(args[2]))
1539
+ continue
1540
+ if op == "inner":
1541
+ a = get(args[0])
1542
+ b = get(args[1])
1543
+ vals[i] = np.einsum("...i,...i->...", a, b)
1544
+ continue
1545
+ if op == "action":
1546
+ ref = args[0]
1547
+ assert isinstance(ref, FieldRef)
1548
+ if isinstance(args[1], FieldRef):
1549
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1550
+ v_field = _eval_field_np(ref, ctx, params)
1551
+ s = get(args[1])
1552
+ value_dim = int(getattr(v_field, "value_dim", 1))
1553
+ if value_dim == 1:
1554
+ if v_field.N.ndim != 2:
1555
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
1556
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
1557
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
1558
+ vals[i] = v_field.N * s
1559
+ else:
1560
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
1561
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
1562
+ vals[i] = _dot_np(v_field, s)
1563
+ continue
1564
+ if op == "gaction":
1565
+ ref = args[0]
1566
+ assert isinstance(ref, FieldRef)
1567
+ v_field = _eval_field_np(ref, ctx, params)
1568
+ q = get(args[1])
1569
+ if v_field.gradN.ndim != 3:
1570
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
1571
+ if not hasattr(q, "ndim"):
1572
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1573
+ if q.ndim == 2:
1574
+ vals[i] = np.einsum("qaj,qj->qa", v_field.gradN, q)
1575
+ elif q.ndim == 3:
1576
+ if int(getattr(v_field, "value_dim", 1)) == 1:
1577
+ raise ValueError("gaction tensor flux requires vector test field.")
1578
+ vals[i] = np.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
1579
+ else:
1580
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1581
+ continue
1582
+ if op == "transpose_last2":
1583
+ vals[i] = _transpose_last2_np(get(args[0]))
1584
+ continue
1585
+ if op == "einsum":
1586
+ subscripts = args[0]
1587
+ operands = [
1588
+ (np.asarray(arg) if isinstance(arg, tuple) else arg)
1589
+ for arg in (get(arg) for arg in args[1:])
1590
+ ]
1591
+ if "..." not in subscripts:
1592
+ has_extra = False
1593
+ parts = subscripts.split("->")
1594
+ in_terms = parts[0].split(",")
1595
+ out_term = parts[1] if len(parts) > 1 else None
1596
+ updated_terms = []
1597
+ for term, opnd in zip(in_terms, operands):
1598
+ if hasattr(opnd, "ndim") and opnd.ndim > len(term):
1599
+ has_extra = True
1600
+ updated_terms.append(term + "...")
1601
+ else:
1602
+ updated_terms.append(term)
1603
+ if has_extra:
1604
+ if out_term is not None:
1605
+ out_term = out_term + "..."
1606
+ subscripts = ",".join(updated_terms) + "->" + out_term
1607
+ else:
1608
+ subscripts = ",".join(updated_terms)
1609
+ vals[i] = np.einsum(subscripts, *operands)
1610
+ continue
1611
+
1612
+ raise ValueError(f"Unknown Expr op: {op}")
1613
+
1614
+ return vals[index[id(plan.expr)]]
1615
+
1616
+
1055
1617
  def compile_surface_linear(fn):
1056
1618
  """get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
1057
1619
  if isinstance(fn, Expr):
@@ -1080,7 +1642,39 @@ def compile_surface_linear(fn):
1080
1642
  return eval_with_plan(plan, ctx, params)
1081
1643
 
1082
1644
  _form._includes_measure = True # type: ignore[attr-defined]
1083
- return _form
1645
+ return _tag_form(_form, kind="linear", domain="surface")
1646
+
1647
+
1648
+ def compile_surface_bilinear(fn):
1649
+ """get_compiled a surface bilinear form into a kernel (ctx, params) -> ndarray."""
1650
+ if isinstance(fn, Expr):
1651
+ expr = fn
1652
+ else:
1653
+ v = test_ref()
1654
+ u = trial_ref()
1655
+ p = param_ref()
1656
+ expr = _call_user(fn, u, v, params=p)
1657
+
1658
+ expr = _as_expr(expr)
1659
+ if not isinstance(expr, Expr):
1660
+ raise ValueError("Surface bilinear form must return an Expr; use ds() in the expression.")
1661
+
1662
+ surface_count = _count_op(expr, "surface_measure")
1663
+ volume_count = _count_op(expr, "volume_measure")
1664
+ if surface_count == 0:
1665
+ raise ValueError("Surface bilinear form must include ds().")
1666
+ if surface_count > 1:
1667
+ raise ValueError("Surface bilinear form must include ds() exactly once.")
1668
+ if volume_count > 0:
1669
+ raise ValueError("Surface bilinear form must not include dOmega().")
1670
+
1671
+ plan = make_eval_plan(expr)
1672
+
1673
+ def _form(ctx, params):
1674
+ return eval_with_plan(plan, ctx, params)
1675
+
1676
+ _form._includes_measure = True # type: ignore[attr-defined]
1677
+ return _tag_form(_form, kind="bilinear", domain="surface")
1084
1678
 
1085
1679
 
1086
1680
  class LinearForm:
@@ -1163,7 +1757,7 @@ def compile_residual(fn):
1163
1757
  return eval_with_plan(plan, ctx, params, u_elem=u_elem)
1164
1758
 
1165
1759
  _form._includes_measure = True
1166
- return _form
1760
+ return _tag_form(_form, kind="residual", domain="volume")
1167
1761
 
1168
1762
 
1169
1763
  def compile_mixed_residual(residuals: dict[str, Callable]):
@@ -1194,11 +1788,158 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
1194
1788
  if surface_count > 0:
1195
1789
  raise ValueError(f"Mixed residual '{name}' must not include ds().")
1196
1790
 
1791
+ class _MixedContextView:
1792
+ def __init__(self, ctx, field_name: str):
1793
+ self._ctx = ctx
1794
+ self.fields = ctx.fields
1795
+ self.x_q = ctx.x_q
1796
+ self.w = ctx.w
1797
+ self.elem_id = ctx.elem_id
1798
+ self.trial_fields = ctx.trial_fields
1799
+ self.test_fields = ctx.test_fields
1800
+ self.unknown_fields = ctx.unknown_fields
1801
+ self.unknown = ctx.unknown
1802
+
1803
+ pair = ctx.fields[field_name]
1804
+ self.test = pair.test
1805
+ self.trial = pair.trial
1806
+ self.v = pair.test
1807
+ self.u = pair.trial
1808
+
1809
+ if hasattr(ctx, "normal"):
1810
+ self.normal = ctx.normal
1811
+
1812
+ def __getattr__(self, name: str):
1813
+ return getattr(self._ctx, name)
1814
+
1197
1815
  def _form(ctx, u_elem, params):
1198
- return {name: eval_with_plan(plan, ctx, params, u_elem=u_elem) for name, plan in plans.items()}
1816
+ return {
1817
+ name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1818
+ for name, plan in plans.items()
1819
+ }
1199
1820
 
1200
1821
  _form._includes_measure = includes_measure
1201
- return _form
1822
+ return _tag_form(_form, kind="residual", domain="volume")
1823
+
1824
+
1825
+ def compile_mixed_surface_residual(residuals: dict[str, Callable]):
1826
+ """get_compiled mixed surface residuals keyed by field name."""
1827
+ compiled = {}
1828
+ plans = {}
1829
+ includes_measure = {}
1830
+ for name, fn in residuals.items():
1831
+ if isinstance(fn, Expr):
1832
+ expr = fn
1833
+ else:
1834
+ v = test_ref(name)
1835
+ u = unknown_ref(name)
1836
+ p = param_ref()
1837
+ expr = _call_user(fn, v, u, params=p)
1838
+ expr = _as_expr(expr)
1839
+ if not isinstance(expr, Expr):
1840
+ raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
1841
+ compiled[name] = expr
1842
+ plans[name] = make_eval_plan(expr)
1843
+ volume_count = _count_op(compiled[name], "volume_measure")
1844
+ surface_count = _count_op(compiled[name], "surface_measure")
1845
+ includes_measure[name] = surface_count == 1
1846
+ if surface_count == 0:
1847
+ raise ValueError(f"Mixed surface residual '{name}' must include ds().")
1848
+ if surface_count > 1:
1849
+ raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
1850
+ if volume_count > 0:
1851
+ raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
1852
+
1853
+ class _MixedContextView:
1854
+ def __init__(self, ctx, field_name: str):
1855
+ self._ctx = ctx
1856
+ self.fields = ctx.fields
1857
+ self.x_q = ctx.x_q
1858
+ self.w = ctx.w
1859
+ self.detJ = ctx.detJ
1860
+ self.normal = getattr(ctx, "normal", None)
1861
+ self.trial_fields = ctx.trial_fields
1862
+ self.test_fields = ctx.test_fields
1863
+ self.unknown_fields = ctx.unknown_fields
1864
+ self.unknown = getattr(ctx, "unknown", None)
1865
+
1866
+ pair = ctx.fields[field_name]
1867
+ self.test = pair.test
1868
+ self.trial = pair.trial
1869
+ self.v = pair.test
1870
+ self.u = pair.trial
1871
+
1872
+ def __getattr__(self, name: str):
1873
+ return getattr(self._ctx, name)
1874
+
1875
+ def _form(ctx, u_elem, params):
1876
+ return {
1877
+ name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1878
+ for name, plan in plans.items()
1879
+ }
1880
+
1881
+ _form._includes_measure = includes_measure
1882
+ return _tag_form(_form, kind="residual", domain="surface")
1883
+
1884
+
1885
+ def compile_mixed_surface_residual_numpy(residuals: dict[str, Callable]):
1886
+ """Mixed surface residual compiled for numpy evaluation."""
1887
+ compiled = {}
1888
+ plans = {}
1889
+ includes_measure = {}
1890
+ for name, fn in residuals.items():
1891
+ if isinstance(fn, Expr):
1892
+ expr = fn
1893
+ else:
1894
+ v = test_ref(name)
1895
+ u = unknown_ref(name)
1896
+ p = param_ref()
1897
+ expr = _call_user(fn, v, u, params=p)
1898
+ expr = _as_expr(expr)
1899
+ if not isinstance(expr, Expr):
1900
+ raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
1901
+ compiled[name] = expr
1902
+ plans[name] = make_eval_plan(expr)
1903
+ volume_count = _count_op(compiled[name], "volume_measure")
1904
+ surface_count = _count_op(compiled[name], "surface_measure")
1905
+ includes_measure[name] = surface_count == 1
1906
+ if surface_count == 0:
1907
+ raise ValueError(f"Mixed surface residual '{name}' must include ds().")
1908
+ if surface_count > 1:
1909
+ raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
1910
+ if volume_count > 0:
1911
+ raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
1912
+
1913
+ class _MixedContextView:
1914
+ def __init__(self, ctx, field_name: str):
1915
+ self._ctx = ctx
1916
+ self.fields = ctx.fields
1917
+ self.x_q = ctx.x_q
1918
+ self.w = ctx.w
1919
+ self.detJ = ctx.detJ
1920
+ self.normal = getattr(ctx, "normal", None)
1921
+ self.trial_fields = ctx.trial_fields
1922
+ self.test_fields = ctx.test_fields
1923
+ self.unknown_fields = ctx.unknown_fields
1924
+ self.unknown = getattr(ctx, "unknown", None)
1925
+
1926
+ pair = ctx.fields[field_name]
1927
+ self.test = pair.test
1928
+ self.trial = pair.trial
1929
+ self.v = pair.test
1930
+ self.u = pair.trial
1931
+
1932
+ def __getattr__(self, name: str):
1933
+ return getattr(self._ctx, name)
1934
+
1935
+ def _form(ctx, u_elem, params):
1936
+ return {
1937
+ name: eval_with_plan_numpy(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1938
+ for name, plan in plans.items()
1939
+ }
1940
+
1941
+ _form._includes_measure = includes_measure
1942
+ return _tag_form(_form, kind="residual", domain="surface")
1202
1943
 
1203
1944
 
1204
1945
  class MixedWeakForm:
@@ -1213,6 +1954,20 @@ class MixedWeakForm:
1213
1954
  return compile_mixed_residual(self.residuals)
1214
1955
 
1215
1956
 
1957
+ def make_mixed_residuals(residuals: dict[str, Callable] | None = None, **kwargs) -> dict[str, Callable]:
1958
+ """
1959
+ Helper to build mixed residual dictionaries.
1960
+
1961
+ Example:
1962
+ res = make_mixed_residuals(u=res_u, p=res_p)
1963
+ """
1964
+ if residuals is not None and kwargs:
1965
+ raise ValueError("Pass either residuals dict or keyword residuals, not both.")
1966
+ if residuals is None:
1967
+ return dict(kwargs)
1968
+ return dict(residuals)
1969
+
1970
+
1216
1971
  def _eval_expr(
1217
1972
  expr: Expr,
1218
1973
  ctx: VolumeContext | SurfaceContext,
@@ -1230,13 +1985,19 @@ __all__ = [
1230
1985
  "trial_ref",
1231
1986
  "test_ref",
1232
1987
  "unknown_ref",
1988
+ "zero_ref",
1233
1989
  "param_ref",
1234
1990
  "Params",
1235
1991
  "MixedWeakForm",
1992
+ "make_mixed_residuals",
1993
+ "kernel",
1236
1994
  "ResidualForm",
1237
1995
  "compile_bilinear",
1238
1996
  "compile_linear",
1239
1997
  "compile_residual",
1998
+ "compile_surface_bilinear",
1999
+ "compile_mixed_surface_residual",
2000
+ "compile_mixed_surface_residual_numpy",
1240
2001
  "compile_mixed_residual",
1241
2002
  "grad",
1242
2003
  "sym_grad",