fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. fluxfem/__init__.py +136 -161
  2. fluxfem/core/__init__.py +172 -41
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/context_types.py +36 -0
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +15 -1
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +348 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +262 -17
  13. fluxfem/core/weakform.py +1503 -312
  14. fluxfem/helpers_wf.py +53 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +322 -8
  17. fluxfem/mesh/contact.py +825 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +18 -16
  20. fluxfem/mesh/io.py +8 -4
  21. fluxfem/mesh/mortar.py +3907 -0
  22. fluxfem/mesh/supermesh.py +316 -0
  23. fluxfem/mesh/surface.py +22 -4
  24. fluxfem/mesh/tet.py +10 -4
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  27. fluxfem/physics/elasticity/linear.py +9 -2
  28. fluxfem/solver/__init__.py +42 -2
  29. fluxfem/solver/bc.py +38 -2
  30. fluxfem/solver/block_matrix.py +132 -0
  31. fluxfem/solver/block_system.py +454 -0
  32. fluxfem/solver/cg.py +115 -33
  33. fluxfem/solver/dirichlet.py +334 -4
  34. fluxfem/solver/newton.py +237 -60
  35. fluxfem/solver/petsc.py +439 -0
  36. fluxfem/solver/preconditioner.py +106 -0
  37. fluxfem/solver/result.py +18 -0
  38. fluxfem/solver/solve_runner.py +168 -1
  39. fluxfem/solver/solver.py +12 -1
  40. fluxfem/solver/sparse.py +124 -9
  41. fluxfem-0.2.0.dist-info/METADATA +303 -0
  42. fluxfem-0.2.0.dist-info/RECORD +59 -0
  43. fluxfem-0.1.3.dist-info/METADATA +0 -125
  44. fluxfem-0.1.3.dist-info/RECORD +0 -47
  45. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  46. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/weakform.py CHANGED
@@ -1,82 +1,314 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable
4
+ from typing import Any, Callable, Iterator, Literal, get_args
5
+ import inspect
6
+ from dataclasses import dataclass
7
+ from functools import update_wrapper
8
+
9
+ import numpy as np
5
10
 
6
11
  import jax.numpy as jnp
7
12
  import jax
8
13
 
9
14
  from ..physics import operators as _ops
15
+ from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
10
16
 
11
17
 
18
+ OpName = Literal[
19
+ "lit",
20
+ "getattr",
21
+ "value",
22
+ "grad",
23
+ "pow",
24
+ "eye",
25
+ "det",
26
+ "inv",
27
+ "transpose",
28
+ "log",
29
+ "surface_normal",
30
+ "surface_measure",
31
+ "volume_measure",
32
+ "sym_grad",
33
+ "outer",
34
+ "add",
35
+ "sub",
36
+ "mul",
37
+ "matmul",
38
+ "matmul_std",
39
+ "neg",
40
+ "dot",
41
+ "sdot",
42
+ "ddot",
43
+ "inner",
44
+ "action",
45
+ "gaction",
46
+ "transpose_last2",
47
+ "einsum",
48
+ ]
49
+
50
+ # Use OpName as the single source of truth for valid ops.
51
+ _OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
52
+
53
+
54
+ _PRECEDENCE: dict[str, int] = {
55
+ "add": 10,
56
+ "sub": 10,
57
+ "mul": 20,
58
+ "matmul": 20,
59
+ "matmul_std": 20,
60
+ "inner": 20,
61
+ "dot": 20,
62
+ "sdot": 20,
63
+ "ddot": 20,
64
+ "pow": 30,
65
+ "neg": 40,
66
+ "transpose": 50,
67
+ }
68
+
69
+
70
+ def _pretty_render_arg(arg, prec: int | None = None) -> str:
71
+ if isinstance(arg, Expr):
72
+ return _pretty_expr(arg, prec or 0)
73
+ if isinstance(arg, FieldRef):
74
+ if arg.name is None:
75
+ return f"{arg.role}"
76
+ return f"{arg.role}:{arg.name}"
77
+ if isinstance(arg, ParamRef):
78
+ return "param"
79
+ return repr(arg)
80
+
81
+
82
+ def _pretty_wrap(text: str, prec: int, parent_prec: int) -> str:
83
+ if prec < parent_prec:
84
+ return f"({text})"
85
+ return text
86
+
87
+
88
+ def _pretty_expr(expr: Expr, parent_prec: int = 0) -> str:
89
+ op = expr.op
90
+ args = expr.args
91
+
92
+ if op == "lit":
93
+ return repr(args[0])
94
+ if op == "getattr":
95
+ base = _pretty_render_arg(args[0], _PRECEDENCE.get("transpose", 50))
96
+ return f"{base}.{args[1]}"
97
+ if op == "value":
98
+ return f"val({_pretty_render_arg(args[0])})"
99
+ if op == "grad":
100
+ return f"grad({_pretty_render_arg(args[0])})"
101
+ if op == "sym_grad":
102
+ return f"sym_grad({_pretty_render_arg(args[0])})"
103
+ if op == "neg":
104
+ inner = _pretty_render_arg(args[0], _PRECEDENCE["neg"])
105
+ return _pretty_wrap(f"-{inner}", _PRECEDENCE["neg"], parent_prec)
106
+ if op == "transpose":
107
+ inner = _pretty_render_arg(args[0], _PRECEDENCE["transpose"])
108
+ return _pretty_wrap(f"{inner}.T", _PRECEDENCE["transpose"], parent_prec)
109
+ if op == "pow":
110
+ base = _pretty_render_arg(args[0], _PRECEDENCE["pow"])
111
+ exp = _pretty_render_arg(args[1], _PRECEDENCE["pow"] + 1)
112
+ return _pretty_wrap(f"{base}**{exp}", _PRECEDENCE["pow"], parent_prec)
113
+ if op in {"add", "sub", "mul", "matmul", "dot", "sdot", "ddot"}:
114
+ left = _pretty_render_arg(args[0], _PRECEDENCE[op])
115
+ right = _pretty_render_arg(args[1], _PRECEDENCE[op] + 1)
116
+ symbol = {
117
+ "add": "+",
118
+ "sub": "-",
119
+ "mul": "*",
120
+ "matmul": "@",
121
+ "inner": "|",
122
+ "dot": "dot",
123
+ "sdot": "sdot",
124
+ "ddot": "ddot",
125
+ }[op]
126
+ if symbol in {"dot", "sdot", "ddot"}:
127
+ text = f"{symbol}({left}, {right})"
128
+ else:
129
+ text = f"{left} {symbol} {right}"
130
+ return _pretty_wrap(text, _PRECEDENCE[op], parent_prec)
131
+ if op == "inner":
132
+ return f"inner({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
133
+ if op in {"action", "gaction"}:
134
+ return f"{op}({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
135
+ if op == "matmul_std":
136
+ return f"matmul_std({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
137
+ if op == "outer":
138
+ return f"outer({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
139
+ if op in {
140
+ "eye",
141
+ "det",
142
+ "inv",
143
+ "log",
144
+ "surface_normal",
145
+ "surface_measure",
146
+ "volume_measure",
147
+ "transpose_last2",
148
+ "einsum",
149
+ }:
150
+ rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
151
+ return f"{op}({rendered})"
152
+ rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
153
+ return f"{op}({rendered})"
154
+
155
+
156
+ def _as_expr(obj) -> Expr | FieldRef | ParamRef:
157
+ """Normalize inputs into Expr/FieldRef/ParamRef nodes."""
158
+ if isinstance(obj, Expr):
159
+ return obj
160
+ if isinstance(obj, FieldRef):
161
+ return obj
162
+ if isinstance(obj, ParamRef):
163
+ return obj
164
+ if isinstance(obj, (int, float, bool, str)):
165
+ return Expr("lit", obj)
166
+ if isinstance(obj, np.generic):
167
+ return Expr("lit", obj.item())
168
+ if isinstance(obj, tuple):
169
+ try:
170
+ hash(obj)
171
+ except TypeError as exc:
172
+ raise TypeError(
173
+ "Expr tuple literal must be hashable; use only immutable items."
174
+ ) from exc
175
+ return Expr("lit", obj)
176
+ raise TypeError(
177
+ "Expr literal must be a scalar or hashable tuple. "
178
+ "Arrays are not allowed; pass them via params (ParamRef/params.xxx)."
179
+ )
180
+
181
+
182
+ @dataclass(frozen=True, slots=True, init=False)
12
183
  class Expr:
13
- """Expression tree node evaluated against a FormContext."""
184
+ """Expression tree node evaluated against a FormContext.
185
+
186
+ Compile flow (recommended):
187
+ - build an Expr via operators/refs
188
+ - compile_* builds an EvalPlan (postorder nodes + index)
189
+ - eval_with_plan(plan, ctx, params, u_elem) evaluates per element
190
+
191
+ Expr.eval is a debug/single-shot path that creates a plan on demand.
192
+ """
193
+
194
+ op: OpName
195
+ args: tuple[Any, ...]
14
196
 
15
- def __init__(self, op: str, *args):
197
+ def __init__(self, op: OpName, *args):
198
+ if op not in _OP_NAMES:
199
+ raise ValueError(f"Unknown Expr op: {op!r}")
16
200
  object.__setattr__(self, "op", op)
17
201
  object.__setattr__(self, "args", args)
18
202
 
19
203
  def eval(self, ctx, params=None, u_elem=None):
204
+ """Evaluate the expression against a context (debug/single-shot path)."""
20
205
  return _eval_expr(self, ctx, params, u_elem=u_elem)
21
206
 
207
+ def children(self) -> tuple[Any, ...]:
208
+ """Return direct child nodes (Expr/FieldRef/ParamRef) for traversal."""
209
+ return tuple(arg for arg in self.args if isinstance(arg, (Expr, FieldRef, ParamRef)))
210
+
211
+ def walk(self) -> Iterator[Any]:
212
+ """Depth-first walk over nodes, including leaf FieldRef/ParamRef."""
213
+ yield self
214
+ for child in self.children():
215
+ if isinstance(child, Expr):
216
+ yield from child.walk()
217
+ else:
218
+ yield child
219
+
220
+ def postorder(self) -> Iterator[Any]:
221
+ """Postorder walk over nodes, including leaf FieldRef/ParamRef."""
222
+ for child in self.children():
223
+ if isinstance(child, Expr):
224
+ yield from child.postorder()
225
+ else:
226
+ yield child
227
+ yield self
228
+
229
+ def postorder_expr(self) -> Iterator["Expr"]:
230
+ """Postorder walk over Expr nodes only (for eval planning)."""
231
+ for arg in self.args:
232
+ if isinstance(arg, Expr):
233
+ yield from arg.postorder_expr()
234
+ yield self
235
+
22
236
  def _binop(self, other, op):
23
237
  return Expr(op, self, _as_expr(other))
24
238
 
25
239
  def __add__(self, other):
240
+ """Add expressions: `a + b`."""
26
241
  return self._binop(other, "add")
27
242
 
28
243
  def __radd__(self, other):
29
- return _as_expr(other)._binop(self, "add")
244
+ """Right-add expressions: `1 + expr`."""
245
+ return Expr("add", _as_expr(other), self)
30
246
 
31
247
  def __sub__(self, other):
248
+ """Subtract expressions: `a - b`."""
32
249
  return self._binop(other, "sub")
33
250
 
34
251
  def __rsub__(self, other):
35
- return _as_expr(other)._binop(self, "sub")
252
+ """Right-subtract expressions: `1 - expr`."""
253
+ return Expr("sub", _as_expr(other), self)
36
254
 
37
255
  def __mul__(self, other):
256
+ """Multiply expressions: `a * b`."""
38
257
  return self._binop(other, "mul")
39
258
 
40
259
  def __rmul__(self, other):
41
- return _as_expr(other)._binop(self, "mul")
260
+ """Right-multiply expressions: `2 * expr`."""
261
+ return Expr("mul", _as_expr(other), self)
42
262
 
43
263
  def __matmul__(self, other):
264
+ """Matrix product: `a @ b` (FEM-specific contraction semantics)."""
44
265
  return self._binop(other, "matmul")
45
266
 
46
267
  def __rmatmul__(self, other):
47
- return _as_expr(other)._binop(self, "matmul")
268
+ """Right-matmul: `A @ expr`."""
269
+ return Expr("matmul", _as_expr(other), self)
48
270
 
49
271
  def __or__(self, other):
50
- return self._binop(other, "inner")
272
+ """Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
273
+ if isinstance(other, FieldRef):
274
+ raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
275
+ return Expr("inner", self, _as_expr(other))
51
276
 
52
277
  def __ror__(self, other):
53
- return _as_expr(other)._binop(self, "inner")
278
+ """Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
279
+ if isinstance(other, FieldRef):
280
+ raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
281
+ return Expr("inner", _as_expr(other), self)
54
282
 
55
283
  def __pow__(self, power, modulo=None):
284
+ """Power: `a ** p` (no modulo support)."""
56
285
  if modulo is not None:
57
286
  raise ValueError("modulo is not supported for Expr exponentiation.")
58
287
  return Expr("pow", self, _as_expr(power))
59
288
 
60
289
  def __neg__(self):
290
+ """Unary negation: `-expr`."""
61
291
  return Expr("neg", self)
62
292
 
63
293
  @property
64
294
  def T(self):
295
+ """Transpose view: `expr.T`."""
65
296
  return Expr("transpose", self)
66
297
 
298
+ def __repr__(self) -> str:
299
+ return self.pretty()
67
300
 
68
- @dataclass(frozen=True)
69
- class FieldRef(Expr):
301
+ def pretty(self) -> str:
302
+ return _pretty_expr(self)
303
+
304
+
305
+ @dataclass(frozen=True, slots=True)
306
+ class FieldRef:
70
307
  """Symbolic reference to trial/test/unknown field, optionally by name."""
71
308
 
72
309
  role: str
73
310
  name: str | None = None
74
311
 
75
- def __init__(self, role: str, name: str | None = None):
76
- object.__setattr__(self, "role", role)
77
- object.__setattr__(self, "name", name)
78
- super().__init__("field", role, name)
79
-
80
312
  @property
81
313
  def val(self):
82
314
  return Expr("value", self)
@@ -91,12 +323,18 @@ class FieldRef(Expr):
91
323
 
92
324
  def __mul__(self, other):
93
325
  if isinstance(other, FieldRef):
94
- return Expr("outer", self, other)
326
+ raise TypeError(
327
+ "FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
328
+ "action(v, s), or dot(v, q)."
329
+ )
95
330
  return Expr("mul", Expr("value", self), _as_expr(other))
96
331
 
97
332
  def __rmul__(self, other):
98
333
  if isinstance(other, FieldRef):
99
- return Expr("outer", other, self)
334
+ raise TypeError(
335
+ "FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
336
+ "action(v, s), or dot(v, q)."
337
+ )
100
338
  return Expr("mul", _as_expr(other), Expr("value", self))
101
339
 
102
340
  def __add__(self, other):
@@ -113,22 +351,23 @@ class FieldRef(Expr):
113
351
 
114
352
  def __or__(self, other):
115
353
  if isinstance(other, FieldRef):
116
- return Expr("inner", self, other)
117
- return Expr("sdot", self, _as_expr(other))
354
+ raise TypeError(
355
+ "FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
356
+ )
357
+ return Expr("dot", self, _as_expr(other))
118
358
 
119
359
  def __ror__(self, other):
120
360
  if isinstance(other, FieldRef):
121
- return Expr("inner", other, self)
122
- return Expr("sdot", _as_expr(other), self)
361
+ raise TypeError(
362
+ "FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
363
+ )
364
+ return Expr("dot", _as_expr(other), self)
123
365
 
124
366
 
125
- @dataclass(frozen=True)
126
- class ParamRef(Expr):
367
+ @dataclass(frozen=True, slots=True)
368
+ class ParamRef:
127
369
  """Symbolic reference to params passed into the kernel."""
128
370
 
129
- def __init__(self):
130
- super().__init__("param")
131
-
132
371
  def __getattr__(self, name: str):
133
372
  return Expr("getattr", self, name)
134
373
 
@@ -159,6 +398,26 @@ class Params:
159
398
  return cls(**dict(zip(keys, values)))
160
399
 
161
400
 
401
+ class _ZeroField:
402
+ """Field-like object that evaluates to zeros with the same shape."""
403
+
404
+ def __init__(self, base):
405
+ self.N = jnp.zeros_like(base.N)
406
+ self.gradN = None if getattr(base, "gradN", None) is None else jnp.zeros_like(base.gradN)
407
+ self.value_dim = int(getattr(base, "value_dim", 1))
408
+ self.basis = getattr(base, "basis", None)
409
+
410
+
411
+ class _ZeroFieldNp:
412
+ """Numpy variant of a zero-valued field (for numpy backend evaluation)."""
413
+
414
+ def __init__(self, base):
415
+ self.N = np.zeros_like(base.N)
416
+ self.gradN = None if getattr(base, "gradN", None) is None else np.zeros_like(base.gradN)
417
+ self.value_dim = int(getattr(base, "value_dim", 1))
418
+ self.basis = getattr(base, "basis", None)
419
+
420
+
162
421
  def trial_ref(name: str | None = "u") -> FieldRef:
163
422
  """Create a symbolic trial field reference."""
164
423
  return FieldRef(role="trial", name=name)
@@ -179,14 +438,34 @@ def param_ref() -> ParamRef:
179
438
  return ParamRef()
180
439
 
181
440
 
182
- def _as_expr(obj) -> Expr:
183
- if isinstance(obj, Expr):
184
- return obj
185
- return Expr("lit", obj)
441
+ def zero_ref(name: str) -> FieldRef:
442
+ """Create a zero-valued field reference (shape derived from context)."""
443
+ return FieldRef("zero", name)
186
444
 
187
445
 
188
- def _eval_field(obj: Any, ctx, params):
446
+ def _eval_field(
447
+ obj: Any,
448
+ ctx: VolumeContext | SurfaceContext,
449
+ params: ParamsLike,
450
+ ) -> FormFieldLike:
189
451
  if isinstance(obj, FieldRef):
452
+ if obj.role == "zero":
453
+ if obj.name is None:
454
+ raise ValueError("zero_ref requires a named field.")
455
+ base = None
456
+ if getattr(ctx, "test_fields", None) is not None and obj.name in ctx.test_fields:
457
+ base = ctx.test_fields[obj.name]
458
+ if base is None and getattr(ctx, "trial_fields", None) is not None and obj.name in ctx.trial_fields:
459
+ base = ctx.trial_fields[obj.name]
460
+ if base is None and getattr(ctx, "fields", None) is not None and obj.name in ctx.fields:
461
+ group = ctx.fields[obj.name]
462
+ if hasattr(group, "test"):
463
+ base = group.test
464
+ elif hasattr(group, "trial"):
465
+ base = group.trial
466
+ if base is None:
467
+ raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
468
+ return _ZeroField(base)
190
469
  if obj.name is not None:
191
470
  mixed_fields = getattr(ctx, "fields", None)
192
471
  if mixed_fields is not None and obj.name in mixed_fields:
@@ -226,25 +505,88 @@ def _eval_field(obj: Any, ctx, params):
226
505
  if obj.role == "unknown":
227
506
  return getattr(ctx, "unknown", ctx.trial)
228
507
  raise ValueError(f"Unknown field role: {obj.role}")
229
- if isinstance(obj, Expr):
230
- val = obj.eval(ctx, params)
231
- if hasattr(val, "N"):
232
- return val
233
508
  raise TypeError("Expected a field reference for this operator.")
234
509
 
235
510
 
236
- def _eval_value(obj: Any, ctx, params, u_elem=None):
511
+ def _eval_field_np(
512
+ obj: Any,
513
+ ctx: VolumeContext | SurfaceContext,
514
+ params: ParamsLike,
515
+ ) -> FormFieldLike:
237
516
  if isinstance(obj, FieldRef):
238
- field = _eval_field(obj, ctx, params)
517
+ if obj.role == "zero":
518
+ if obj.name is None:
519
+ raise ValueError("zero_ref requires a named field.")
520
+ base = None
521
+ if getattr(ctx, "test_fields", None) is not None and obj.name in ctx.test_fields:
522
+ base = ctx.test_fields[obj.name]
523
+ if base is None and getattr(ctx, "trial_fields", None) is not None and obj.name in ctx.trial_fields:
524
+ base = ctx.trial_fields[obj.name]
525
+ if base is None and getattr(ctx, "fields", None) is not None and obj.name in ctx.fields:
526
+ group = ctx.fields[obj.name]
527
+ if hasattr(group, "test"):
528
+ base = group.test
529
+ elif hasattr(group, "trial"):
530
+ base = group.trial
531
+ if base is None:
532
+ raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
533
+ return _ZeroFieldNp(base)
534
+ if obj.name is not None:
535
+ mixed_fields = getattr(ctx, "fields", None)
536
+ if mixed_fields is not None and obj.name in mixed_fields:
537
+ group = mixed_fields[obj.name]
538
+ if hasattr(group, "trial") and obj.role == "trial":
539
+ return group.trial
540
+ if hasattr(group, "test") and obj.role == "test":
541
+ return group.test
542
+ if hasattr(group, "unknown") and obj.role == "unknown":
543
+ return group.unknown if group.unknown is not None else group.trial
544
+ if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
545
+ if obj.name in ctx.trial_fields:
546
+ return ctx.trial_fields[obj.name]
547
+ if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
548
+ if obj.name in ctx.test_fields:
549
+ return ctx.test_fields[obj.name]
550
+ if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
551
+ if obj.name in ctx.unknown_fields:
552
+ return ctx.unknown_fields[obj.name]
553
+ fields = getattr(ctx, "fields", None)
554
+ if fields is not None and obj.name in fields:
555
+ group = fields[obj.name]
556
+ if isinstance(group, dict):
557
+ if obj.role in group:
558
+ return group[obj.role]
559
+ if "field" in group:
560
+ return group["field"]
561
+ return group
562
+ if obj.role == "trial":
563
+ return ctx.trial
564
+ if obj.role == "test":
565
+ if hasattr(ctx, "test"):
566
+ return ctx.test
567
+ if hasattr(ctx, "v"):
568
+ return ctx.v
569
+ raise ValueError("Surface context is missing test field.")
239
570
  if obj.role == "unknown":
240
- return _eval_unknown_value(obj, field, u_elem)
241
- return field.N
242
- if isinstance(obj, Expr):
243
- return obj.eval(ctx, params, u_elem=u_elem)
244
- return obj
571
+ return getattr(ctx, "unknown", ctx.trial)
572
+ raise ValueError(f"Unknown field role: {obj.role}")
573
+ raise TypeError("Expected a field reference for this operator.")
574
+
245
575
 
576
+ # def _eval_value(obj: Any, ctx, params, u_elem=None):
577
+ # if isinstance(obj, FieldRef):
578
+ # field = _eval_field(obj, ctx, params)
579
+ # if obj.role == "unknown":
580
+ # return _eval_unknown_value(obj, field, u_elem)
581
+ # return field.N
582
+ # if isinstance(obj, ParamRef):
583
+ # return params
584
+ # if isinstance(obj, Expr):
585
+ # return obj.eval(ctx, params, u_elem=u_elem)
586
+ # return obj
246
587
 
247
- def _extract_unknown_elem(field_ref: FieldRef, u_elem):
588
+
589
+ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
248
590
  if u_elem is None:
249
591
  raise ValueError("u_elem is required to evaluate unknown field value.")
250
592
  if isinstance(u_elem, dict):
@@ -255,7 +597,27 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem):
255
597
  return u_elem
256
598
 
257
599
 
258
- def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
600
+ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
601
+ v_field = _eval_field(test, ctx, params)
602
+ u_field = _eval_field(trial, ctx, params)
603
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
604
+ raise ValueError(
605
+ "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
606
+ )
607
+ return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
608
+
609
+
610
+ def _basis_outer_np(test: FieldRef, trial: FieldRef, ctx, params):
611
+ v_field = _eval_field_np(test, ctx, params)
612
+ u_field = _eval_field_np(trial, ctx, params)
613
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
614
+ raise ValueError(
615
+ "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
616
+ )
617
+ return np.einsum("qi,qj->qij", v_field.N, u_field.N)
618
+
619
+
620
+ def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
259
621
  u_local = _extract_unknown_elem(field_ref, u_elem)
260
622
  value_dim = int(getattr(field, "value_dim", 1))
261
623
  if value_dim == 1:
@@ -264,7 +626,22 @@ def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
264
626
  return jnp.einsum("qa,ai->qi", field.N, u_nodes)
265
627
 
266
628
 
267
- def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
629
+ def _eval_unknown_value_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
630
+ u_local = _extract_unknown_elem(field_ref, u_elem)
631
+ value_dim = int(getattr(field, "value_dim", 1))
632
+ u_arr = np.asarray(u_local)
633
+ if value_dim == 1:
634
+ if u_arr.ndim == 2:
635
+ return np.einsum("qa,ab->qb", field.N, u_arr)
636
+ return np.einsum("qa,a->q", field.N, u_arr)
637
+ if u_arr.ndim == 2:
638
+ u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
639
+ return np.einsum("qa,aib->qib", field.N, u_nodes)
640
+ u_nodes = u_arr.reshape((-1, value_dim))
641
+ return np.einsum("qa,ai->qi", field.N, u_nodes)
642
+
643
+
644
+ def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
268
645
  u_local = _extract_unknown_elem(field_ref, u_elem)
269
646
  if u_local is None:
270
647
  raise ValueError("u_elem is required to evaluate unknown field gradient.")
@@ -275,6 +652,88 @@ def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
275
652
  return jnp.einsum("qaj,ai->qij", field.gradN, u_nodes)
276
653
 
277
654
 
655
+ def _eval_unknown_grad_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
656
+ u_local = _extract_unknown_elem(field_ref, u_elem)
657
+ if u_local is None:
658
+ raise ValueError("u_elem is required to evaluate unknown field gradient.")
659
+ value_dim = int(getattr(field, "value_dim", 1))
660
+ u_arr = np.asarray(u_local)
661
+ if value_dim == 1:
662
+ if u_arr.ndim == 2:
663
+ return np.einsum("qaj,ab->qjb", field.gradN, u_arr)
664
+ return np.einsum("qaj,a->qj", field.gradN, u_arr)
665
+ if u_arr.ndim == 2:
666
+ u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
667
+ return np.einsum("qaj,aib->qijb", field.gradN, u_nodes)
668
+ u_nodes = u_arr.reshape((-1, value_dim))
669
+ return np.einsum("qaj,ai->qij", field.gradN, u_nodes)
670
+
671
+
672
+ def _vector_load_form_np(field: Any, load_vec: Any) -> np.ndarray:
673
+ lv = np.asarray(load_vec)
674
+ if lv.ndim == 1:
675
+ lv = lv[None, :]
676
+ elif lv.ndim not in (2, 3):
677
+ raise ValueError("load_vec must be shape (dim,), (n_q, dim), or (n_q, dim, batch)")
678
+ if lv.shape[0] == 1:
679
+ lv = np.broadcast_to(lv, (field.N.shape[0], lv.shape[1]))
680
+ elif lv.shape[0] != field.N.shape[0]:
681
+ raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
682
+ if lv.ndim == 3:
683
+ load = field.N[..., None, None] * lv[:, None, :, :]
684
+ return load.reshape(load.shape[0], -1, load.shape[-1])
685
+ load = field.N[..., None] * lv[:, None, :]
686
+ return load.reshape(load.shape[0], -1)
687
+
688
+
689
+ def _sym_grad_np(field) -> np.ndarray:
690
+ gradN = np.asarray(field.gradN)
691
+ dofs = int(getattr(field.basis, "dofs_per_node", 3))
692
+ n_q, n_nodes, _ = gradN.shape
693
+ n_dofs = dofs * n_nodes
694
+ B = np.zeros((n_q, 6, n_dofs), dtype=gradN.dtype)
695
+ for a in range(n_nodes):
696
+ col = dofs * a
697
+ dNdx = gradN[:, a, 0]
698
+ dNdy = gradN[:, a, 1]
699
+ dNdz = gradN[:, a, 2]
700
+ B[:, 0, col + 0] = dNdx
701
+ B[:, 1, col + 1] = dNdy
702
+ B[:, 2, col + 2] = dNdz
703
+ B[:, 3, col + 0] = dNdy
704
+ B[:, 3, col + 1] = dNdx
705
+ B[:, 4, col + 1] = dNdz
706
+ B[:, 4, col + 2] = dNdy
707
+ B[:, 5, col + 0] = dNdz
708
+ B[:, 5, col + 2] = dNdx
709
+ return B
710
+
711
+
712
+ def _sym_grad_u_np(field, u_elem: Any) -> np.ndarray:
713
+ B = _sym_grad_np(field)
714
+ u_arr = np.asarray(u_elem)
715
+ if u_arr.ndim == 2:
716
+ return np.einsum("qik,kb->qib", B, u_arr)
717
+ return np.einsum("qik,k->qi", B, u_arr)
718
+
719
+
720
+ def _ddot_np(a: Any, b: Any, c: Any | None = None) -> np.ndarray:
721
+ if c is None:
722
+ return np.einsum("...ij,...ij->...", a, b)
723
+ a_t = np.swapaxes(a, -1, -2)
724
+ return np.einsum("...ik,kl,...lm->...im", a_t, b, c)
725
+
726
+
727
+ def _dot_np(a: Any, b: Any) -> np.ndarray:
728
+ if hasattr(a, "N") and getattr(a, "value_dim", None) is not None:
729
+ return _vector_load_form_np(a, b)
730
+ return np.matmul(a, b)
731
+
732
+
733
+ def _transpose_last2_np(a: Any) -> np.ndarray:
734
+ return np.swapaxes(a, -1, -2)
735
+
736
+
278
737
  def grad(field) -> Expr:
279
738
  """Return basis gradients for a scalar or vector FormField."""
280
739
  return Expr("grad", _as_expr(field))
@@ -285,6 +744,15 @@ def sym_grad(field) -> Expr:
285
744
  return Expr("sym_grad", _as_expr(field))
286
745
 
287
746
 
747
+ def outer(a, b) -> Expr:
748
+ """Outer product of scalar fields: `outer(v, u)` (test, trial)."""
749
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
750
+ raise TypeError("outer expects FieldRef operands.")
751
+ if a.role != "test" or b.role != "trial":
752
+ raise TypeError("outer expects outer(test, trial).")
753
+ return Expr("outer", a, b)
754
+
755
+
288
756
  def dot(a, b) -> Expr:
289
757
  """Dot product or vector load helper."""
290
758
  return Expr("dot", _as_expr(a), _as_expr(b))
@@ -303,7 +771,7 @@ def ddot(a, b, c=None) -> Expr:
303
771
 
304
772
 
305
773
  def inner(a, b) -> Expr:
306
- """Inner product over the last axis."""
774
+ """Inner product over the last axis (tensor-level)."""
307
775
  return Expr("inner", _as_expr(a), _as_expr(b))
308
776
 
309
777
 
@@ -363,7 +831,12 @@ def transpose_last2(a) -> Expr:
363
831
 
364
832
 
365
833
  def matmul(a, b) -> Expr:
366
- """Matrix product with standard semantics (no special 3D contraction)."""
834
+ """FEM-specific batched contraction (same semantics as `@`)."""
835
+ return Expr("matmul", _as_expr(a), _as_expr(b))
836
+
837
+
838
+ def matmul_std(a, b) -> Expr:
839
+ """Standard matrix product (`jnp.matmul` semantics)."""
367
840
  return Expr("matmul_std", _as_expr(a), _as_expr(b))
368
841
 
369
842
 
@@ -374,9 +847,78 @@ def einsum(subscripts: str, *args) -> Expr:
374
847
 
375
848
  def _call_user(fn, *args, params):
376
849
  try:
850
+ sig = inspect.signature(fn)
851
+ except (TypeError, ValueError):
377
852
  return fn(*args, params)
378
- except TypeError:
379
- return fn(*args)
853
+
854
+ params_list = list(sig.parameters.values())
855
+ if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params_list):
856
+ return fn(*args, params)
857
+ positional = [
858
+ p
859
+ for p in params_list
860
+ if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
861
+ ]
862
+ max_positional = len(positional)
863
+ if len(args) + 1 <= max_positional:
864
+ return fn(*args, params)
865
+ return fn(*args)
866
+
867
+
868
+ @dataclass(frozen=True)
869
+ class KernelSpec:
870
+ kind: str
871
+ domain: str
872
+
873
+
874
+ class TaggedKernel:
875
+ def __init__(self, fn, spec: KernelSpec):
876
+ self._fn = fn
877
+ self._ff_spec = spec
878
+ self._ff_kind = spec.kind
879
+ self._ff_domain = spec.domain
880
+ update_wrapper(self, fn)
881
+ self.__wrapped__ = fn
882
+
883
+ def __call__(self, *args, **kwargs):
884
+ return self._fn(*args, **kwargs)
885
+
886
+ def __repr__(self) -> str:
887
+ return f"TaggedKernel(kind={self._ff_kind!r}, domain={self._ff_domain!r})"
888
+
889
+ @property
890
+ def spec(self) -> KernelSpec:
891
+ return self._ff_spec
892
+
893
+ @property
894
+ def kind(self) -> str:
895
+ return self._ff_kind
896
+
897
+ @property
898
+ def domain(self) -> str:
899
+ return self._ff_domain
900
+
901
+ def __hash__(self) -> int:
902
+ return hash(self._fn)
903
+
904
+
905
+ def _tag_form(fn, *, kind: str, domain: str):
906
+ spec = KernelSpec(kind=kind, domain=domain)
907
+ fn._ff_spec = spec
908
+ fn._ff_kind = kind
909
+ fn._ff_domain = domain
910
+ return fn
911
+
912
+
913
+ def kernel(*, kind: str, domain: str = "volume"):
914
+ """
915
+ Decorator to tag raw kernels with kind/domain metadata for assembly inference.
916
+ """
917
+ def _deco(fn):
918
+ spec = KernelSpec(kind=kind, domain=domain)
919
+ return TaggedKernel(fn, spec)
920
+
921
+ return _deco
380
922
 
381
923
 
382
924
  def compile_bilinear(fn):
@@ -387,20 +929,27 @@ def compile_bilinear(fn):
387
929
  u = trial_ref()
388
930
  v = test_ref()
389
931
  p = param_ref()
390
- try:
391
- expr = fn(u, v, p)
392
- except TypeError:
393
- expr = fn(u, v)
932
+ expr = _call_user(fn, u, v, params=p)
933
+ expr = _as_expr(expr)
934
+ if not isinstance(expr, Expr):
935
+ raise TypeError("Bilinear form must return an Expr.")
394
936
 
395
- includes_measure = _expr_contains(expr, "volume_measure")
396
- if not includes_measure:
937
+ volume_count = _count_op(expr, "volume_measure")
938
+ surface_count = _count_op(expr, "surface_measure")
939
+ if volume_count == 0:
397
940
  raise ValueError("Volume bilinear form must include dOmega().")
941
+ if volume_count > 1:
942
+ raise ValueError("Volume bilinear form must include dOmega() exactly once.")
943
+ if surface_count > 0:
944
+ raise ValueError("Volume bilinear form must not include ds().")
945
+
946
+ plan = make_eval_plan(expr)
398
947
 
399
948
  def _form(ctx, params):
400
- return _as_expr(expr).eval(ctx, params)
949
+ return eval_with_plan(plan, ctx, params)
401
950
 
402
- _form._includes_measure = includes_measure
403
- return _form
951
+ _form._includes_measure = True
952
+ return _tag_form(_form, kind="bilinear", domain="volume")
404
953
 
405
954
 
406
955
  def compile_linear(fn):
@@ -410,20 +959,27 @@ def compile_linear(fn):
410
959
  else:
411
960
  v = test_ref()
412
961
  p = param_ref()
413
- try:
414
- expr = fn(v, p)
415
- except TypeError:
416
- expr = fn(v)
962
+ expr = _call_user(fn, v, params=p)
963
+ expr = _as_expr(expr)
964
+ if not isinstance(expr, Expr):
965
+ raise TypeError("Linear form must return an Expr.")
417
966
 
418
- includes_measure = _expr_contains(expr, "volume_measure")
419
- if not includes_measure:
967
+ volume_count = _count_op(expr, "volume_measure")
968
+ surface_count = _count_op(expr, "surface_measure")
969
+ if volume_count == 0:
420
970
  raise ValueError("Volume linear form must include dOmega().")
971
+ if volume_count > 1:
972
+ raise ValueError("Volume linear form must include dOmega() exactly once.")
973
+ if surface_count > 0:
974
+ raise ValueError("Volume linear form must not include ds().")
975
+
976
+ plan = make_eval_plan(expr)
421
977
 
422
978
  def _form(ctx, params):
423
- return _as_expr(expr).eval(ctx, params)
979
+ return eval_with_plan(plan, ctx, params)
424
980
 
425
- _form._includes_measure = includes_measure
426
- return _form
981
+ _form._includes_measure = True
982
+ return _tag_form(_form, kind="linear", domain="volume")
427
983
 
428
984
 
429
985
  def _expr_contains(expr: Expr, op: str) -> bool:
@@ -434,6 +990,630 @@ def _expr_contains(expr: Expr, op: str) -> bool:
434
990
  return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
435
991
 
436
992
 
993
+ def _count_op(expr: Expr, op: str) -> int:
994
+ if not isinstance(expr, Expr):
995
+ return 0
996
+ count = 1 if expr.op == op else 0
997
+ for arg in expr.args:
998
+ if isinstance(arg, Expr):
999
+ count += _count_op(arg, op)
1000
+ return count
1001
+
1002
+
1003
+ @dataclass(frozen=True, slots=True)
1004
+ class EvalPlan:
1005
+ expr: Expr
1006
+ nodes: tuple[Expr, ...]
1007
+ index: dict[int, int]
1008
+
1009
+
1010
+ def _validate_eval_plan(nodes: tuple[Expr, ...]) -> None:
1011
+ fieldref_arg_ops = {
1012
+ "value",
1013
+ "grad",
1014
+ "sym_grad",
1015
+ "dot",
1016
+ "sdot",
1017
+ "action",
1018
+ "gaction",
1019
+ "outer",
1020
+ }
1021
+ for node in nodes:
1022
+ op = node.op
1023
+ args = node.args
1024
+ if op not in fieldref_arg_ops:
1025
+ if any(isinstance(arg, FieldRef) for arg in args):
1026
+ raise TypeError(f"{op} cannot take FieldRef directly; wrap with .val/.grad/.sym_grad.")
1027
+ if op in {"value", "grad", "sym_grad"}:
1028
+ if len(args) != 1 or not isinstance(args[0], FieldRef):
1029
+ raise TypeError(f"{op} expects FieldRef.")
1030
+ elif op in {"dot", "sdot"}:
1031
+ if len(args) != 2:
1032
+ raise TypeError(f"{op} expects two arguments.")
1033
+ if any(isinstance(arg, FieldRef) for arg in args):
1034
+ if not isinstance(args[0], FieldRef):
1035
+ raise TypeError(f"{op} expects FieldRef as the first argument.")
1036
+ if isinstance(args[1], FieldRef):
1037
+ raise TypeError(f"{op} expects an expression for the second argument; use .val/.grad.")
1038
+ elif op in {"action", "gaction"}:
1039
+ if len(args) != 2 or not isinstance(args[0], FieldRef):
1040
+ raise TypeError(f"{op} expects FieldRef as the first argument.")
1041
+ if op == "action" and isinstance(args[1], FieldRef):
1042
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1043
+ if op == "gaction" and isinstance(args[1], FieldRef):
1044
+ raise TypeError("gaction expects an expression for the second argument; use .grad.")
1045
+ elif op == "outer":
1046
+ if len(args) != 2 or not all(isinstance(arg, FieldRef) for arg in args):
1047
+ raise TypeError("outer expects two FieldRef operands.")
1048
+ if args[0].role != "test" or args[1].role != "trial":
1049
+ raise TypeError("outer expects outer(test, trial).")
1050
+
1051
+
1052
+ def make_eval_plan(expr: Expr) -> EvalPlan:
1053
+ nodes = tuple(expr.postorder_expr())
1054
+ _validate_eval_plan(nodes)
1055
+ index: dict[int, int] = {}
1056
+ for i, node in enumerate(nodes):
1057
+ index.setdefault(id(node), i)
1058
+ return EvalPlan(expr=expr, nodes=nodes, index=index)
1059
+
1060
+
1061
+ def eval_with_plan(
1062
+ plan: EvalPlan,
1063
+ ctx: VolumeContext | SurfaceContext,
1064
+ params: ParamsLike,
1065
+ u_elem: UElement | None = None,
1066
+ ):
1067
+ nodes = plan.nodes
1068
+ index = plan.index
1069
+ vals: list[Any] = [None] * len(nodes)
1070
+
1071
+ def get(obj):
1072
+ if isinstance(obj, Expr):
1073
+ return vals[index[id(obj)]]
1074
+ if isinstance(obj, FieldRef):
1075
+ raise TypeError(
1076
+ "FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
1077
+ )
1078
+ if isinstance(obj, ParamRef):
1079
+ return params
1080
+ return obj
1081
+
1082
+ for i, node in enumerate(nodes):
1083
+ op = node.op
1084
+ args = node.args
1085
+
1086
+ if op == "lit":
1087
+ vals[i] = args[0]
1088
+ continue
1089
+ if op == "getattr":
1090
+ base = get(args[0])
1091
+ name = args[1]
1092
+ if isinstance(base, dict):
1093
+ vals[i] = base[name]
1094
+ else:
1095
+ vals[i] = getattr(base, name)
1096
+ continue
1097
+ if op == "value":
1098
+ ref = args[0]
1099
+ assert isinstance(ref, FieldRef)
1100
+ field = _eval_field(ref, ctx, params)
1101
+ if ref.role == "unknown":
1102
+ vals[i] = _eval_unknown_value(ref, field, u_elem)
1103
+ else:
1104
+ vals[i] = field.N
1105
+ continue
1106
+ if op == "grad":
1107
+ ref = args[0]
1108
+ assert isinstance(ref, FieldRef)
1109
+ field = _eval_field(ref, ctx, params)
1110
+ if ref.role == "unknown":
1111
+ vals[i] = _eval_unknown_grad(ref, field, u_elem)
1112
+ else:
1113
+ vals[i] = field.gradN
1114
+ continue
1115
+ if op == "pow":
1116
+ base = get(args[0])
1117
+ exp = get(args[1])
1118
+ vals[i] = base**exp
1119
+ continue
1120
+ if op == "eye":
1121
+ vals[i] = jnp.eye(int(args[0]))
1122
+ continue
1123
+ if op == "det":
1124
+ vals[i] = jnp.linalg.det(get(args[0]))
1125
+ continue
1126
+ if op == "inv":
1127
+ vals[i] = jnp.linalg.inv(get(args[0]))
1128
+ continue
1129
+ if op == "transpose":
1130
+ vals[i] = jnp.swapaxes(get(args[0]), -1, -2)
1131
+ continue
1132
+ if op == "log":
1133
+ vals[i] = jnp.log(get(args[0]))
1134
+ continue
1135
+ if op == "surface_normal":
1136
+ normal = getattr(ctx, "normal", None)
1137
+ if normal is None:
1138
+ raise ValueError("surface normal is not available in context")
1139
+ vals[i] = normal
1140
+ continue
1141
+ if op == "surface_measure":
1142
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
1143
+ raise TypeError("surface measure requires SurfaceContext.")
1144
+ vals[i] = ctx.w * ctx.detJ
1145
+ continue
1146
+ if op == "volume_measure":
1147
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
1148
+ raise TypeError("volume measure requires VolumeContext.")
1149
+ vals[i] = ctx.w * ctx.test.detJ
1150
+ continue
1151
+ if op == "sym_grad":
1152
+ ref = args[0]
1153
+ assert isinstance(ref, FieldRef)
1154
+ field = _eval_field(ref, ctx, params)
1155
+ if ref.role == "unknown":
1156
+ if u_elem is None:
1157
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
1158
+ u_local = _extract_unknown_elem(ref, u_elem)
1159
+ vals[i] = _ops.sym_grad_u(field, u_local)
1160
+ else:
1161
+ vals[i] = _ops.sym_grad(field)
1162
+ continue
1163
+ if op == "outer":
1164
+ a, b = args
1165
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
1166
+ raise TypeError("outer expects FieldRef operands.")
1167
+ test, trial = a, b
1168
+ vals[i] = _basis_outer(test, trial, ctx, params)
1169
+ continue
1170
+ if op == "add":
1171
+ vals[i] = get(args[0]) + get(args[1])
1172
+ continue
1173
+ if op == "sub":
1174
+ vals[i] = get(args[0]) - get(args[1])
1175
+ continue
1176
+ if op == "mul":
1177
+ a = get(args[0])
1178
+ b = get(args[1])
1179
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
1180
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
1181
+ a = a[:, None]
1182
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
1183
+ b = b[:, None]
1184
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
1185
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
1186
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
1187
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
1188
+ vals[i] = a * b
1189
+ continue
1190
+ if op == "matmul":
1191
+ a = get(args[0])
1192
+ b = get(args[1])
1193
+ if (
1194
+ hasattr(a, "ndim")
1195
+ and hasattr(b, "ndim")
1196
+ and a.ndim == 3
1197
+ and b.ndim == 3
1198
+ and a.shape[0] == b.shape[0]
1199
+ and a.shape[-1] == b.shape[-1]
1200
+ ):
1201
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
1202
+ else:
1203
+ raise TypeError(
1204
+ "Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
1205
+ )
1206
+ continue
1207
+ if op == "matmul_std":
1208
+ a = get(args[0])
1209
+ b = get(args[1])
1210
+ vals[i] = jnp.matmul(a, b)
1211
+ continue
1212
+ if op == "neg":
1213
+ vals[i] = -get(args[0])
1214
+ continue
1215
+ if op == "dot":
1216
+ ref = args[0]
1217
+ if isinstance(ref, FieldRef):
1218
+ vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
1219
+ else:
1220
+ a = get(args[0])
1221
+ b = get(args[1])
1222
+ if (
1223
+ hasattr(a, "ndim")
1224
+ and hasattr(b, "ndim")
1225
+ and a.ndim == 3
1226
+ and b.ndim == 3
1227
+ and a.shape[-1] == b.shape[-1]
1228
+ ):
1229
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
1230
+ else:
1231
+ vals[i] = jnp.matmul(a, b)
1232
+ continue
1233
+ if op == "sdot":
1234
+ ref = args[0]
1235
+ if isinstance(ref, FieldRef):
1236
+ vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
1237
+ else:
1238
+ a = get(args[0])
1239
+ b = get(args[1])
1240
+ if (
1241
+ hasattr(a, "ndim")
1242
+ and hasattr(b, "ndim")
1243
+ and a.ndim == 3
1244
+ and b.ndim == 3
1245
+ and a.shape[-1] == b.shape[-1]
1246
+ ):
1247
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
1248
+ else:
1249
+ vals[i] = jnp.matmul(a, b)
1250
+ continue
1251
+ if op == "ddot":
1252
+ if len(args) == 2:
1253
+ a = get(args[0])
1254
+ b = get(args[1])
1255
+ if (
1256
+ hasattr(a, "ndim")
1257
+ and hasattr(b, "ndim")
1258
+ and a.ndim == 3
1259
+ and b.ndim == 3
1260
+ and a.shape[0] == b.shape[0]
1261
+ and a.shape[1] == b.shape[1]
1262
+ ):
1263
+ vals[i] = jnp.einsum("qik,qim->qkm", a, b)
1264
+ else:
1265
+ vals[i] = _ops.ddot(a, b)
1266
+ else:
1267
+ vals[i] = _ops.ddot(get(args[0]), get(args[1]), get(args[2]))
1268
+ continue
1269
+ if op == "inner":
1270
+ a = get(args[0])
1271
+ b = get(args[1])
1272
+ vals[i] = jnp.einsum("...i,...i->...", a, b)
1273
+ continue
1274
+ if op == "action":
1275
+ ref = args[0]
1276
+ assert isinstance(ref, FieldRef)
1277
+ if isinstance(args[1], FieldRef):
1278
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1279
+ v_field = _eval_field(ref, ctx, params)
1280
+ s = get(args[1])
1281
+ value_dim = int(getattr(v_field, "value_dim", 1))
1282
+ # action maps a test field with a scalar/vector expression into nodal space.
1283
+ if value_dim == 1:
1284
+ if v_field.N.ndim != 2:
1285
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
1286
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
1287
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
1288
+ vals[i] = v_field.N * s
1289
+ else:
1290
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
1291
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
1292
+ vals[i] = _ops.dot(v_field, s)
1293
+ continue
1294
+ if op == "gaction":
1295
+ ref = args[0]
1296
+ assert isinstance(ref, FieldRef)
1297
+ v_field = _eval_field(ref, ctx, params)
1298
+ q = get(args[1])
1299
+ # gaction maps a flux-like expression to nodal space via test gradients.
1300
+ if v_field.gradN.ndim != 3:
1301
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
1302
+ if not hasattr(q, "ndim"):
1303
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1304
+ if q.ndim == 2:
1305
+ vals[i] = jnp.einsum("qaj,qj->qa", v_field.gradN, q)
1306
+ elif q.ndim == 3:
1307
+ if int(getattr(v_field, "value_dim", 1)) == 1:
1308
+ raise ValueError("gaction tensor flux requires vector test field.")
1309
+ vals[i] = jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
1310
+ else:
1311
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1312
+ continue
1313
+ if op == "transpose_last2":
1314
+ vals[i] = _ops.transpose_last2(get(args[0]))
1315
+ continue
1316
+ if op == "einsum":
1317
+ subscripts = args[0]
1318
+ operands = [
1319
+ (jnp.asarray(arg) if isinstance(arg, tuple) else arg)
1320
+ for arg in (get(arg) for arg in args[1:])
1321
+ ]
1322
+ vals[i] = jnp.einsum(subscripts, *operands)
1323
+ continue
1324
+
1325
+ raise ValueError(f"Unknown Expr op: {op}")
1326
+
1327
+ return vals[index[id(plan.expr)]]
1328
+
1329
+
1330
+ def eval_with_plan_numpy(
1331
+ plan: EvalPlan,
1332
+ ctx: VolumeContext | SurfaceContext,
1333
+ params: ParamsLike,
1334
+ u_elem: UElement | None = None,
1335
+ ):
1336
+ nodes = plan.nodes
1337
+ index = plan.index
1338
+ vals: list[Any] = [None] * len(nodes)
1339
+
1340
+ def get(obj):
1341
+ if isinstance(obj, Expr):
1342
+ return vals[index[id(obj)]]
1343
+ if isinstance(obj, FieldRef):
1344
+ raise TypeError(
1345
+ "FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
1346
+ )
1347
+ if isinstance(obj, ParamRef):
1348
+ return params
1349
+ return obj
1350
+
1351
+ for i, node in enumerate(nodes):
1352
+ op = node.op
1353
+ args = node.args
1354
+
1355
+ if op == "lit":
1356
+ vals[i] = args[0]
1357
+ continue
1358
+ if op == "getattr":
1359
+ base = get(args[0])
1360
+ name = args[1]
1361
+ if isinstance(base, dict):
1362
+ vals[i] = base[name]
1363
+ else:
1364
+ vals[i] = getattr(base, name)
1365
+ continue
1366
+ if op == "value":
1367
+ ref = args[0]
1368
+ assert isinstance(ref, FieldRef)
1369
+ field = _eval_field_np(ref, ctx, params)
1370
+ if ref.role == "unknown":
1371
+ vals[i] = _eval_unknown_value_np(ref, field, u_elem)
1372
+ else:
1373
+ vals[i] = field.N
1374
+ continue
1375
+ if op == "grad":
1376
+ ref = args[0]
1377
+ assert isinstance(ref, FieldRef)
1378
+ field = _eval_field_np(ref, ctx, params)
1379
+ if ref.role == "unknown":
1380
+ vals[i] = _eval_unknown_grad_np(ref, field, u_elem)
1381
+ else:
1382
+ vals[i] = field.gradN
1383
+ continue
1384
+ if op == "pow":
1385
+ base = get(args[0])
1386
+ exp = get(args[1])
1387
+ vals[i] = base**exp
1388
+ continue
1389
+ if op == "eye":
1390
+ vals[i] = np.eye(int(args[0]))
1391
+ continue
1392
+ if op == "det":
1393
+ vals[i] = np.linalg.det(get(args[0]))
1394
+ continue
1395
+ if op == "inv":
1396
+ vals[i] = np.linalg.inv(get(args[0]))
1397
+ continue
1398
+ if op == "transpose":
1399
+ vals[i] = np.swapaxes(get(args[0]), -1, -2)
1400
+ continue
1401
+ if op == "log":
1402
+ vals[i] = np.log(get(args[0]))
1403
+ continue
1404
+ if op == "surface_normal":
1405
+ normal = getattr(ctx, "normal", None)
1406
+ if normal is None:
1407
+ raise ValueError("surface normal is not available in context")
1408
+ vals[i] = normal
1409
+ continue
1410
+ if op == "surface_measure":
1411
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
1412
+ raise TypeError("surface measure requires SurfaceContext.")
1413
+ vals[i] = ctx.w * ctx.detJ
1414
+ continue
1415
+ if op == "volume_measure":
1416
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
1417
+ raise TypeError("volume measure requires VolumeContext.")
1418
+ vals[i] = ctx.w * ctx.test.detJ
1419
+ continue
1420
+ if op == "sym_grad":
1421
+ ref = args[0]
1422
+ assert isinstance(ref, FieldRef)
1423
+ field = _eval_field_np(ref, ctx, params)
1424
+ if ref.role == "unknown":
1425
+ if u_elem is None:
1426
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
1427
+ u_local = _extract_unknown_elem(ref, u_elem)
1428
+ vals[i] = _sym_grad_u_np(field, u_local)
1429
+ else:
1430
+ vals[i] = _sym_grad_np(field)
1431
+ continue
1432
+ if op == "outer":
1433
+ a, b = args
1434
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
1435
+ raise TypeError("outer expects FieldRef operands.")
1436
+ test, trial = a, b
1437
+ vals[i] = _basis_outer_np(test, trial, ctx, params)
1438
+ continue
1439
+ if op == "add":
1440
+ vals[i] = get(args[0]) + get(args[1])
1441
+ continue
1442
+ if op == "sub":
1443
+ vals[i] = get(args[0]) - get(args[1])
1444
+ continue
1445
+ if op == "mul":
1446
+ a = get(args[0])
1447
+ b = get(args[1])
1448
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
1449
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
1450
+ a = a[:, None]
1451
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
1452
+ b = b[:, None]
1453
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
1454
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
1455
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
1456
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
1457
+ vals[i] = a * b
1458
+ continue
1459
+ if op == "matmul":
1460
+ a = get(args[0])
1461
+ b = get(args[1])
1462
+ if (
1463
+ hasattr(a, "ndim")
1464
+ and hasattr(b, "ndim")
1465
+ and a.ndim == 3
1466
+ and b.ndim == 3
1467
+ and a.shape[0] == b.shape[0]
1468
+ and a.shape[-1] == b.shape[-1]
1469
+ ):
1470
+ vals[i] = np.einsum("qia,qja->qij", a, b)
1471
+ else:
1472
+ raise TypeError(
1473
+ "Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
1474
+ )
1475
+ continue
1476
+ if op == "matmul_std":
1477
+ a = get(args[0])
1478
+ b = get(args[1])
1479
+ vals[i] = np.matmul(a, b)
1480
+ continue
1481
+ if op == "neg":
1482
+ vals[i] = -get(args[0])
1483
+ continue
1484
+ if op == "dot":
1485
+ ref = args[0]
1486
+ if isinstance(ref, FieldRef):
1487
+ vals[i] = _dot_np(_eval_field_np(ref, ctx, params), get(args[1]))
1488
+ else:
1489
+ a = get(args[0])
1490
+ b = get(args[1])
1491
+ if (
1492
+ hasattr(a, "ndim")
1493
+ and hasattr(b, "ndim")
1494
+ and a.ndim >= 3
1495
+ and b.ndim >= 3
1496
+ and a.shape[0] == b.shape[0]
1497
+ and a.shape[1] == b.shape[1]
1498
+ ):
1499
+ vals[i] = np.einsum("qi...,qj...->qij...", a, b)
1500
+ else:
1501
+ vals[i] = np.matmul(a, b)
1502
+ continue
1503
+ if op == "sdot":
1504
+ ref = args[0]
1505
+ if isinstance(ref, FieldRef):
1506
+ vals[i] = _dot_np(_eval_field_np(ref, ctx, params), get(args[1]))
1507
+ else:
1508
+ a = get(args[0])
1509
+ b = get(args[1])
1510
+ if (
1511
+ hasattr(a, "ndim")
1512
+ and hasattr(b, "ndim")
1513
+ and a.ndim >= 3
1514
+ and b.ndim >= 3
1515
+ and a.shape[0] == b.shape[0]
1516
+ and a.shape[1] == b.shape[1]
1517
+ ):
1518
+ vals[i] = np.einsum("qi...,qj...->qij...", a, b)
1519
+ else:
1520
+ vals[i] = np.matmul(a, b)
1521
+ continue
1522
+ if op == "ddot":
1523
+ if len(args) == 2:
1524
+ a = get(args[0])
1525
+ b = get(args[1])
1526
+ if (
1527
+ hasattr(a, "ndim")
1528
+ and hasattr(b, "ndim")
1529
+ and a.ndim == 3
1530
+ and b.ndim == 3
1531
+ and a.shape[0] == b.shape[0]
1532
+ and a.shape[1] == b.shape[1]
1533
+ ):
1534
+ vals[i] = np.einsum("qik,qim->qkm", a, b)
1535
+ else:
1536
+ vals[i] = _ddot_np(a, b)
1537
+ else:
1538
+ vals[i] = _ddot_np(get(args[0]), get(args[1]), get(args[2]))
1539
+ continue
1540
+ if op == "inner":
1541
+ a = get(args[0])
1542
+ b = get(args[1])
1543
+ vals[i] = np.einsum("...i,...i->...", a, b)
1544
+ continue
1545
+ if op == "action":
1546
+ ref = args[0]
1547
+ assert isinstance(ref, FieldRef)
1548
+ if isinstance(args[1], FieldRef):
1549
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1550
+ v_field = _eval_field_np(ref, ctx, params)
1551
+ s = get(args[1])
1552
+ value_dim = int(getattr(v_field, "value_dim", 1))
1553
+ if value_dim == 1:
1554
+ if v_field.N.ndim != 2:
1555
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
1556
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
1557
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
1558
+ vals[i] = v_field.N * s
1559
+ else:
1560
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
1561
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
1562
+ vals[i] = _dot_np(v_field, s)
1563
+ continue
1564
+ if op == "gaction":
1565
+ ref = args[0]
1566
+ assert isinstance(ref, FieldRef)
1567
+ v_field = _eval_field_np(ref, ctx, params)
1568
+ q = get(args[1])
1569
+ if v_field.gradN.ndim != 3:
1570
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
1571
+ if not hasattr(q, "ndim"):
1572
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1573
+ if q.ndim == 2:
1574
+ vals[i] = np.einsum("qaj,qj->qa", v_field.gradN, q)
1575
+ elif q.ndim == 3:
1576
+ if int(getattr(v_field, "value_dim", 1)) == 1:
1577
+ raise ValueError("gaction tensor flux requires vector test field.")
1578
+ vals[i] = np.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
1579
+ else:
1580
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1581
+ continue
1582
+ if op == "transpose_last2":
1583
+ vals[i] = _transpose_last2_np(get(args[0]))
1584
+ continue
1585
+ if op == "einsum":
1586
+ subscripts = args[0]
1587
+ operands = [
1588
+ (np.asarray(arg) if isinstance(arg, tuple) else arg)
1589
+ for arg in (get(arg) for arg in args[1:])
1590
+ ]
1591
+ if "..." not in subscripts:
1592
+ has_extra = False
1593
+ parts = subscripts.split("->")
1594
+ in_terms = parts[0].split(",")
1595
+ out_term = parts[1] if len(parts) > 1 else None
1596
+ updated_terms = []
1597
+ for term, opnd in zip(in_terms, operands):
1598
+ if hasattr(opnd, "ndim") and opnd.ndim > len(term):
1599
+ has_extra = True
1600
+ updated_terms.append(term + "...")
1601
+ else:
1602
+ updated_terms.append(term)
1603
+ if has_extra:
1604
+ if out_term is not None:
1605
+ out_term = out_term + "..."
1606
+ subscripts = ",".join(updated_terms) + "->" + out_term
1607
+ else:
1608
+ subscripts = ",".join(updated_terms)
1609
+ vals[i] = np.einsum(subscripts, *operands)
1610
+ continue
1611
+
1612
+ raise ValueError(f"Unknown Expr op: {op}")
1613
+
1614
+ return vals[index[id(plan.expr)]]
1615
+
1616
+
437
1617
  def compile_surface_linear(fn):
438
1618
  """get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
439
1619
  if isinstance(fn, Expr):
@@ -441,27 +1621,60 @@ def compile_surface_linear(fn):
441
1621
  else:
442
1622
  v = test_ref()
443
1623
  p = param_ref()
444
- expr = None
445
- try:
446
- expr = fn(v, p)
447
- except TypeError:
448
- try:
449
- expr = fn(v)
450
- except TypeError:
451
- expr = None
1624
+ expr = _call_user(fn, v, params=p)
452
1625
 
1626
+ expr = _as_expr(expr)
453
1627
  if not isinstance(expr, Expr):
454
1628
  raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
455
1629
 
456
- includes_measure = _expr_contains(expr, "surface_measure")
457
- if not includes_measure:
1630
+ surface_count = _count_op(expr, "surface_measure")
1631
+ volume_count = _count_op(expr, "volume_measure")
1632
+ if surface_count == 0:
458
1633
  raise ValueError("Surface linear form must include ds().")
1634
+ if surface_count > 1:
1635
+ raise ValueError("Surface linear form must include ds() exactly once.")
1636
+ if volume_count > 0:
1637
+ raise ValueError("Surface linear form must not include dOmega().")
1638
+
1639
+ plan = make_eval_plan(expr)
1640
+
1641
+ def _form(ctx, params):
1642
+ return eval_with_plan(plan, ctx, params)
1643
+
1644
+ _form._includes_measure = True # type: ignore[attr-defined]
1645
+ return _tag_form(_form, kind="linear", domain="surface")
1646
+
1647
+
1648
+ def compile_surface_bilinear(fn):
1649
+ """get_compiled a surface bilinear form into a kernel (ctx, params) -> ndarray."""
1650
+ if isinstance(fn, Expr):
1651
+ expr = fn
1652
+ else:
1653
+ v = test_ref()
1654
+ u = trial_ref()
1655
+ p = param_ref()
1656
+ expr = _call_user(fn, u, v, params=p)
1657
+
1658
+ expr = _as_expr(expr)
1659
+ if not isinstance(expr, Expr):
1660
+ raise ValueError("Surface bilinear form must return an Expr; use ds() in the expression.")
1661
+
1662
+ surface_count = _count_op(expr, "surface_measure")
1663
+ volume_count = _count_op(expr, "volume_measure")
1664
+ if surface_count == 0:
1665
+ raise ValueError("Surface bilinear form must include ds().")
1666
+ if surface_count > 1:
1667
+ raise ValueError("Surface bilinear form must include ds() exactly once.")
1668
+ if volume_count > 0:
1669
+ raise ValueError("Surface bilinear form must not include dOmega().")
1670
+
1671
+ plan = make_eval_plan(expr)
459
1672
 
460
1673
  def _form(ctx, params):
461
- return _as_expr(expr).eval(ctx, params)
1674
+ return eval_with_plan(plan, ctx, params)
462
1675
 
463
- _form._includes_measure = includes_measure # type: ignore[attr-defined]
464
- return _form
1676
+ _form._includes_measure = True # type: ignore[attr-defined]
1677
+ return _tag_form(_form, kind="bilinear", domain="surface")
465
1678
 
466
1679
 
467
1680
  class LinearForm:
@@ -524,25 +1737,33 @@ def compile_residual(fn):
524
1737
  v = test_ref()
525
1738
  u = unknown_ref()
526
1739
  p = param_ref()
527
- try:
528
- expr = fn(v, u, p)
529
- except TypeError:
530
- expr = fn(v, u)
1740
+ expr = _call_user(fn, v, u, params=p)
1741
+ expr = _as_expr(expr)
1742
+ if not isinstance(expr, Expr):
1743
+ raise TypeError("Residual form must return an Expr.")
531
1744
 
532
- includes_measure = _expr_contains(expr, "volume_measure")
533
- if not includes_measure:
1745
+ volume_count = _count_op(expr, "volume_measure")
1746
+ surface_count = _count_op(expr, "surface_measure")
1747
+ if volume_count == 0:
534
1748
  raise ValueError("Volume residual form must include dOmega().")
1749
+ if volume_count > 1:
1750
+ raise ValueError("Volume residual form must include dOmega() exactly once.")
1751
+ if surface_count > 0:
1752
+ raise ValueError("Volume residual form must not include ds().")
1753
+
1754
+ plan = make_eval_plan(expr)
535
1755
 
536
1756
  def _form(ctx, u_elem, params):
537
- return _as_expr(expr).eval(ctx, params, u_elem=u_elem)
1757
+ return eval_with_plan(plan, ctx, params, u_elem=u_elem)
538
1758
 
539
- _form._includes_measure = includes_measure
540
- return _form
1759
+ _form._includes_measure = True
1760
+ return _tag_form(_form, kind="residual", domain="volume")
541
1761
 
542
1762
 
543
1763
  def compile_mixed_residual(residuals: dict[str, Callable]):
544
1764
  """get_compiled mixed residuals keyed by field name."""
545
1765
  compiled = {}
1766
+ plans = {}
546
1767
  includes_measure = {}
547
1768
  for name, fn in residuals.items():
548
1769
  if isinstance(fn, Expr):
@@ -551,20 +1772,174 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
551
1772
  v = test_ref(name)
552
1773
  u = unknown_ref(name)
553
1774
  p = param_ref()
554
- try:
555
- expr = fn(v, u, p)
556
- except TypeError:
557
- expr = fn(v, u)
558
- compiled[name] = _as_expr(expr)
559
- includes_measure[name] = _expr_contains(compiled[name], "volume_measure")
560
- if not includes_measure[name]:
1775
+ expr = _call_user(fn, v, u, params=p)
1776
+ expr = _as_expr(expr)
1777
+ if not isinstance(expr, Expr):
1778
+ raise TypeError(f"Mixed residual '{name}' must return an Expr.")
1779
+ compiled[name] = expr
1780
+ plans[name] = make_eval_plan(expr)
1781
+ volume_count = _count_op(compiled[name], "volume_measure")
1782
+ surface_count = _count_op(compiled[name], "surface_measure")
1783
+ includes_measure[name] = volume_count == 1
1784
+ if volume_count == 0:
561
1785
  raise ValueError(f"Mixed residual '{name}' must include dOmega().")
1786
+ if volume_count > 1:
1787
+ raise ValueError(f"Mixed residual '{name}' must include dOmega() exactly once.")
1788
+ if surface_count > 0:
1789
+ raise ValueError(f"Mixed residual '{name}' must not include ds().")
1790
+
1791
+ class _MixedContextView:
1792
+ def __init__(self, ctx, field_name: str):
1793
+ self._ctx = ctx
1794
+ self.fields = ctx.fields
1795
+ self.x_q = ctx.x_q
1796
+ self.w = ctx.w
1797
+ self.elem_id = ctx.elem_id
1798
+ self.trial_fields = ctx.trial_fields
1799
+ self.test_fields = ctx.test_fields
1800
+ self.unknown_fields = ctx.unknown_fields
1801
+ self.unknown = ctx.unknown
1802
+
1803
+ pair = ctx.fields[field_name]
1804
+ self.test = pair.test
1805
+ self.trial = pair.trial
1806
+ self.v = pair.test
1807
+ self.u = pair.trial
1808
+
1809
+ if hasattr(ctx, "normal"):
1810
+ self.normal = ctx.normal
1811
+
1812
+ def __getattr__(self, name: str):
1813
+ return getattr(self._ctx, name)
562
1814
 
563
1815
  def _form(ctx, u_elem, params):
564
- return {name: expr.eval(ctx, params, u_elem=u_elem) for name, expr in compiled.items()}
1816
+ return {
1817
+ name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1818
+ for name, plan in plans.items()
1819
+ }
565
1820
 
566
1821
  _form._includes_measure = includes_measure
567
- return _form
1822
+ return _tag_form(_form, kind="residual", domain="volume")
1823
+
1824
+
1825
+ def compile_mixed_surface_residual(residuals: dict[str, Callable]):
1826
+ """get_compiled mixed surface residuals keyed by field name."""
1827
+ compiled = {}
1828
+ plans = {}
1829
+ includes_measure = {}
1830
+ for name, fn in residuals.items():
1831
+ if isinstance(fn, Expr):
1832
+ expr = fn
1833
+ else:
1834
+ v = test_ref(name)
1835
+ u = unknown_ref(name)
1836
+ p = param_ref()
1837
+ expr = _call_user(fn, v, u, params=p)
1838
+ expr = _as_expr(expr)
1839
+ if not isinstance(expr, Expr):
1840
+ raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
1841
+ compiled[name] = expr
1842
+ plans[name] = make_eval_plan(expr)
1843
+ volume_count = _count_op(compiled[name], "volume_measure")
1844
+ surface_count = _count_op(compiled[name], "surface_measure")
1845
+ includes_measure[name] = surface_count == 1
1846
+ if surface_count == 0:
1847
+ raise ValueError(f"Mixed surface residual '{name}' must include ds().")
1848
+ if surface_count > 1:
1849
+ raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
1850
+ if volume_count > 0:
1851
+ raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
1852
+
1853
+ class _MixedContextView:
1854
+ def __init__(self, ctx, field_name: str):
1855
+ self._ctx = ctx
1856
+ self.fields = ctx.fields
1857
+ self.x_q = ctx.x_q
1858
+ self.w = ctx.w
1859
+ self.detJ = ctx.detJ
1860
+ self.normal = getattr(ctx, "normal", None)
1861
+ self.trial_fields = ctx.trial_fields
1862
+ self.test_fields = ctx.test_fields
1863
+ self.unknown_fields = ctx.unknown_fields
1864
+ self.unknown = getattr(ctx, "unknown", None)
1865
+
1866
+ pair = ctx.fields[field_name]
1867
+ self.test = pair.test
1868
+ self.trial = pair.trial
1869
+ self.v = pair.test
1870
+ self.u = pair.trial
1871
+
1872
+ def __getattr__(self, name: str):
1873
+ return getattr(self._ctx, name)
1874
+
1875
+ def _form(ctx, u_elem, params):
1876
+ return {
1877
+ name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1878
+ for name, plan in plans.items()
1879
+ }
1880
+
1881
+ _form._includes_measure = includes_measure
1882
+ return _tag_form(_form, kind="residual", domain="surface")
1883
+
1884
+
1885
+ def compile_mixed_surface_residual_numpy(residuals: dict[str, Callable]):
1886
+ """Mixed surface residual compiled for numpy evaluation."""
1887
+ compiled = {}
1888
+ plans = {}
1889
+ includes_measure = {}
1890
+ for name, fn in residuals.items():
1891
+ if isinstance(fn, Expr):
1892
+ expr = fn
1893
+ else:
1894
+ v = test_ref(name)
1895
+ u = unknown_ref(name)
1896
+ p = param_ref()
1897
+ expr = _call_user(fn, v, u, params=p)
1898
+ expr = _as_expr(expr)
1899
+ if not isinstance(expr, Expr):
1900
+ raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
1901
+ compiled[name] = expr
1902
+ plans[name] = make_eval_plan(expr)
1903
+ volume_count = _count_op(compiled[name], "volume_measure")
1904
+ surface_count = _count_op(compiled[name], "surface_measure")
1905
+ includes_measure[name] = surface_count == 1
1906
+ if surface_count == 0:
1907
+ raise ValueError(f"Mixed surface residual '{name}' must include ds().")
1908
+ if surface_count > 1:
1909
+ raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
1910
+ if volume_count > 0:
1911
+ raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
1912
+
1913
+ class _MixedContextView:
1914
+ def __init__(self, ctx, field_name: str):
1915
+ self._ctx = ctx
1916
+ self.fields = ctx.fields
1917
+ self.x_q = ctx.x_q
1918
+ self.w = ctx.w
1919
+ self.detJ = ctx.detJ
1920
+ self.normal = getattr(ctx, "normal", None)
1921
+ self.trial_fields = ctx.trial_fields
1922
+ self.test_fields = ctx.test_fields
1923
+ self.unknown_fields = ctx.unknown_fields
1924
+ self.unknown = getattr(ctx, "unknown", None)
1925
+
1926
+ pair = ctx.fields[field_name]
1927
+ self.test = pair.test
1928
+ self.trial = pair.trial
1929
+ self.v = pair.test
1930
+ self.u = pair.trial
1931
+
1932
+ def __getattr__(self, name: str):
1933
+ return getattr(self._ctx, name)
1934
+
1935
+ def _form(ctx, u_elem, params):
1936
+ return {
1937
+ name: eval_with_plan_numpy(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
1938
+ for name, plan in plans.items()
1939
+ }
1940
+
1941
+ _form._includes_measure = includes_measure
1942
+ return _tag_form(_form, kind="residual", domain="surface")
568
1943
 
569
1944
 
570
1945
  class MixedWeakForm:
@@ -579,220 +1954,28 @@ class MixedWeakForm:
579
1954
  return compile_mixed_residual(self.residuals)
580
1955
 
581
1956
 
582
- def _eval_expr(expr: Expr, ctx, params, u_elem=None):
583
- op = expr.op
584
- args = expr.args
1957
+ def make_mixed_residuals(residuals: dict[str, Callable] | None = None, **kwargs) -> dict[str, Callable]:
1958
+ """
1959
+ Helper to build mixed residual dictionaries.
585
1960
 
586
- if op == "lit":
587
- return args[0]
588
- if op == "param":
589
- return params
590
- if op == "getattr":
591
- base = _eval_value(args[0], ctx, params, u_elem=u_elem)
592
- name = args[1]
593
- if isinstance(base, dict):
594
- return base[name]
595
- return getattr(base, name)
596
- if op == "field":
597
- role, name = args
598
- if name is not None:
599
- if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
600
- if name in ctx.trial_fields:
601
- return ctx.trial_fields[name]
602
- if role == "test" and getattr(ctx, "test_fields", None) is not None:
603
- if name in ctx.test_fields:
604
- return ctx.test_fields[name]
605
- if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
606
- if name in ctx.unknown_fields:
607
- return ctx.unknown_fields[name]
608
- fields = getattr(ctx, "fields", None)
609
- if fields is not None and name in fields:
610
- group = fields[name]
611
- if isinstance(group, dict):
612
- if role in group:
613
- return group[role]
614
- if "field" in group:
615
- return group["field"]
616
- return group
617
- if role == "trial":
618
- return ctx.trial
619
- if role == "test":
620
- return ctx.test
621
- if role == "unknown":
622
- return getattr(ctx, "unknown", ctx.trial)
623
- raise ValueError(f"Unknown field role: {role}")
624
- if op == "value":
625
- field = _eval_field(args[0], ctx, params)
626
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
627
- return _eval_unknown_value(args[0], field, u_elem)
628
- return field.N
629
- if op == "grad":
630
- field = _eval_field(args[0], ctx, params)
631
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
632
- return _eval_unknown_grad(args[0], field, u_elem)
633
- return field.gradN
634
- if op == "pow":
635
- base = _eval_value(args[0], ctx, params, u_elem=u_elem)
636
- exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
637
- return base**exp
638
- if op == "eye":
639
- return jnp.eye(int(args[0]))
640
- if op == "det":
641
- return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
642
- if op == "inv":
643
- return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
644
- if op == "transpose":
645
- return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
646
- if op == "log":
647
- return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
648
- if op == "surface_normal":
649
- normal = getattr(ctx, "normal", None)
650
- if normal is None:
651
- raise ValueError("surface normal is not available in context")
652
- return normal
653
- if op == "surface_measure":
654
- if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
655
- raise ValueError("surface measure requires surface context with w and detJ.")
656
- return ctx.w * ctx.detJ
657
- if op == "volume_measure":
658
- if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
659
- raise ValueError("volume measure requires FormContext with w and test.detJ.")
660
- return ctx.w * ctx.test.detJ
661
- if op == "sym_grad":
662
- field = _eval_field(args[0], ctx, params)
663
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
664
- if u_elem is None:
665
- raise ValueError("u_elem is required to evaluate unknown sym_grad.")
666
- u_local = _extract_unknown_elem(args[0], u_elem)
667
- return _ops.sym_grad_u(field, u_local)
668
- return _ops.sym_grad(field)
669
- if op == "outer":
670
- a, b = args
671
- if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
672
- raise TypeError("outer expects FieldRef operands.")
673
- if a.role == b.role:
674
- raise ValueError("outer requires one trial and one test field.")
675
- test = a if a.role == "test" else b
676
- trial = b if a.role == "test" else a
677
- v_field = _eval_field(test, ctx, params)
678
- u_field = _eval_field(trial, ctx, params)
679
- if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
680
- raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
681
- vN = v_field.N
682
- uN = u_field.N
683
- return jnp.einsum("qi,qj->qij", vN, uN)
684
- if op == "add":
685
- return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
686
- if op == "sub":
687
- return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
688
- if op == "mul":
689
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
690
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
691
- if hasattr(a, "ndim") and hasattr(b, "ndim"):
692
- if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
693
- a = a[:, None]
694
- elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
695
- b = b[:, None]
696
- elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
697
- b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
698
- elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
699
- a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
700
- return a * b
701
- if op == "matmul":
702
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
703
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
704
- if (
705
- hasattr(a, "ndim")
706
- and hasattr(b, "ndim")
707
- and a.ndim == 3
708
- and b.ndim == 3
709
- and a.shape[0] == b.shape[0]
710
- and a.shape[-1] == b.shape[-1]
711
- ):
712
- return jnp.einsum("qia,qja->qij", a, b)
713
- return a @ b
714
- if op == "matmul_std":
715
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
716
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
717
- return jnp.matmul(a, b)
718
- if op == "neg":
719
- return -_eval_value(args[0], ctx, params, u_elem=u_elem)
720
- if op == "dot":
721
- if isinstance(args[0], FieldRef):
722
- return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
723
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
724
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
725
- if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
726
- return jnp.einsum("qia,qja->qij", a, b)
727
- return jnp.matmul(a, b)
728
- if op == "sdot":
729
- if isinstance(args[0], FieldRef):
730
- return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
731
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
732
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
733
- if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
734
- return jnp.einsum("qia,qja->qij", a, b)
735
- return jnp.matmul(a, b)
736
- if op == "ddot":
737
- if len(args) == 2:
738
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
739
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
740
- if (
741
- hasattr(a, "ndim")
742
- and hasattr(b, "ndim")
743
- and a.ndim == 3
744
- and b.ndim == 3
745
- and a.shape[0] == b.shape[0]
746
- and a.shape[1] == b.shape[1]
747
- ):
748
- return jnp.einsum("qik,qim->qkm", a, b)
749
- return _ops.ddot(a, b)
750
- return _ops.ddot(
751
- _eval_value(args[0], ctx, params, u_elem=u_elem),
752
- _eval_value(args[1], ctx, params, u_elem=u_elem),
753
- _eval_value(args[2], ctx, params, u_elem=u_elem),
754
- )
755
- if op == "inner":
756
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
757
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
758
- return jnp.einsum("...i,...i->...", a, b)
759
- if op == "action":
760
- if isinstance(args[1], FieldRef):
761
- raise ValueError("action expects a scalar expression; use u.val for unknowns.")
762
- v_field = _eval_field(args[0], ctx, params)
763
- s = _eval_value(args[1], ctx, params, u_elem=u_elem)
764
- value_dim = int(getattr(v_field, "value_dim", 1))
765
- if value_dim == 1:
766
- if v_field.N.ndim != 2:
767
- raise ValueError("action expects scalar test field with N shape (q, ndofs).")
768
- if hasattr(s, "ndim") and s.ndim not in (0, 1):
769
- raise ValueError("action expects scalar s with shape (q,) or scalar.")
770
- return v_field.N * s
771
- if hasattr(s, "ndim") and s.ndim not in (1, 2):
772
- raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
773
- return _ops.dot(v_field, s)
774
- if op == "gaction":
775
- v_field = _eval_field(args[0], ctx, params)
776
- q = _eval_value(args[1], ctx, params, u_elem=u_elem)
777
- if v_field.gradN.ndim != 3:
778
- raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
779
- if not hasattr(q, "ndim"):
780
- raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
781
- if q.ndim == 2:
782
- return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
783
- if q.ndim == 3:
784
- if int(getattr(v_field, "value_dim", 1)) == 1:
785
- raise ValueError("gaction tensor flux requires vector test field.")
786
- return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
787
- raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
788
- if op == "transpose_last2":
789
- return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
790
- if op == "einsum":
791
- subscripts = args[0]
792
- operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
793
- return jnp.einsum(subscripts, *operands)
794
-
795
- raise ValueError(f"Unknown Expr op: {op}")
1961
+ Example:
1962
+ res = make_mixed_residuals(u=res_u, p=res_p)
1963
+ """
1964
+ if residuals is not None and kwargs:
1965
+ raise ValueError("Pass either residuals dict or keyword residuals, not both.")
1966
+ if residuals is None:
1967
+ return dict(kwargs)
1968
+ return dict(residuals)
1969
+
1970
+
1971
+ def _eval_expr(
1972
+ expr: Expr,
1973
+ ctx: VolumeContext | SurfaceContext,
1974
+ params: ParamsLike,
1975
+ u_elem: UElement | None = None,
1976
+ ):
1977
+ plan = make_eval_plan(expr)
1978
+ return eval_with_plan(plan, ctx, params, u_elem=u_elem)
796
1979
 
797
1980
 
798
1981
  __all__ = [
@@ -802,16 +1985,23 @@ __all__ = [
802
1985
  "trial_ref",
803
1986
  "test_ref",
804
1987
  "unknown_ref",
1988
+ "zero_ref",
805
1989
  "param_ref",
806
1990
  "Params",
807
1991
  "MixedWeakForm",
1992
+ "make_mixed_residuals",
1993
+ "kernel",
808
1994
  "ResidualForm",
809
1995
  "compile_bilinear",
810
1996
  "compile_linear",
811
1997
  "compile_residual",
1998
+ "compile_surface_bilinear",
1999
+ "compile_mixed_surface_residual",
2000
+ "compile_mixed_surface_residual_numpy",
812
2001
  "compile_mixed_residual",
813
2002
  "grad",
814
2003
  "sym_grad",
2004
+ "outer",
815
2005
  "dot",
816
2006
  "ddot",
817
2007
  "inner",
@@ -824,5 +2014,6 @@ __all__ = [
824
2014
  "log",
825
2015
  "transpose_last2",
826
2016
  "matmul",
2017
+ "matmul_std",
827
2018
  "einsum",
828
2019
  ]