fluxfem 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fluxfem/core/weakform.py CHANGED
@@ -1,82 +1,312 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable
4
+ from typing import Any, Callable, Iterator, Literal, get_args
5
+ import inspect
6
+
7
+ import numpy as np
5
8
 
6
9
  import jax.numpy as jnp
7
10
  import jax
8
11
 
9
12
  from ..physics import operators as _ops
13
+ from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
10
14
 
11
15
 
16
+ OpName = Literal[
17
+ "lit",
18
+ "getattr",
19
+ "value",
20
+ "grad",
21
+ "pow",
22
+ "eye",
23
+ "det",
24
+ "inv",
25
+ "transpose",
26
+ "log",
27
+ "surface_normal",
28
+ "surface_measure",
29
+ "volume_measure",
30
+ "sym_grad",
31
+ "outer",
32
+ "add",
33
+ "sub",
34
+ "mul",
35
+ "matmul",
36
+ "matmul_std",
37
+ "neg",
38
+ "dot",
39
+ "sdot",
40
+ "ddot",
41
+ "inner",
42
+ "action",
43
+ "gaction",
44
+ "transpose_last2",
45
+ "einsum",
46
+ ]
47
+
48
+ # Use OpName as the single source of truth for valid ops.
49
+ _OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
50
+
51
+
52
+ _PRECEDENCE: dict[str, int] = {
53
+ "add": 10,
54
+ "sub": 10,
55
+ "mul": 20,
56
+ "matmul": 20,
57
+ "matmul_std": 20,
58
+ "inner": 20,
59
+ "dot": 20,
60
+ "sdot": 20,
61
+ "ddot": 20,
62
+ "pow": 30,
63
+ "neg": 40,
64
+ "transpose": 50,
65
+ }
66
+
67
+
68
+ def _pretty_render_arg(arg, prec: int | None = None) -> str:
69
+ if isinstance(arg, Expr):
70
+ return _pretty_expr(arg, prec or 0)
71
+ if isinstance(arg, FieldRef):
72
+ if arg.name is None:
73
+ return f"{arg.role}"
74
+ return f"{arg.role}:{arg.name}"
75
+ if isinstance(arg, ParamRef):
76
+ return "param"
77
+ return repr(arg)
78
+
79
+
80
+ def _pretty_wrap(text: str, prec: int, parent_prec: int) -> str:
81
+ if prec < parent_prec:
82
+ return f"({text})"
83
+ return text
84
+
85
+
86
+ def _pretty_expr(expr: Expr, parent_prec: int = 0) -> str:
87
+ op = expr.op
88
+ args = expr.args
89
+
90
+ if op == "lit":
91
+ return repr(args[0])
92
+ if op == "getattr":
93
+ base = _pretty_render_arg(args[0], _PRECEDENCE.get("transpose", 50))
94
+ return f"{base}.{args[1]}"
95
+ if op == "value":
96
+ return f"val({_pretty_render_arg(args[0])})"
97
+ if op == "grad":
98
+ return f"grad({_pretty_render_arg(args[0])})"
99
+ if op == "sym_grad":
100
+ return f"sym_grad({_pretty_render_arg(args[0])})"
101
+ if op == "neg":
102
+ inner = _pretty_render_arg(args[0], _PRECEDENCE["neg"])
103
+ return _pretty_wrap(f"-{inner}", _PRECEDENCE["neg"], parent_prec)
104
+ if op == "transpose":
105
+ inner = _pretty_render_arg(args[0], _PRECEDENCE["transpose"])
106
+ return _pretty_wrap(f"{inner}.T", _PRECEDENCE["transpose"], parent_prec)
107
+ if op == "pow":
108
+ base = _pretty_render_arg(args[0], _PRECEDENCE["pow"])
109
+ exp = _pretty_render_arg(args[1], _PRECEDENCE["pow"] + 1)
110
+ return _pretty_wrap(f"{base}**{exp}", _PRECEDENCE["pow"], parent_prec)
111
+ if op in {"add", "sub", "mul", "matmul", "dot", "sdot", "ddot"}:
112
+ left = _pretty_render_arg(args[0], _PRECEDENCE[op])
113
+ right = _pretty_render_arg(args[1], _PRECEDENCE[op] + 1)
114
+ symbol = {
115
+ "add": "+",
116
+ "sub": "-",
117
+ "mul": "*",
118
+ "matmul": "@",
119
+ "inner": "|",
120
+ "dot": "dot",
121
+ "sdot": "sdot",
122
+ "ddot": "ddot",
123
+ }[op]
124
+ if symbol in {"dot", "sdot", "ddot"}:
125
+ text = f"{symbol}({left}, {right})"
126
+ else:
127
+ text = f"{left} {symbol} {right}"
128
+ return _pretty_wrap(text, _PRECEDENCE[op], parent_prec)
129
+ if op == "inner":
130
+ return f"inner({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
131
+ if op in {"action", "gaction"}:
132
+ return f"{op}({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
133
+ if op == "matmul_std":
134
+ return f"matmul_std({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
135
+ if op == "outer":
136
+ return f"outer({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
137
+ if op in {
138
+ "eye",
139
+ "det",
140
+ "inv",
141
+ "log",
142
+ "surface_normal",
143
+ "surface_measure",
144
+ "volume_measure",
145
+ "transpose_last2",
146
+ "einsum",
147
+ }:
148
+ rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
149
+ return f"{op}({rendered})"
150
+ rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
151
+ return f"{op}({rendered})"
152
+
153
+
154
+ def _as_expr(obj) -> Expr | FieldRef | ParamRef:
155
+ """Normalize inputs into Expr/FieldRef/ParamRef nodes."""
156
+ if isinstance(obj, Expr):
157
+ return obj
158
+ if isinstance(obj, FieldRef):
159
+ return obj
160
+ if isinstance(obj, ParamRef):
161
+ return obj
162
+ if isinstance(obj, (int, float, bool, str)):
163
+ return Expr("lit", obj)
164
+ if isinstance(obj, np.generic):
165
+ return Expr("lit", obj.item())
166
+ if isinstance(obj, tuple):
167
+ try:
168
+ hash(obj)
169
+ except TypeError as exc:
170
+ raise TypeError(
171
+ "Expr tuple literal must be hashable; use only immutable items."
172
+ ) from exc
173
+ return Expr("lit", obj)
174
+ raise TypeError(
175
+ "Expr literal must be a scalar or hashable tuple. "
176
+ "Arrays are not allowed; pass them via params (ParamRef/params.xxx)."
177
+ )
178
+
179
+
180
+ @dataclass(frozen=True, slots=True, init=False)
12
181
  class Expr:
13
- """Expression tree node evaluated against a FormContext."""
182
+ """Expression tree node evaluated against a FormContext.
183
+
184
+ Compile flow (recommended):
185
+ - build an Expr via operators/refs
186
+ - compile_* builds an EvalPlan (postorder nodes + index)
187
+ - eval_with_plan(plan, ctx, params, u_elem) evaluates per element
14
188
 
15
- def __init__(self, op: str, *args):
189
+ Expr.eval is a debug/single-shot path that creates a plan on demand.
190
+ """
191
+
192
+ op: OpName
193
+ args: tuple[Any, ...]
194
+
195
+ def __init__(self, op: OpName, *args):
196
+ if op not in _OP_NAMES:
197
+ raise ValueError(f"Unknown Expr op: {op!r}")
16
198
  object.__setattr__(self, "op", op)
17
199
  object.__setattr__(self, "args", args)
18
200
 
19
201
  def eval(self, ctx, params=None, u_elem=None):
202
+ """Evaluate the expression against a context (debug/single-shot path)."""
20
203
  return _eval_expr(self, ctx, params, u_elem=u_elem)
21
204
 
205
+ def children(self) -> tuple[Any, ...]:
206
+ """Return direct child nodes (Expr/FieldRef/ParamRef) for traversal."""
207
+ return tuple(arg for arg in self.args if isinstance(arg, (Expr, FieldRef, ParamRef)))
208
+
209
+ def walk(self) -> Iterator[Any]:
210
+ """Depth-first walk over nodes, including leaf FieldRef/ParamRef."""
211
+ yield self
212
+ for child in self.children():
213
+ if isinstance(child, Expr):
214
+ yield from child.walk()
215
+ else:
216
+ yield child
217
+
218
+ def postorder(self) -> Iterator[Any]:
219
+ """Postorder walk over nodes, including leaf FieldRef/ParamRef."""
220
+ for child in self.children():
221
+ if isinstance(child, Expr):
222
+ yield from child.postorder()
223
+ else:
224
+ yield child
225
+ yield self
226
+
227
+ def postorder_expr(self) -> Iterator["Expr"]:
228
+ """Postorder walk over Expr nodes only (for eval planning)."""
229
+ for arg in self.args:
230
+ if isinstance(arg, Expr):
231
+ yield from arg.postorder_expr()
232
+ yield self
233
+
22
234
  def _binop(self, other, op):
23
235
  return Expr(op, self, _as_expr(other))
24
236
 
25
237
  def __add__(self, other):
238
+ """Add expressions: `a + b`."""
26
239
  return self._binop(other, "add")
27
240
 
28
241
  def __radd__(self, other):
29
- return _as_expr(other)._binop(self, "add")
242
+ """Right-add expressions: `1 + expr`."""
243
+ return Expr("add", _as_expr(other), self)
30
244
 
31
245
  def __sub__(self, other):
246
+ """Subtract expressions: `a - b`."""
32
247
  return self._binop(other, "sub")
33
248
 
34
249
  def __rsub__(self, other):
35
- return _as_expr(other)._binop(self, "sub")
250
+ """Right-subtract expressions: `1 - expr`."""
251
+ return Expr("sub", _as_expr(other), self)
36
252
 
37
253
  def __mul__(self, other):
254
+ """Multiply expressions: `a * b`."""
38
255
  return self._binop(other, "mul")
39
256
 
40
257
  def __rmul__(self, other):
41
- return _as_expr(other)._binop(self, "mul")
258
+ """Right-multiply expressions: `2 * expr`."""
259
+ return Expr("mul", _as_expr(other), self)
42
260
 
43
261
  def __matmul__(self, other):
262
+ """Matrix product: `a @ b` (FEM-specific contraction semantics)."""
44
263
  return self._binop(other, "matmul")
45
264
 
46
265
  def __rmatmul__(self, other):
47
- return _as_expr(other)._binop(self, "matmul")
266
+ """Right-matmul: `A @ expr`."""
267
+ return Expr("matmul", _as_expr(other), self)
48
268
 
49
269
  def __or__(self, other):
50
- return self._binop(other, "inner")
270
+ """Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
271
+ if isinstance(other, FieldRef):
272
+ raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
273
+ return Expr("inner", self, _as_expr(other))
51
274
 
52
275
  def __ror__(self, other):
53
- return _as_expr(other)._binop(self, "inner")
276
+ """Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
277
+ if isinstance(other, FieldRef):
278
+ raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
279
+ return Expr("inner", _as_expr(other), self)
54
280
 
55
281
  def __pow__(self, power, modulo=None):
282
+ """Power: `a ** p` (no modulo support)."""
56
283
  if modulo is not None:
57
284
  raise ValueError("modulo is not supported for Expr exponentiation.")
58
285
  return Expr("pow", self, _as_expr(power))
59
286
 
60
287
  def __neg__(self):
288
+ """Unary negation: `-expr`."""
61
289
  return Expr("neg", self)
62
290
 
63
291
  @property
64
292
  def T(self):
293
+ """Transpose view: `expr.T`."""
65
294
  return Expr("transpose", self)
66
295
 
296
+ def __repr__(self) -> str:
297
+ return self.pretty()
298
+
299
+ def pretty(self) -> str:
300
+ return _pretty_expr(self)
301
+
67
302
 
68
- @dataclass(frozen=True)
69
- class FieldRef(Expr):
303
+ @dataclass(frozen=True, slots=True)
304
+ class FieldRef:
70
305
  """Symbolic reference to trial/test/unknown field, optionally by name."""
71
306
 
72
307
  role: str
73
308
  name: str | None = None
74
309
 
75
- def __init__(self, role: str, name: str | None = None):
76
- object.__setattr__(self, "role", role)
77
- object.__setattr__(self, "name", name)
78
- super().__init__("field", role, name)
79
-
80
310
  @property
81
311
  def val(self):
82
312
  return Expr("value", self)
@@ -91,12 +321,18 @@ class FieldRef(Expr):
91
321
 
92
322
  def __mul__(self, other):
93
323
  if isinstance(other, FieldRef):
94
- return Expr("outer", self, other)
324
+ raise TypeError(
325
+ "FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
326
+ "action(v, s), or dot(v, q)."
327
+ )
95
328
  return Expr("mul", Expr("value", self), _as_expr(other))
96
329
 
97
330
  def __rmul__(self, other):
98
331
  if isinstance(other, FieldRef):
99
- return Expr("outer", other, self)
332
+ raise TypeError(
333
+ "FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
334
+ "action(v, s), or dot(v, q)."
335
+ )
100
336
  return Expr("mul", _as_expr(other), Expr("value", self))
101
337
 
102
338
  def __add__(self, other):
@@ -113,22 +349,23 @@ class FieldRef(Expr):
113
349
 
114
350
  def __or__(self, other):
115
351
  if isinstance(other, FieldRef):
116
- return Expr("inner", self, other)
117
- return Expr("sdot", self, _as_expr(other))
352
+ raise TypeError(
353
+ "FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
354
+ )
355
+ return Expr("dot", self, _as_expr(other))
118
356
 
119
357
  def __ror__(self, other):
120
358
  if isinstance(other, FieldRef):
121
- return Expr("inner", other, self)
122
- return Expr("sdot", _as_expr(other), self)
359
+ raise TypeError(
360
+ "FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
361
+ )
362
+ return Expr("dot", _as_expr(other), self)
123
363
 
124
364
 
125
- @dataclass(frozen=True)
126
- class ParamRef(Expr):
365
+ @dataclass(frozen=True, slots=True)
366
+ class ParamRef:
127
367
  """Symbolic reference to params passed into the kernel."""
128
368
 
129
- def __init__(self):
130
- super().__init__("param")
131
-
132
369
  def __getattr__(self, name: str):
133
370
  return Expr("getattr", self, name)
134
371
 
@@ -179,13 +416,11 @@ def param_ref() -> ParamRef:
179
416
  return ParamRef()
180
417
 
181
418
 
182
- def _as_expr(obj) -> Expr:
183
- if isinstance(obj, Expr):
184
- return obj
185
- return Expr("lit", obj)
186
-
187
-
188
- def _eval_field(obj: Any, ctx, params):
419
+ def _eval_field(
420
+ obj: Any,
421
+ ctx: VolumeContext | SurfaceContext,
422
+ params: ParamsLike,
423
+ ) -> FormFieldLike:
189
424
  if isinstance(obj, FieldRef):
190
425
  if obj.name is not None:
191
426
  mixed_fields = getattr(ctx, "fields", None)
@@ -226,25 +461,23 @@ def _eval_field(obj: Any, ctx, params):
226
461
  if obj.role == "unknown":
227
462
  return getattr(ctx, "unknown", ctx.trial)
228
463
  raise ValueError(f"Unknown field role: {obj.role}")
229
- if isinstance(obj, Expr):
230
- val = obj.eval(ctx, params)
231
- if hasattr(val, "N"):
232
- return val
233
464
  raise TypeError("Expected a field reference for this operator.")
234
465
 
235
466
 
236
- def _eval_value(obj: Any, ctx, params, u_elem=None):
237
- if isinstance(obj, FieldRef):
238
- field = _eval_field(obj, ctx, params)
239
- if obj.role == "unknown":
240
- return _eval_unknown_value(obj, field, u_elem)
241
- return field.N
242
- if isinstance(obj, Expr):
243
- return obj.eval(ctx, params, u_elem=u_elem)
244
- return obj
467
+ # def _eval_value(obj: Any, ctx, params, u_elem=None):
468
+ # if isinstance(obj, FieldRef):
469
+ # field = _eval_field(obj, ctx, params)
470
+ # if obj.role == "unknown":
471
+ # return _eval_unknown_value(obj, field, u_elem)
472
+ # return field.N
473
+ # if isinstance(obj, ParamRef):
474
+ # return params
475
+ # if isinstance(obj, Expr):
476
+ # return obj.eval(ctx, params, u_elem=u_elem)
477
+ # return obj
245
478
 
246
479
 
247
- def _extract_unknown_elem(field_ref: FieldRef, u_elem):
480
+ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
248
481
  if u_elem is None:
249
482
  raise ValueError("u_elem is required to evaluate unknown field value.")
250
483
  if isinstance(u_elem, dict):
@@ -255,7 +488,17 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem):
255
488
  return u_elem
256
489
 
257
490
 
258
- def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
491
+ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
492
+ v_field = _eval_field(test, ctx, params)
493
+ u_field = _eval_field(trial, ctx, params)
494
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
495
+ raise ValueError(
496
+ "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
497
+ )
498
+ return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
499
+
500
+
501
+ def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
259
502
  u_local = _extract_unknown_elem(field_ref, u_elem)
260
503
  value_dim = int(getattr(field, "value_dim", 1))
261
504
  if value_dim == 1:
@@ -264,7 +507,7 @@ def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
264
507
  return jnp.einsum("qa,ai->qi", field.N, u_nodes)
265
508
 
266
509
 
267
- def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
510
+ def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
268
511
  u_local = _extract_unknown_elem(field_ref, u_elem)
269
512
  if u_local is None:
270
513
  raise ValueError("u_elem is required to evaluate unknown field gradient.")
@@ -285,6 +528,15 @@ def sym_grad(field) -> Expr:
285
528
  return Expr("sym_grad", _as_expr(field))
286
529
 
287
530
 
531
+ def outer(a, b) -> Expr:
532
+ """Outer product of scalar fields: `outer(v, u)` (test, trial)."""
533
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
534
+ raise TypeError("outer expects FieldRef operands.")
535
+ if a.role != "test" or b.role != "trial":
536
+ raise TypeError("outer expects outer(test, trial).")
537
+ return Expr("outer", a, b)
538
+
539
+
288
540
  def dot(a, b) -> Expr:
289
541
  """Dot product or vector load helper."""
290
542
  return Expr("dot", _as_expr(a), _as_expr(b))
@@ -303,7 +555,7 @@ def ddot(a, b, c=None) -> Expr:
303
555
 
304
556
 
305
557
  def inner(a, b) -> Expr:
306
- """Inner product over the last axis."""
558
+ """Inner product over the last axis (tensor-level)."""
307
559
  return Expr("inner", _as_expr(a), _as_expr(b))
308
560
 
309
561
 
@@ -363,7 +615,12 @@ def transpose_last2(a) -> Expr:
363
615
 
364
616
 
365
617
  def matmul(a, b) -> Expr:
366
- """Matrix product with standard semantics (no special 3D contraction)."""
618
+ """FEM-specific batched contraction (same semantics as `@`)."""
619
+ return Expr("matmul", _as_expr(a), _as_expr(b))
620
+
621
+
622
+ def matmul_std(a, b) -> Expr:
623
+ """Standard matrix product (`jnp.matmul` semantics)."""
367
624
  return Expr("matmul_std", _as_expr(a), _as_expr(b))
368
625
 
369
626
 
@@ -374,9 +631,22 @@ def einsum(subscripts: str, *args) -> Expr:
374
631
 
375
632
  def _call_user(fn, *args, params):
376
633
  try:
634
+ sig = inspect.signature(fn)
635
+ except (TypeError, ValueError):
636
+ return fn(*args, params)
637
+
638
+ params_list = list(sig.parameters.values())
639
+ if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params_list):
377
640
  return fn(*args, params)
378
- except TypeError:
379
- return fn(*args)
641
+ positional = [
642
+ p
643
+ for p in params_list
644
+ if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
645
+ ]
646
+ max_positional = len(positional)
647
+ if len(args) + 1 <= max_positional:
648
+ return fn(*args, params)
649
+ return fn(*args)
380
650
 
381
651
 
382
652
  def compile_bilinear(fn):
@@ -387,19 +657,26 @@ def compile_bilinear(fn):
387
657
  u = trial_ref()
388
658
  v = test_ref()
389
659
  p = param_ref()
390
- try:
391
- expr = fn(u, v, p)
392
- except TypeError:
393
- expr = fn(u, v)
660
+ expr = _call_user(fn, u, v, params=p)
661
+ expr = _as_expr(expr)
662
+ if not isinstance(expr, Expr):
663
+ raise TypeError("Bilinear form must return an Expr.")
394
664
 
395
- includes_measure = _expr_contains(expr, "volume_measure")
396
- if not includes_measure:
665
+ volume_count = _count_op(expr, "volume_measure")
666
+ surface_count = _count_op(expr, "surface_measure")
667
+ if volume_count == 0:
397
668
  raise ValueError("Volume bilinear form must include dOmega().")
669
+ if volume_count > 1:
670
+ raise ValueError("Volume bilinear form must include dOmega() exactly once.")
671
+ if surface_count > 0:
672
+ raise ValueError("Volume bilinear form must not include ds().")
673
+
674
+ plan = make_eval_plan(expr)
398
675
 
399
676
  def _form(ctx, params):
400
- return _as_expr(expr).eval(ctx, params)
677
+ return eval_with_plan(plan, ctx, params)
401
678
 
402
- _form._includes_measure = includes_measure
679
+ _form._includes_measure = True
403
680
  return _form
404
681
 
405
682
 
@@ -410,19 +687,26 @@ def compile_linear(fn):
410
687
  else:
411
688
  v = test_ref()
412
689
  p = param_ref()
413
- try:
414
- expr = fn(v, p)
415
- except TypeError:
416
- expr = fn(v)
690
+ expr = _call_user(fn, v, params=p)
691
+ expr = _as_expr(expr)
692
+ if not isinstance(expr, Expr):
693
+ raise TypeError("Linear form must return an Expr.")
417
694
 
418
- includes_measure = _expr_contains(expr, "volume_measure")
419
- if not includes_measure:
695
+ volume_count = _count_op(expr, "volume_measure")
696
+ surface_count = _count_op(expr, "surface_measure")
697
+ if volume_count == 0:
420
698
  raise ValueError("Volume linear form must include dOmega().")
699
+ if volume_count > 1:
700
+ raise ValueError("Volume linear form must include dOmega() exactly once.")
701
+ if surface_count > 0:
702
+ raise ValueError("Volume linear form must not include ds().")
703
+
704
+ plan = make_eval_plan(expr)
421
705
 
422
706
  def _form(ctx, params):
423
- return _as_expr(expr).eval(ctx, params)
707
+ return eval_with_plan(plan, ctx, params)
424
708
 
425
- _form._includes_measure = includes_measure
709
+ _form._includes_measure = True
426
710
  return _form
427
711
 
428
712
 
@@ -434,6 +718,340 @@ def _expr_contains(expr: Expr, op: str) -> bool:
434
718
  return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
435
719
 
436
720
 
721
+ def _count_op(expr: Expr, op: str) -> int:
722
+ if not isinstance(expr, Expr):
723
+ return 0
724
+ count = 1 if expr.op == op else 0
725
+ for arg in expr.args:
726
+ if isinstance(arg, Expr):
727
+ count += _count_op(arg, op)
728
+ return count
729
+
730
+
731
+ @dataclass(frozen=True, slots=True)
732
+ class EvalPlan:
733
+ expr: Expr
734
+ nodes: tuple[Expr, ...]
735
+ index: dict[int, int]
736
+
737
+
738
+ def _validate_eval_plan(nodes: tuple[Expr, ...]) -> None:
739
+ fieldref_arg_ops = {
740
+ "value",
741
+ "grad",
742
+ "sym_grad",
743
+ "dot",
744
+ "sdot",
745
+ "action",
746
+ "gaction",
747
+ "outer",
748
+ }
749
+ for node in nodes:
750
+ op = node.op
751
+ args = node.args
752
+ if op not in fieldref_arg_ops:
753
+ if any(isinstance(arg, FieldRef) for arg in args):
754
+ raise TypeError(f"{op} cannot take FieldRef directly; wrap with .val/.grad/.sym_grad.")
755
+ if op in {"value", "grad", "sym_grad"}:
756
+ if len(args) != 1 or not isinstance(args[0], FieldRef):
757
+ raise TypeError(f"{op} expects FieldRef.")
758
+ elif op in {"dot", "sdot"}:
759
+ if len(args) != 2:
760
+ raise TypeError(f"{op} expects two arguments.")
761
+ if any(isinstance(arg, FieldRef) for arg in args):
762
+ if not isinstance(args[0], FieldRef):
763
+ raise TypeError(f"{op} expects FieldRef as the first argument.")
764
+ if isinstance(args[1], FieldRef):
765
+ raise TypeError(f"{op} expects an expression for the second argument; use .val/.grad.")
766
+ elif op in {"action", "gaction"}:
767
+ if len(args) != 2 or not isinstance(args[0], FieldRef):
768
+ raise TypeError(f"{op} expects FieldRef as the first argument.")
769
+ if op == "action" and isinstance(args[1], FieldRef):
770
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
771
+ if op == "gaction" and isinstance(args[1], FieldRef):
772
+ raise TypeError("gaction expects an expression for the second argument; use .grad.")
773
+ elif op == "outer":
774
+ if len(args) != 2 or not all(isinstance(arg, FieldRef) for arg in args):
775
+ raise TypeError("outer expects two FieldRef operands.")
776
+ if args[0].role != "test" or args[1].role != "trial":
777
+ raise TypeError("outer expects outer(test, trial).")
778
+
779
+
780
+ def make_eval_plan(expr: Expr) -> EvalPlan:
781
+ nodes = tuple(expr.postorder_expr())
782
+ _validate_eval_plan(nodes)
783
+ index: dict[int, int] = {}
784
+ for i, node in enumerate(nodes):
785
+ index.setdefault(id(node), i)
786
+ return EvalPlan(expr=expr, nodes=nodes, index=index)
787
+
788
+
789
+ def eval_with_plan(
790
+ plan: EvalPlan,
791
+ ctx: VolumeContext | SurfaceContext,
792
+ params: ParamsLike,
793
+ u_elem: UElement | None = None,
794
+ ):
795
+ nodes = plan.nodes
796
+ index = plan.index
797
+ vals: list[Any] = [None] * len(nodes)
798
+
799
+ def get(obj):
800
+ if isinstance(obj, Expr):
801
+ return vals[index[id(obj)]]
802
+ if isinstance(obj, FieldRef):
803
+ raise TypeError(
804
+ "FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
805
+ )
806
+ if isinstance(obj, ParamRef):
807
+ return params
808
+ return obj
809
+
810
+ for i, node in enumerate(nodes):
811
+ op = node.op
812
+ args = node.args
813
+
814
+ if op == "lit":
815
+ vals[i] = args[0]
816
+ continue
817
+ if op == "getattr":
818
+ base = get(args[0])
819
+ name = args[1]
820
+ if isinstance(base, dict):
821
+ vals[i] = base[name]
822
+ else:
823
+ vals[i] = getattr(base, name)
824
+ continue
825
+ if op == "value":
826
+ ref = args[0]
827
+ assert isinstance(ref, FieldRef)
828
+ field = _eval_field(ref, ctx, params)
829
+ if ref.role == "unknown":
830
+ vals[i] = _eval_unknown_value(ref, field, u_elem)
831
+ else:
832
+ vals[i] = field.N
833
+ continue
834
+ if op == "grad":
835
+ ref = args[0]
836
+ assert isinstance(ref, FieldRef)
837
+ field = _eval_field(ref, ctx, params)
838
+ if ref.role == "unknown":
839
+ vals[i] = _eval_unknown_grad(ref, field, u_elem)
840
+ else:
841
+ vals[i] = field.gradN
842
+ continue
843
+ if op == "pow":
844
+ base = get(args[0])
845
+ exp = get(args[1])
846
+ vals[i] = base**exp
847
+ continue
848
+ if op == "eye":
849
+ vals[i] = jnp.eye(int(args[0]))
850
+ continue
851
+ if op == "det":
852
+ vals[i] = jnp.linalg.det(get(args[0]))
853
+ continue
854
+ if op == "inv":
855
+ vals[i] = jnp.linalg.inv(get(args[0]))
856
+ continue
857
+ if op == "transpose":
858
+ vals[i] = jnp.swapaxes(get(args[0]), -1, -2)
859
+ continue
860
+ if op == "log":
861
+ vals[i] = jnp.log(get(args[0]))
862
+ continue
863
+ if op == "surface_normal":
864
+ normal = getattr(ctx, "normal", None)
865
+ if normal is None:
866
+ raise ValueError("surface normal is not available in context")
867
+ vals[i] = normal
868
+ continue
869
+ if op == "surface_measure":
870
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
871
+ raise TypeError("surface measure requires SurfaceContext.")
872
+ vals[i] = ctx.w * ctx.detJ
873
+ continue
874
+ if op == "volume_measure":
875
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
876
+ raise TypeError("volume measure requires VolumeContext.")
877
+ vals[i] = ctx.w * ctx.test.detJ
878
+ continue
879
+ if op == "sym_grad":
880
+ ref = args[0]
881
+ assert isinstance(ref, FieldRef)
882
+ field = _eval_field(ref, ctx, params)
883
+ if ref.role == "unknown":
884
+ if u_elem is None:
885
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
886
+ u_local = _extract_unknown_elem(ref, u_elem)
887
+ vals[i] = _ops.sym_grad_u(field, u_local)
888
+ else:
889
+ vals[i] = _ops.sym_grad(field)
890
+ continue
891
+ if op == "outer":
892
+ a, b = args
893
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
894
+ raise TypeError("outer expects FieldRef operands.")
895
+ test, trial = a, b
896
+ vals[i] = _basis_outer(test, trial, ctx, params)
897
+ continue
898
+ if op == "add":
899
+ vals[i] = get(args[0]) + get(args[1])
900
+ continue
901
+ if op == "sub":
902
+ vals[i] = get(args[0]) - get(args[1])
903
+ continue
904
+ if op == "mul":
905
+ a = get(args[0])
906
+ b = get(args[1])
907
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
908
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
909
+ a = a[:, None]
910
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
911
+ b = b[:, None]
912
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
913
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
914
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
915
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
916
+ vals[i] = a * b
917
+ continue
918
+ if op == "matmul":
919
+ a = get(args[0])
920
+ b = get(args[1])
921
+ if (
922
+ hasattr(a, "ndim")
923
+ and hasattr(b, "ndim")
924
+ and a.ndim == 3
925
+ and b.ndim == 3
926
+ and a.shape[0] == b.shape[0]
927
+ and a.shape[-1] == b.shape[-1]
928
+ ):
929
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
930
+ else:
931
+ raise TypeError(
932
+ "Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
933
+ )
934
+ continue
935
+ if op == "matmul_std":
936
+ a = get(args[0])
937
+ b = get(args[1])
938
+ vals[i] = jnp.matmul(a, b)
939
+ continue
940
+ if op == "neg":
941
+ vals[i] = -get(args[0])
942
+ continue
943
+ if op == "dot":
944
+ ref = args[0]
945
+ if isinstance(ref, FieldRef):
946
+ vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
947
+ else:
948
+ a = get(args[0])
949
+ b = get(args[1])
950
+ if (
951
+ hasattr(a, "ndim")
952
+ and hasattr(b, "ndim")
953
+ and a.ndim == 3
954
+ and b.ndim == 3
955
+ and a.shape[-1] == b.shape[-1]
956
+ ):
957
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
958
+ else:
959
+ vals[i] = jnp.matmul(a, b)
960
+ continue
961
+ if op == "sdot":
962
+ ref = args[0]
963
+ if isinstance(ref, FieldRef):
964
+ vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
965
+ else:
966
+ a = get(args[0])
967
+ b = get(args[1])
968
+ if (
969
+ hasattr(a, "ndim")
970
+ and hasattr(b, "ndim")
971
+ and a.ndim == 3
972
+ and b.ndim == 3
973
+ and a.shape[-1] == b.shape[-1]
974
+ ):
975
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
976
+ else:
977
+ vals[i] = jnp.matmul(a, b)
978
+ continue
979
+ if op == "ddot":
980
+ if len(args) == 2:
981
+ a = get(args[0])
982
+ b = get(args[1])
983
+ if (
984
+ hasattr(a, "ndim")
985
+ and hasattr(b, "ndim")
986
+ and a.ndim == 3
987
+ and b.ndim == 3
988
+ and a.shape[0] == b.shape[0]
989
+ and a.shape[1] == b.shape[1]
990
+ ):
991
+ vals[i] = jnp.einsum("qik,qim->qkm", a, b)
992
+ else:
993
+ vals[i] = _ops.ddot(a, b)
994
+ else:
995
+ vals[i] = _ops.ddot(get(args[0]), get(args[1]), get(args[2]))
996
+ continue
997
+ if op == "inner":
998
+ a = get(args[0])
999
+ b = get(args[1])
1000
+ vals[i] = jnp.einsum("...i,...i->...", a, b)
1001
+ continue
1002
+ if op == "action":
1003
+ ref = args[0]
1004
+ assert isinstance(ref, FieldRef)
1005
+ if isinstance(args[1], FieldRef):
1006
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1007
+ v_field = _eval_field(ref, ctx, params)
1008
+ s = get(args[1])
1009
+ value_dim = int(getattr(v_field, "value_dim", 1))
1010
+ # action maps a test field with a scalar/vector expression into nodal space.
1011
+ if value_dim == 1:
1012
+ if v_field.N.ndim != 2:
1013
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
1014
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
1015
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
1016
+ vals[i] = v_field.N * s
1017
+ else:
1018
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
1019
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
1020
+ vals[i] = _ops.dot(v_field, s)
1021
+ continue
1022
+ if op == "gaction":
1023
+ ref = args[0]
1024
+ assert isinstance(ref, FieldRef)
1025
+ v_field = _eval_field(ref, ctx, params)
1026
+ q = get(args[1])
1027
+ # gaction maps a flux-like expression to nodal space via test gradients.
1028
+ if v_field.gradN.ndim != 3:
1029
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
1030
+ if not hasattr(q, "ndim"):
1031
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1032
+ if q.ndim == 2:
1033
+ vals[i] = jnp.einsum("qaj,qj->qa", v_field.gradN, q)
1034
+ elif q.ndim == 3:
1035
+ if int(getattr(v_field, "value_dim", 1)) == 1:
1036
+ raise ValueError("gaction tensor flux requires vector test field.")
1037
+ vals[i] = jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
1038
+ else:
1039
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1040
+ continue
1041
+ if op == "transpose_last2":
1042
+ vals[i] = _ops.transpose_last2(get(args[0]))
1043
+ continue
1044
+ if op == "einsum":
1045
+ subscripts = args[0]
1046
+ operands = [get(arg) for arg in args[1:]]
1047
+ vals[i] = jnp.einsum(subscripts, *operands)
1048
+ continue
1049
+
1050
+ raise ValueError(f"Unknown Expr op: {op}")
1051
+
1052
+ return vals[index[id(plan.expr)]]
1053
+
1054
+
437
1055
  def compile_surface_linear(fn):
438
1056
  """get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
439
1057
  if isinstance(fn, Expr):
@@ -441,26 +1059,27 @@ def compile_surface_linear(fn):
441
1059
  else:
442
1060
  v = test_ref()
443
1061
  p = param_ref()
444
- expr = None
445
- try:
446
- expr = fn(v, p)
447
- except TypeError:
448
- try:
449
- expr = fn(v)
450
- except TypeError:
451
- expr = None
1062
+ expr = _call_user(fn, v, params=p)
452
1063
 
1064
+ expr = _as_expr(expr)
453
1065
  if not isinstance(expr, Expr):
454
1066
  raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
455
1067
 
456
- includes_measure = _expr_contains(expr, "surface_measure")
457
- if not includes_measure:
1068
+ surface_count = _count_op(expr, "surface_measure")
1069
+ volume_count = _count_op(expr, "volume_measure")
1070
+ if surface_count == 0:
458
1071
  raise ValueError("Surface linear form must include ds().")
1072
+ if surface_count > 1:
1073
+ raise ValueError("Surface linear form must include ds() exactly once.")
1074
+ if volume_count > 0:
1075
+ raise ValueError("Surface linear form must not include dOmega().")
1076
+
1077
+ plan = make_eval_plan(expr)
459
1078
 
460
1079
  def _form(ctx, params):
461
- return _as_expr(expr).eval(ctx, params)
1080
+ return eval_with_plan(plan, ctx, params)
462
1081
 
463
- _form._includes_measure = includes_measure # type: ignore[attr-defined]
1082
+ _form._includes_measure = True # type: ignore[attr-defined]
464
1083
  return _form
465
1084
 
466
1085
 
@@ -524,25 +1143,33 @@ def compile_residual(fn):
524
1143
  v = test_ref()
525
1144
  u = unknown_ref()
526
1145
  p = param_ref()
527
- try:
528
- expr = fn(v, u, p)
529
- except TypeError:
530
- expr = fn(v, u)
1146
+ expr = _call_user(fn, v, u, params=p)
1147
+ expr = _as_expr(expr)
1148
+ if not isinstance(expr, Expr):
1149
+ raise TypeError("Residual form must return an Expr.")
531
1150
 
532
- includes_measure = _expr_contains(expr, "volume_measure")
533
- if not includes_measure:
1151
+ volume_count = _count_op(expr, "volume_measure")
1152
+ surface_count = _count_op(expr, "surface_measure")
1153
+ if volume_count == 0:
534
1154
  raise ValueError("Volume residual form must include dOmega().")
1155
+ if volume_count > 1:
1156
+ raise ValueError("Volume residual form must include dOmega() exactly once.")
1157
+ if surface_count > 0:
1158
+ raise ValueError("Volume residual form must not include ds().")
1159
+
1160
+ plan = make_eval_plan(expr)
535
1161
 
536
1162
  def _form(ctx, u_elem, params):
537
- return _as_expr(expr).eval(ctx, params, u_elem=u_elem)
1163
+ return eval_with_plan(plan, ctx, params, u_elem=u_elem)
538
1164
 
539
- _form._includes_measure = includes_measure
1165
+ _form._includes_measure = True
540
1166
  return _form
541
1167
 
542
1168
 
543
1169
  def compile_mixed_residual(residuals: dict[str, Callable]):
544
1170
  """get_compiled mixed residuals keyed by field name."""
545
1171
  compiled = {}
1172
+ plans = {}
546
1173
  includes_measure = {}
547
1174
  for name, fn in residuals.items():
548
1175
  if isinstance(fn, Expr):
@@ -551,17 +1178,24 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
551
1178
  v = test_ref(name)
552
1179
  u = unknown_ref(name)
553
1180
  p = param_ref()
554
- try:
555
- expr = fn(v, u, p)
556
- except TypeError:
557
- expr = fn(v, u)
558
- compiled[name] = _as_expr(expr)
559
- includes_measure[name] = _expr_contains(compiled[name], "volume_measure")
560
- if not includes_measure[name]:
1181
+ expr = _call_user(fn, v, u, params=p)
1182
+ expr = _as_expr(expr)
1183
+ if not isinstance(expr, Expr):
1184
+ raise TypeError(f"Mixed residual '{name}' must return an Expr.")
1185
+ compiled[name] = expr
1186
+ plans[name] = make_eval_plan(expr)
1187
+ volume_count = _count_op(compiled[name], "volume_measure")
1188
+ surface_count = _count_op(compiled[name], "surface_measure")
1189
+ includes_measure[name] = volume_count == 1
1190
+ if volume_count == 0:
561
1191
  raise ValueError(f"Mixed residual '{name}' must include dOmega().")
1192
+ if volume_count > 1:
1193
+ raise ValueError(f"Mixed residual '{name}' must include dOmega() exactly once.")
1194
+ if surface_count > 0:
1195
+ raise ValueError(f"Mixed residual '{name}' must not include ds().")
562
1196
 
563
1197
  def _form(ctx, u_elem, params):
564
- return {name: expr.eval(ctx, params, u_elem=u_elem) for name, expr in compiled.items()}
1198
+ return {name: eval_with_plan(plan, ctx, params, u_elem=u_elem) for name, plan in plans.items()}
565
1199
 
566
1200
  _form._includes_measure = includes_measure
567
1201
  return _form
@@ -579,220 +1213,14 @@ class MixedWeakForm:
579
1213
  return compile_mixed_residual(self.residuals)
580
1214
 
581
1215
 
582
- def _eval_expr(expr: Expr, ctx, params, u_elem=None):
583
- op = expr.op
584
- args = expr.args
585
-
586
- if op == "lit":
587
- return args[0]
588
- if op == "param":
589
- return params
590
- if op == "getattr":
591
- base = _eval_value(args[0], ctx, params, u_elem=u_elem)
592
- name = args[1]
593
- if isinstance(base, dict):
594
- return base[name]
595
- return getattr(base, name)
596
- if op == "field":
597
- role, name = args
598
- if name is not None:
599
- if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
600
- if name in ctx.trial_fields:
601
- return ctx.trial_fields[name]
602
- if role == "test" and getattr(ctx, "test_fields", None) is not None:
603
- if name in ctx.test_fields:
604
- return ctx.test_fields[name]
605
- if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
606
- if name in ctx.unknown_fields:
607
- return ctx.unknown_fields[name]
608
- fields = getattr(ctx, "fields", None)
609
- if fields is not None and name in fields:
610
- group = fields[name]
611
- if isinstance(group, dict):
612
- if role in group:
613
- return group[role]
614
- if "field" in group:
615
- return group["field"]
616
- return group
617
- if role == "trial":
618
- return ctx.trial
619
- if role == "test":
620
- return ctx.test
621
- if role == "unknown":
622
- return getattr(ctx, "unknown", ctx.trial)
623
- raise ValueError(f"Unknown field role: {role}")
624
- if op == "value":
625
- field = _eval_field(args[0], ctx, params)
626
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
627
- return _eval_unknown_value(args[0], field, u_elem)
628
- return field.N
629
- if op == "grad":
630
- field = _eval_field(args[0], ctx, params)
631
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
632
- return _eval_unknown_grad(args[0], field, u_elem)
633
- return field.gradN
634
- if op == "pow":
635
- base = _eval_value(args[0], ctx, params, u_elem=u_elem)
636
- exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
637
- return base**exp
638
- if op == "eye":
639
- return jnp.eye(int(args[0]))
640
- if op == "det":
641
- return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
642
- if op == "inv":
643
- return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
644
- if op == "transpose":
645
- return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
646
- if op == "log":
647
- return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
648
- if op == "surface_normal":
649
- normal = getattr(ctx, "normal", None)
650
- if normal is None:
651
- raise ValueError("surface normal is not available in context")
652
- return normal
653
- if op == "surface_measure":
654
- if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
655
- raise ValueError("surface measure requires surface context with w and detJ.")
656
- return ctx.w * ctx.detJ
657
- if op == "volume_measure":
658
- if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
659
- raise ValueError("volume measure requires FormContext with w and test.detJ.")
660
- return ctx.w * ctx.test.detJ
661
- if op == "sym_grad":
662
- field = _eval_field(args[0], ctx, params)
663
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
664
- if u_elem is None:
665
- raise ValueError("u_elem is required to evaluate unknown sym_grad.")
666
- u_local = _extract_unknown_elem(args[0], u_elem)
667
- return _ops.sym_grad_u(field, u_local)
668
- return _ops.sym_grad(field)
669
- if op == "outer":
670
- a, b = args
671
- if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
672
- raise TypeError("outer expects FieldRef operands.")
673
- if a.role == b.role:
674
- raise ValueError("outer requires one trial and one test field.")
675
- test = a if a.role == "test" else b
676
- trial = b if a.role == "test" else a
677
- v_field = _eval_field(test, ctx, params)
678
- u_field = _eval_field(trial, ctx, params)
679
- if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
680
- raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
681
- vN = v_field.N
682
- uN = u_field.N
683
- return jnp.einsum("qi,qj->qij", vN, uN)
684
- if op == "add":
685
- return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
686
- if op == "sub":
687
- return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
688
- if op == "mul":
689
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
690
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
691
- if hasattr(a, "ndim") and hasattr(b, "ndim"):
692
- if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
693
- a = a[:, None]
694
- elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
695
- b = b[:, None]
696
- elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
697
- b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
698
- elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
699
- a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
700
- return a * b
701
- if op == "matmul":
702
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
703
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
704
- if (
705
- hasattr(a, "ndim")
706
- and hasattr(b, "ndim")
707
- and a.ndim == 3
708
- and b.ndim == 3
709
- and a.shape[0] == b.shape[0]
710
- and a.shape[-1] == b.shape[-1]
711
- ):
712
- return jnp.einsum("qia,qja->qij", a, b)
713
- return a @ b
714
- if op == "matmul_std":
715
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
716
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
717
- return jnp.matmul(a, b)
718
- if op == "neg":
719
- return -_eval_value(args[0], ctx, params, u_elem=u_elem)
720
- if op == "dot":
721
- if isinstance(args[0], FieldRef):
722
- return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
723
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
724
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
725
- if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
726
- return jnp.einsum("qia,qja->qij", a, b)
727
- return jnp.matmul(a, b)
728
- if op == "sdot":
729
- if isinstance(args[0], FieldRef):
730
- return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
731
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
732
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
733
- if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
734
- return jnp.einsum("qia,qja->qij", a, b)
735
- return jnp.matmul(a, b)
736
- if op == "ddot":
737
- if len(args) == 2:
738
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
739
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
740
- if (
741
- hasattr(a, "ndim")
742
- and hasattr(b, "ndim")
743
- and a.ndim == 3
744
- and b.ndim == 3
745
- and a.shape[0] == b.shape[0]
746
- and a.shape[1] == b.shape[1]
747
- ):
748
- return jnp.einsum("qik,qim->qkm", a, b)
749
- return _ops.ddot(a, b)
750
- return _ops.ddot(
751
- _eval_value(args[0], ctx, params, u_elem=u_elem),
752
- _eval_value(args[1], ctx, params, u_elem=u_elem),
753
- _eval_value(args[2], ctx, params, u_elem=u_elem),
754
- )
755
- if op == "inner":
756
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
757
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
758
- return jnp.einsum("...i,...i->...", a, b)
759
- if op == "action":
760
- if isinstance(args[1], FieldRef):
761
- raise ValueError("action expects a scalar expression; use u.val for unknowns.")
762
- v_field = _eval_field(args[0], ctx, params)
763
- s = _eval_value(args[1], ctx, params, u_elem=u_elem)
764
- value_dim = int(getattr(v_field, "value_dim", 1))
765
- if value_dim == 1:
766
- if v_field.N.ndim != 2:
767
- raise ValueError("action expects scalar test field with N shape (q, ndofs).")
768
- if hasattr(s, "ndim") and s.ndim not in (0, 1):
769
- raise ValueError("action expects scalar s with shape (q,) or scalar.")
770
- return v_field.N * s
771
- if hasattr(s, "ndim") and s.ndim not in (1, 2):
772
- raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
773
- return _ops.dot(v_field, s)
774
- if op == "gaction":
775
- v_field = _eval_field(args[0], ctx, params)
776
- q = _eval_value(args[1], ctx, params, u_elem=u_elem)
777
- if v_field.gradN.ndim != 3:
778
- raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
779
- if not hasattr(q, "ndim"):
780
- raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
781
- if q.ndim == 2:
782
- return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
783
- if q.ndim == 3:
784
- if int(getattr(v_field, "value_dim", 1)) == 1:
785
- raise ValueError("gaction tensor flux requires vector test field.")
786
- return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
787
- raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
788
- if op == "transpose_last2":
789
- return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
790
- if op == "einsum":
791
- subscripts = args[0]
792
- operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
793
- return jnp.einsum(subscripts, *operands)
794
-
795
- raise ValueError(f"Unknown Expr op: {op}")
1216
+ def _eval_expr(
1217
+ expr: Expr,
1218
+ ctx: VolumeContext | SurfaceContext,
1219
+ params: ParamsLike,
1220
+ u_elem: UElement | None = None,
1221
+ ):
1222
+ plan = make_eval_plan(expr)
1223
+ return eval_with_plan(plan, ctx, params, u_elem=u_elem)
796
1224
 
797
1225
 
798
1226
  __all__ = [
@@ -812,6 +1240,7 @@ __all__ = [
812
1240
  "compile_mixed_residual",
813
1241
  "grad",
814
1242
  "sym_grad",
1243
+ "outer",
815
1244
  "dot",
816
1245
  "ddot",
817
1246
  "inner",
@@ -824,5 +1253,6 @@ __all__ = [
824
1253
  "log",
825
1254
  "transpose_last2",
826
1255
  "matmul",
1256
+ "matmul_std",
827
1257
  "einsum",
828
1258
  ]