fluxfem 0.1.1a0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fluxfem/core/weakform.py CHANGED
@@ -1,82 +1,312 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable
4
+ from typing import Any, Callable, Iterator, Literal, get_args
5
+ import inspect
6
+
7
+ import numpy as np
5
8
 
6
9
  import jax.numpy as jnp
7
10
  import jax
8
11
 
9
12
  from ..physics import operators as _ops
13
+ from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
14
+
15
+
16
+ OpName = Literal[
17
+ "lit",
18
+ "getattr",
19
+ "value",
20
+ "grad",
21
+ "pow",
22
+ "eye",
23
+ "det",
24
+ "inv",
25
+ "transpose",
26
+ "log",
27
+ "surface_normal",
28
+ "surface_measure",
29
+ "volume_measure",
30
+ "sym_grad",
31
+ "outer",
32
+ "add",
33
+ "sub",
34
+ "mul",
35
+ "matmul",
36
+ "matmul_std",
37
+ "neg",
38
+ "dot",
39
+ "sdot",
40
+ "ddot",
41
+ "inner",
42
+ "action",
43
+ "gaction",
44
+ "transpose_last2",
45
+ "einsum",
46
+ ]
10
47
 
48
+ # Use OpName as the single source of truth for valid ops.
49
+ _OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
50
+
51
+
52
+ _PRECEDENCE: dict[str, int] = {
53
+ "add": 10,
54
+ "sub": 10,
55
+ "mul": 20,
56
+ "matmul": 20,
57
+ "matmul_std": 20,
58
+ "inner": 20,
59
+ "dot": 20,
60
+ "sdot": 20,
61
+ "ddot": 20,
62
+ "pow": 30,
63
+ "neg": 40,
64
+ "transpose": 50,
65
+ }
66
+
67
+
68
+ def _pretty_render_arg(arg, prec: int | None = None) -> str:
69
+ if isinstance(arg, Expr):
70
+ return _pretty_expr(arg, prec or 0)
71
+ if isinstance(arg, FieldRef):
72
+ if arg.name is None:
73
+ return f"{arg.role}"
74
+ return f"{arg.role}:{arg.name}"
75
+ if isinstance(arg, ParamRef):
76
+ return "param"
77
+ return repr(arg)
78
+
79
+
80
+ def _pretty_wrap(text: str, prec: int, parent_prec: int) -> str:
81
+ if prec < parent_prec:
82
+ return f"({text})"
83
+ return text
84
+
85
+
86
+ def _pretty_expr(expr: Expr, parent_prec: int = 0) -> str:
87
+ op = expr.op
88
+ args = expr.args
11
89
 
90
+ if op == "lit":
91
+ return repr(args[0])
92
+ if op == "getattr":
93
+ base = _pretty_render_arg(args[0], _PRECEDENCE.get("transpose", 50))
94
+ return f"{base}.{args[1]}"
95
+ if op == "value":
96
+ return f"val({_pretty_render_arg(args[0])})"
97
+ if op == "grad":
98
+ return f"grad({_pretty_render_arg(args[0])})"
99
+ if op == "sym_grad":
100
+ return f"sym_grad({_pretty_render_arg(args[0])})"
101
+ if op == "neg":
102
+ inner = _pretty_render_arg(args[0], _PRECEDENCE["neg"])
103
+ return _pretty_wrap(f"-{inner}", _PRECEDENCE["neg"], parent_prec)
104
+ if op == "transpose":
105
+ inner = _pretty_render_arg(args[0], _PRECEDENCE["transpose"])
106
+ return _pretty_wrap(f"{inner}.T", _PRECEDENCE["transpose"], parent_prec)
107
+ if op == "pow":
108
+ base = _pretty_render_arg(args[0], _PRECEDENCE["pow"])
109
+ exp = _pretty_render_arg(args[1], _PRECEDENCE["pow"] + 1)
110
+ return _pretty_wrap(f"{base}**{exp}", _PRECEDENCE["pow"], parent_prec)
111
+ if op in {"add", "sub", "mul", "matmul", "dot", "sdot", "ddot"}:
112
+ left = _pretty_render_arg(args[0], _PRECEDENCE[op])
113
+ right = _pretty_render_arg(args[1], _PRECEDENCE[op] + 1)
114
+ symbol = {
115
+ "add": "+",
116
+ "sub": "-",
117
+ "mul": "*",
118
+ "matmul": "@",
119
+ "inner": "|",
120
+ "dot": "dot",
121
+ "sdot": "sdot",
122
+ "ddot": "ddot",
123
+ }[op]
124
+ if symbol in {"dot", "sdot", "ddot"}:
125
+ text = f"{symbol}({left}, {right})"
126
+ else:
127
+ text = f"{left} {symbol} {right}"
128
+ return _pretty_wrap(text, _PRECEDENCE[op], parent_prec)
129
+ if op == "inner":
130
+ return f"inner({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
131
+ if op in {"action", "gaction"}:
132
+ return f"{op}({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
133
+ if op == "matmul_std":
134
+ return f"matmul_std({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
135
+ if op == "outer":
136
+ return f"outer({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
137
+ if op in {
138
+ "eye",
139
+ "det",
140
+ "inv",
141
+ "log",
142
+ "surface_normal",
143
+ "surface_measure",
144
+ "volume_measure",
145
+ "transpose_last2",
146
+ "einsum",
147
+ }:
148
+ rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
149
+ return f"{op}({rendered})"
150
+ rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
151
+ return f"{op}({rendered})"
152
+
153
+
154
+ def _as_expr(obj) -> Expr | FieldRef | ParamRef:
155
+ """Normalize inputs into Expr/FieldRef/ParamRef nodes."""
156
+ if isinstance(obj, Expr):
157
+ return obj
158
+ if isinstance(obj, FieldRef):
159
+ return obj
160
+ if isinstance(obj, ParamRef):
161
+ return obj
162
+ if isinstance(obj, (int, float, bool, str)):
163
+ return Expr("lit", obj)
164
+ if isinstance(obj, np.generic):
165
+ return Expr("lit", obj.item())
166
+ if isinstance(obj, tuple):
167
+ try:
168
+ hash(obj)
169
+ except TypeError as exc:
170
+ raise TypeError(
171
+ "Expr tuple literal must be hashable; use only immutable items."
172
+ ) from exc
173
+ return Expr("lit", obj)
174
+ raise TypeError(
175
+ "Expr literal must be a scalar or hashable tuple. "
176
+ "Arrays are not allowed; pass them via params (ParamRef/params.xxx)."
177
+ )
178
+
179
+
180
+ @dataclass(frozen=True, slots=True, init=False)
12
181
  class Expr:
13
- """Expression tree node evaluated against a FormContext."""
182
+ """Expression tree node evaluated against a FormContext.
183
+
184
+ Compile flow (recommended):
185
+ - build an Expr via operators/refs
186
+ - compile_* builds an EvalPlan (postorder nodes + index)
187
+ - eval_with_plan(plan, ctx, params, u_elem) evaluates per element
188
+
189
+ Expr.eval is a debug/single-shot path that creates a plan on demand.
190
+ """
14
191
 
15
- def __init__(self, op: str, *args):
192
+ op: OpName
193
+ args: tuple[Any, ...]
194
+
195
+ def __init__(self, op: OpName, *args):
196
+ if op not in _OP_NAMES:
197
+ raise ValueError(f"Unknown Expr op: {op!r}")
16
198
  object.__setattr__(self, "op", op)
17
199
  object.__setattr__(self, "args", args)
18
200
 
19
201
  def eval(self, ctx, params=None, u_elem=None):
202
+ """Evaluate the expression against a context (debug/single-shot path)."""
20
203
  return _eval_expr(self, ctx, params, u_elem=u_elem)
21
204
 
205
+ def children(self) -> tuple[Any, ...]:
206
+ """Return direct child nodes (Expr/FieldRef/ParamRef) for traversal."""
207
+ return tuple(arg for arg in self.args if isinstance(arg, (Expr, FieldRef, ParamRef)))
208
+
209
+ def walk(self) -> Iterator[Any]:
210
+ """Depth-first walk over nodes, including leaf FieldRef/ParamRef."""
211
+ yield self
212
+ for child in self.children():
213
+ if isinstance(child, Expr):
214
+ yield from child.walk()
215
+ else:
216
+ yield child
217
+
218
+ def postorder(self) -> Iterator[Any]:
219
+ """Postorder walk over nodes, including leaf FieldRef/ParamRef."""
220
+ for child in self.children():
221
+ if isinstance(child, Expr):
222
+ yield from child.postorder()
223
+ else:
224
+ yield child
225
+ yield self
226
+
227
+ def postorder_expr(self) -> Iterator["Expr"]:
228
+ """Postorder walk over Expr nodes only (for eval planning)."""
229
+ for arg in self.args:
230
+ if isinstance(arg, Expr):
231
+ yield from arg.postorder_expr()
232
+ yield self
233
+
22
234
  def _binop(self, other, op):
23
235
  return Expr(op, self, _as_expr(other))
24
236
 
25
237
  def __add__(self, other):
238
+ """Add expressions: `a + b`."""
26
239
  return self._binop(other, "add")
27
240
 
28
241
  def __radd__(self, other):
29
- return _as_expr(other)._binop(self, "add")
242
+ """Right-add expressions: `1 + expr`."""
243
+ return Expr("add", _as_expr(other), self)
30
244
 
31
245
  def __sub__(self, other):
246
+ """Subtract expressions: `a - b`."""
32
247
  return self._binop(other, "sub")
33
248
 
34
249
  def __rsub__(self, other):
35
- return _as_expr(other)._binop(self, "sub")
250
+ """Right-subtract expressions: `1 - expr`."""
251
+ return Expr("sub", _as_expr(other), self)
36
252
 
37
253
  def __mul__(self, other):
254
+ """Multiply expressions: `a * b`."""
38
255
  return self._binop(other, "mul")
39
256
 
40
257
  def __rmul__(self, other):
41
- return _as_expr(other)._binop(self, "mul")
258
+ """Right-multiply expressions: `2 * expr`."""
259
+ return Expr("mul", _as_expr(other), self)
42
260
 
43
261
  def __matmul__(self, other):
262
+ """Matrix product: `a @ b` (FEM-specific contraction semantics)."""
44
263
  return self._binop(other, "matmul")
45
264
 
46
265
  def __rmatmul__(self, other):
47
- return _as_expr(other)._binop(self, "matmul")
266
+ """Right-matmul: `A @ expr`."""
267
+ return Expr("matmul", _as_expr(other), self)
48
268
 
49
269
  def __or__(self, other):
50
- return self._binop(other, "inner")
270
+ """Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
271
+ if isinstance(other, FieldRef):
272
+ raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
273
+ return Expr("inner", self, _as_expr(other))
51
274
 
52
275
  def __ror__(self, other):
53
- return _as_expr(other)._binop(self, "inner")
276
+ """Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
277
+ if isinstance(other, FieldRef):
278
+ raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
279
+ return Expr("inner", _as_expr(other), self)
54
280
 
55
281
  def __pow__(self, power, modulo=None):
282
+ """Power: `a ** p` (no modulo support)."""
56
283
  if modulo is not None:
57
284
  raise ValueError("modulo is not supported for Expr exponentiation.")
58
285
  return Expr("pow", self, _as_expr(power))
59
286
 
60
287
  def __neg__(self):
288
+ """Unary negation: `-expr`."""
61
289
  return Expr("neg", self)
62
290
 
63
291
  @property
64
292
  def T(self):
293
+ """Transpose view: `expr.T`."""
65
294
  return Expr("transpose", self)
66
295
 
296
+ def __repr__(self) -> str:
297
+ return self.pretty()
67
298
 
68
- @dataclass(frozen=True)
69
- class FieldRef(Expr):
299
+ def pretty(self) -> str:
300
+ return _pretty_expr(self)
301
+
302
+
303
+ @dataclass(frozen=True, slots=True)
304
+ class FieldRef:
70
305
  """Symbolic reference to trial/test/unknown field, optionally by name."""
71
306
 
72
307
  role: str
73
308
  name: str | None = None
74
309
 
75
- def __init__(self, role: str, name: str | None = None):
76
- object.__setattr__(self, "role", role)
77
- object.__setattr__(self, "name", name)
78
- super().__init__("field", role, name)
79
-
80
310
  @property
81
311
  def val(self):
82
312
  return Expr("value", self)
@@ -91,12 +321,18 @@ class FieldRef(Expr):
91
321
 
92
322
  def __mul__(self, other):
93
323
  if isinstance(other, FieldRef):
94
- return Expr("outer", self, other)
324
+ raise TypeError(
325
+ "FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
326
+ "action(v, s), or dot(v, q)."
327
+ )
95
328
  return Expr("mul", Expr("value", self), _as_expr(other))
96
329
 
97
330
  def __rmul__(self, other):
98
331
  if isinstance(other, FieldRef):
99
- return Expr("outer", other, self)
332
+ raise TypeError(
333
+ "FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
334
+ "action(v, s), or dot(v, q)."
335
+ )
100
336
  return Expr("mul", _as_expr(other), Expr("value", self))
101
337
 
102
338
  def __add__(self, other):
@@ -113,22 +349,23 @@ class FieldRef(Expr):
113
349
 
114
350
  def __or__(self, other):
115
351
  if isinstance(other, FieldRef):
116
- return Expr("inner", self, other)
117
- return Expr("sdot", self, _as_expr(other))
352
+ raise TypeError(
353
+ "FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
354
+ )
355
+ return Expr("dot", self, _as_expr(other))
118
356
 
119
357
  def __ror__(self, other):
120
358
  if isinstance(other, FieldRef):
121
- return Expr("inner", other, self)
122
- return Expr("sdot", _as_expr(other), self)
359
+ raise TypeError(
360
+ "FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
361
+ )
362
+ return Expr("dot", _as_expr(other), self)
123
363
 
124
364
 
125
- @dataclass(frozen=True)
126
- class ParamRef(Expr):
365
+ @dataclass(frozen=True, slots=True)
366
+ class ParamRef:
127
367
  """Symbolic reference to params passed into the kernel."""
128
368
 
129
- def __init__(self):
130
- super().__init__("param")
131
-
132
369
  def __getattr__(self, name: str):
133
370
  return Expr("getattr", self, name)
134
371
 
@@ -179,13 +416,11 @@ def param_ref() -> ParamRef:
179
416
  return ParamRef()
180
417
 
181
418
 
182
- def _as_expr(obj) -> Expr:
183
- if isinstance(obj, Expr):
184
- return obj
185
- return Expr("lit", obj)
186
-
187
-
188
- def _eval_field(obj: Any, ctx, params):
419
+ def _eval_field(
420
+ obj: Any,
421
+ ctx: VolumeContext | SurfaceContext,
422
+ params: ParamsLike,
423
+ ) -> FormFieldLike:
189
424
  if isinstance(obj, FieldRef):
190
425
  if obj.name is not None:
191
426
  mixed_fields = getattr(ctx, "fields", None)
@@ -226,25 +461,23 @@ def _eval_field(obj: Any, ctx, params):
226
461
  if obj.role == "unknown":
227
462
  return getattr(ctx, "unknown", ctx.trial)
228
463
  raise ValueError(f"Unknown field role: {obj.role}")
229
- if isinstance(obj, Expr):
230
- val = obj.eval(ctx, params)
231
- if hasattr(val, "N"):
232
- return val
233
464
  raise TypeError("Expected a field reference for this operator.")
234
465
 
235
466
 
236
- def _eval_value(obj: Any, ctx, params, u_elem=None):
237
- if isinstance(obj, FieldRef):
238
- field = _eval_field(obj, ctx, params)
239
- if obj.role == "unknown":
240
- return _eval_unknown_value(obj, field, u_elem)
241
- return field.N
242
- if isinstance(obj, Expr):
243
- return obj.eval(ctx, params, u_elem=u_elem)
244
- return obj
467
+ # def _eval_value(obj: Any, ctx, params, u_elem=None):
468
+ # if isinstance(obj, FieldRef):
469
+ # field = _eval_field(obj, ctx, params)
470
+ # if obj.role == "unknown":
471
+ # return _eval_unknown_value(obj, field, u_elem)
472
+ # return field.N
473
+ # if isinstance(obj, ParamRef):
474
+ # return params
475
+ # if isinstance(obj, Expr):
476
+ # return obj.eval(ctx, params, u_elem=u_elem)
477
+ # return obj
245
478
 
246
479
 
247
- def _extract_unknown_elem(field_ref: FieldRef, u_elem):
480
+ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
248
481
  if u_elem is None:
249
482
  raise ValueError("u_elem is required to evaluate unknown field value.")
250
483
  if isinstance(u_elem, dict):
@@ -255,7 +488,17 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem):
255
488
  return u_elem
256
489
 
257
490
 
258
- def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
491
+ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
492
+ v_field = _eval_field(test, ctx, params)
493
+ u_field = _eval_field(trial, ctx, params)
494
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
495
+ raise ValueError(
496
+ "inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
497
+ )
498
+ return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
499
+
500
+
501
+ def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
259
502
  u_local = _extract_unknown_elem(field_ref, u_elem)
260
503
  value_dim = int(getattr(field, "value_dim", 1))
261
504
  if value_dim == 1:
@@ -264,7 +507,7 @@ def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
264
507
  return jnp.einsum("qa,ai->qi", field.N, u_nodes)
265
508
 
266
509
 
267
- def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
510
+ def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
268
511
  u_local = _extract_unknown_elem(field_ref, u_elem)
269
512
  if u_local is None:
270
513
  raise ValueError("u_elem is required to evaluate unknown field gradient.")
@@ -285,6 +528,15 @@ def sym_grad(field) -> Expr:
285
528
  return Expr("sym_grad", _as_expr(field))
286
529
 
287
530
 
531
+ def outer(a, b) -> Expr:
532
+ """Outer product of scalar fields: `outer(v, u)` (test, trial)."""
533
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
534
+ raise TypeError("outer expects FieldRef operands.")
535
+ if a.role != "test" or b.role != "trial":
536
+ raise TypeError("outer expects outer(test, trial).")
537
+ return Expr("outer", a, b)
538
+
539
+
288
540
  def dot(a, b) -> Expr:
289
541
  """Dot product or vector load helper."""
290
542
  return Expr("dot", _as_expr(a), _as_expr(b))
@@ -303,7 +555,7 @@ def ddot(a, b, c=None) -> Expr:
303
555
 
304
556
 
305
557
  def inner(a, b) -> Expr:
306
- """Inner product over the last axis."""
558
+ """Inner product over the last axis (tensor-level)."""
307
559
  return Expr("inner", _as_expr(a), _as_expr(b))
308
560
 
309
561
 
@@ -362,6 +614,16 @@ def transpose_last2(a) -> Expr:
362
614
  return Expr("transpose_last2", _as_expr(a))
363
615
 
364
616
 
617
+ def matmul(a, b) -> Expr:
618
+ """FEM-specific batched contraction (same semantics as `@`)."""
619
+ return Expr("matmul", _as_expr(a), _as_expr(b))
620
+
621
+
622
+ def matmul_std(a, b) -> Expr:
623
+ """Standard matrix product (`jnp.matmul` semantics)."""
624
+ return Expr("matmul_std", _as_expr(a), _as_expr(b))
625
+
626
+
365
627
  def einsum(subscripts: str, *args) -> Expr:
366
628
  """Einsum wrapper that supports Expr inputs."""
367
629
  return Expr("einsum", subscripts, *[_as_expr(arg) for arg in args])
@@ -369,9 +631,22 @@ def einsum(subscripts: str, *args) -> Expr:
369
631
 
370
632
  def _call_user(fn, *args, params):
371
633
  try:
634
+ sig = inspect.signature(fn)
635
+ except (TypeError, ValueError):
636
+ return fn(*args, params)
637
+
638
+ params_list = list(sig.parameters.values())
639
+ if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params_list):
640
+ return fn(*args, params)
641
+ positional = [
642
+ p
643
+ for p in params_list
644
+ if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
645
+ ]
646
+ max_positional = len(positional)
647
+ if len(args) + 1 <= max_positional:
372
648
  return fn(*args, params)
373
- except TypeError:
374
- return fn(*args)
649
+ return fn(*args)
375
650
 
376
651
 
377
652
  def compile_bilinear(fn):
@@ -382,19 +657,26 @@ def compile_bilinear(fn):
382
657
  u = trial_ref()
383
658
  v = test_ref()
384
659
  p = param_ref()
385
- try:
386
- expr = fn(u, v, p)
387
- except TypeError:
388
- expr = fn(u, v)
660
+ expr = _call_user(fn, u, v, params=p)
661
+ expr = _as_expr(expr)
662
+ if not isinstance(expr, Expr):
663
+ raise TypeError("Bilinear form must return an Expr.")
389
664
 
390
- includes_measure = _expr_contains(expr, "volume_measure")
391
- if not includes_measure:
665
+ volume_count = _count_op(expr, "volume_measure")
666
+ surface_count = _count_op(expr, "surface_measure")
667
+ if volume_count == 0:
392
668
  raise ValueError("Volume bilinear form must include dOmega().")
669
+ if volume_count > 1:
670
+ raise ValueError("Volume bilinear form must include dOmega() exactly once.")
671
+ if surface_count > 0:
672
+ raise ValueError("Volume bilinear form must not include ds().")
673
+
674
+ plan = make_eval_plan(expr)
393
675
 
394
676
  def _form(ctx, params):
395
- return _as_expr(expr).eval(ctx, params)
677
+ return eval_with_plan(plan, ctx, params)
396
678
 
397
- _form._includes_measure = includes_measure
679
+ _form._includes_measure = True
398
680
  return _form
399
681
 
400
682
 
@@ -405,19 +687,26 @@ def compile_linear(fn):
405
687
  else:
406
688
  v = test_ref()
407
689
  p = param_ref()
408
- try:
409
- expr = fn(v, p)
410
- except TypeError:
411
- expr = fn(v)
690
+ expr = _call_user(fn, v, params=p)
691
+ expr = _as_expr(expr)
692
+ if not isinstance(expr, Expr):
693
+ raise TypeError("Linear form must return an Expr.")
412
694
 
413
- includes_measure = _expr_contains(expr, "volume_measure")
414
- if not includes_measure:
695
+ volume_count = _count_op(expr, "volume_measure")
696
+ surface_count = _count_op(expr, "surface_measure")
697
+ if volume_count == 0:
415
698
  raise ValueError("Volume linear form must include dOmega().")
699
+ if volume_count > 1:
700
+ raise ValueError("Volume linear form must include dOmega() exactly once.")
701
+ if surface_count > 0:
702
+ raise ValueError("Volume linear form must not include ds().")
703
+
704
+ plan = make_eval_plan(expr)
416
705
 
417
706
  def _form(ctx, params):
418
- return _as_expr(expr).eval(ctx, params)
707
+ return eval_with_plan(plan, ctx, params)
419
708
 
420
- _form._includes_measure = includes_measure
709
+ _form._includes_measure = True
421
710
  return _form
422
711
 
423
712
 
@@ -429,6 +718,340 @@ def _expr_contains(expr: Expr, op: str) -> bool:
429
718
  return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
430
719
 
431
720
 
721
+ def _count_op(expr: Expr, op: str) -> int:
722
+ if not isinstance(expr, Expr):
723
+ return 0
724
+ count = 1 if expr.op == op else 0
725
+ for arg in expr.args:
726
+ if isinstance(arg, Expr):
727
+ count += _count_op(arg, op)
728
+ return count
729
+
730
+
731
+ @dataclass(frozen=True, slots=True)
732
+ class EvalPlan:
733
+ expr: Expr
734
+ nodes: tuple[Expr, ...]
735
+ index: dict[int, int]
736
+
737
+
738
+ def _validate_eval_plan(nodes: tuple[Expr, ...]) -> None:
739
+ fieldref_arg_ops = {
740
+ "value",
741
+ "grad",
742
+ "sym_grad",
743
+ "dot",
744
+ "sdot",
745
+ "action",
746
+ "gaction",
747
+ "outer",
748
+ }
749
+ for node in nodes:
750
+ op = node.op
751
+ args = node.args
752
+ if op not in fieldref_arg_ops:
753
+ if any(isinstance(arg, FieldRef) for arg in args):
754
+ raise TypeError(f"{op} cannot take FieldRef directly; wrap with .val/.grad/.sym_grad.")
755
+ if op in {"value", "grad", "sym_grad"}:
756
+ if len(args) != 1 or not isinstance(args[0], FieldRef):
757
+ raise TypeError(f"{op} expects FieldRef.")
758
+ elif op in {"dot", "sdot"}:
759
+ if len(args) != 2:
760
+ raise TypeError(f"{op} expects two arguments.")
761
+ if any(isinstance(arg, FieldRef) for arg in args):
762
+ if not isinstance(args[0], FieldRef):
763
+ raise TypeError(f"{op} expects FieldRef as the first argument.")
764
+ if isinstance(args[1], FieldRef):
765
+ raise TypeError(f"{op} expects an expression for the second argument; use .val/.grad.")
766
+ elif op in {"action", "gaction"}:
767
+ if len(args) != 2 or not isinstance(args[0], FieldRef):
768
+ raise TypeError(f"{op} expects FieldRef as the first argument.")
769
+ if op == "action" and isinstance(args[1], FieldRef):
770
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
771
+ if op == "gaction" and isinstance(args[1], FieldRef):
772
+ raise TypeError("gaction expects an expression for the second argument; use .grad.")
773
+ elif op == "outer":
774
+ if len(args) != 2 or not all(isinstance(arg, FieldRef) for arg in args):
775
+ raise TypeError("outer expects two FieldRef operands.")
776
+ if args[0].role != "test" or args[1].role != "trial":
777
+ raise TypeError("outer expects outer(test, trial).")
778
+
779
+
780
+ def make_eval_plan(expr: Expr) -> EvalPlan:
781
+ nodes = tuple(expr.postorder_expr())
782
+ _validate_eval_plan(nodes)
783
+ index: dict[int, int] = {}
784
+ for i, node in enumerate(nodes):
785
+ index.setdefault(id(node), i)
786
+ return EvalPlan(expr=expr, nodes=nodes, index=index)
787
+
788
+
789
+ def eval_with_plan(
790
+ plan: EvalPlan,
791
+ ctx: VolumeContext | SurfaceContext,
792
+ params: ParamsLike,
793
+ u_elem: UElement | None = None,
794
+ ):
795
+ nodes = plan.nodes
796
+ index = plan.index
797
+ vals: list[Any] = [None] * len(nodes)
798
+
799
+ def get(obj):
800
+ if isinstance(obj, Expr):
801
+ return vals[index[id(obj)]]
802
+ if isinstance(obj, FieldRef):
803
+ raise TypeError(
804
+ "FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
805
+ )
806
+ if isinstance(obj, ParamRef):
807
+ return params
808
+ return obj
809
+
810
+ for i, node in enumerate(nodes):
811
+ op = node.op
812
+ args = node.args
813
+
814
+ if op == "lit":
815
+ vals[i] = args[0]
816
+ continue
817
+ if op == "getattr":
818
+ base = get(args[0])
819
+ name = args[1]
820
+ if isinstance(base, dict):
821
+ vals[i] = base[name]
822
+ else:
823
+ vals[i] = getattr(base, name)
824
+ continue
825
+ if op == "value":
826
+ ref = args[0]
827
+ assert isinstance(ref, FieldRef)
828
+ field = _eval_field(ref, ctx, params)
829
+ if ref.role == "unknown":
830
+ vals[i] = _eval_unknown_value(ref, field, u_elem)
831
+ else:
832
+ vals[i] = field.N
833
+ continue
834
+ if op == "grad":
835
+ ref = args[0]
836
+ assert isinstance(ref, FieldRef)
837
+ field = _eval_field(ref, ctx, params)
838
+ if ref.role == "unknown":
839
+ vals[i] = _eval_unknown_grad(ref, field, u_elem)
840
+ else:
841
+ vals[i] = field.gradN
842
+ continue
843
+ if op == "pow":
844
+ base = get(args[0])
845
+ exp = get(args[1])
846
+ vals[i] = base**exp
847
+ continue
848
+ if op == "eye":
849
+ vals[i] = jnp.eye(int(args[0]))
850
+ continue
851
+ if op == "det":
852
+ vals[i] = jnp.linalg.det(get(args[0]))
853
+ continue
854
+ if op == "inv":
855
+ vals[i] = jnp.linalg.inv(get(args[0]))
856
+ continue
857
+ if op == "transpose":
858
+ vals[i] = jnp.swapaxes(get(args[0]), -1, -2)
859
+ continue
860
+ if op == "log":
861
+ vals[i] = jnp.log(get(args[0]))
862
+ continue
863
+ if op == "surface_normal":
864
+ normal = getattr(ctx, "normal", None)
865
+ if normal is None:
866
+ raise ValueError("surface normal is not available in context")
867
+ vals[i] = normal
868
+ continue
869
+ if op == "surface_measure":
870
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
871
+ raise TypeError("surface measure requires SurfaceContext.")
872
+ vals[i] = ctx.w * ctx.detJ
873
+ continue
874
+ if op == "volume_measure":
875
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
876
+ raise TypeError("volume measure requires VolumeContext.")
877
+ vals[i] = ctx.w * ctx.test.detJ
878
+ continue
879
+ if op == "sym_grad":
880
+ ref = args[0]
881
+ assert isinstance(ref, FieldRef)
882
+ field = _eval_field(ref, ctx, params)
883
+ if ref.role == "unknown":
884
+ if u_elem is None:
885
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
886
+ u_local = _extract_unknown_elem(ref, u_elem)
887
+ vals[i] = _ops.sym_grad_u(field, u_local)
888
+ else:
889
+ vals[i] = _ops.sym_grad(field)
890
+ continue
891
+ if op == "outer":
892
+ a, b = args
893
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
894
+ raise TypeError("outer expects FieldRef operands.")
895
+ test, trial = a, b
896
+ vals[i] = _basis_outer(test, trial, ctx, params)
897
+ continue
898
+ if op == "add":
899
+ vals[i] = get(args[0]) + get(args[1])
900
+ continue
901
+ if op == "sub":
902
+ vals[i] = get(args[0]) - get(args[1])
903
+ continue
904
+ if op == "mul":
905
+ a = get(args[0])
906
+ b = get(args[1])
907
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
908
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
909
+ a = a[:, None]
910
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
911
+ b = b[:, None]
912
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
913
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
914
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
915
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
916
+ vals[i] = a * b
917
+ continue
918
+ if op == "matmul":
919
+ a = get(args[0])
920
+ b = get(args[1])
921
+ if (
922
+ hasattr(a, "ndim")
923
+ and hasattr(b, "ndim")
924
+ and a.ndim == 3
925
+ and b.ndim == 3
926
+ and a.shape[0] == b.shape[0]
927
+ and a.shape[-1] == b.shape[-1]
928
+ ):
929
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
930
+ else:
931
+ raise TypeError(
932
+ "Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
933
+ )
934
+ continue
935
+ if op == "matmul_std":
936
+ a = get(args[0])
937
+ b = get(args[1])
938
+ vals[i] = jnp.matmul(a, b)
939
+ continue
940
+ if op == "neg":
941
+ vals[i] = -get(args[0])
942
+ continue
943
+ if op == "dot":
944
+ ref = args[0]
945
+ if isinstance(ref, FieldRef):
946
+ vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
947
+ else:
948
+ a = get(args[0])
949
+ b = get(args[1])
950
+ if (
951
+ hasattr(a, "ndim")
952
+ and hasattr(b, "ndim")
953
+ and a.ndim == 3
954
+ and b.ndim == 3
955
+ and a.shape[-1] == b.shape[-1]
956
+ ):
957
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
958
+ else:
959
+ vals[i] = jnp.matmul(a, b)
960
+ continue
961
+ if op == "sdot":
962
+ ref = args[0]
963
+ if isinstance(ref, FieldRef):
964
+ vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
965
+ else:
966
+ a = get(args[0])
967
+ b = get(args[1])
968
+ if (
969
+ hasattr(a, "ndim")
970
+ and hasattr(b, "ndim")
971
+ and a.ndim == 3
972
+ and b.ndim == 3
973
+ and a.shape[-1] == b.shape[-1]
974
+ ):
975
+ vals[i] = jnp.einsum("qia,qja->qij", a, b)
976
+ else:
977
+ vals[i] = jnp.matmul(a, b)
978
+ continue
979
+ if op == "ddot":
980
+ if len(args) == 2:
981
+ a = get(args[0])
982
+ b = get(args[1])
983
+ if (
984
+ hasattr(a, "ndim")
985
+ and hasattr(b, "ndim")
986
+ and a.ndim == 3
987
+ and b.ndim == 3
988
+ and a.shape[0] == b.shape[0]
989
+ and a.shape[1] == b.shape[1]
990
+ ):
991
+ vals[i] = jnp.einsum("qik,qim->qkm", a, b)
992
+ else:
993
+ vals[i] = _ops.ddot(a, b)
994
+ else:
995
+ vals[i] = _ops.ddot(get(args[0]), get(args[1]), get(args[2]))
996
+ continue
997
+ if op == "inner":
998
+ a = get(args[0])
999
+ b = get(args[1])
1000
+ vals[i] = jnp.einsum("...i,...i->...", a, b)
1001
+ continue
1002
+ if op == "action":
1003
+ ref = args[0]
1004
+ assert isinstance(ref, FieldRef)
1005
+ if isinstance(args[1], FieldRef):
1006
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
1007
+ v_field = _eval_field(ref, ctx, params)
1008
+ s = get(args[1])
1009
+ value_dim = int(getattr(v_field, "value_dim", 1))
1010
+ # action maps a test field with a scalar/vector expression into nodal space.
1011
+ if value_dim == 1:
1012
+ if v_field.N.ndim != 2:
1013
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
1014
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
1015
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
1016
+ vals[i] = v_field.N * s
1017
+ else:
1018
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
1019
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
1020
+ vals[i] = _ops.dot(v_field, s)
1021
+ continue
1022
+ if op == "gaction":
1023
+ ref = args[0]
1024
+ assert isinstance(ref, FieldRef)
1025
+ v_field = _eval_field(ref, ctx, params)
1026
+ q = get(args[1])
1027
+ # gaction maps a flux-like expression to nodal space via test gradients.
1028
+ if v_field.gradN.ndim != 3:
1029
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
1030
+ if not hasattr(q, "ndim"):
1031
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1032
+ if q.ndim == 2:
1033
+ vals[i] = jnp.einsum("qaj,qj->qa", v_field.gradN, q)
1034
+ elif q.ndim == 3:
1035
+ if int(getattr(v_field, "value_dim", 1)) == 1:
1036
+ raise ValueError("gaction tensor flux requires vector test field.")
1037
+ vals[i] = jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
1038
+ else:
1039
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
1040
+ continue
1041
+ if op == "transpose_last2":
1042
+ vals[i] = _ops.transpose_last2(get(args[0]))
1043
+ continue
1044
+ if op == "einsum":
1045
+ subscripts = args[0]
1046
+ operands = [get(arg) for arg in args[1:]]
1047
+ vals[i] = jnp.einsum(subscripts, *operands)
1048
+ continue
1049
+
1050
+ raise ValueError(f"Unknown Expr op: {op}")
1051
+
1052
+ return vals[index[id(plan.expr)]]
1053
+
1054
+
432
1055
  def compile_surface_linear(fn):
433
1056
  """get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
434
1057
  if isinstance(fn, Expr):
@@ -436,26 +1059,27 @@ def compile_surface_linear(fn):
436
1059
  else:
437
1060
  v = test_ref()
438
1061
  p = param_ref()
439
- expr = None
440
- try:
441
- expr = fn(v, p)
442
- except TypeError:
443
- try:
444
- expr = fn(v)
445
- except TypeError:
446
- expr = None
1062
+ expr = _call_user(fn, v, params=p)
447
1063
 
1064
+ expr = _as_expr(expr)
448
1065
  if not isinstance(expr, Expr):
449
1066
  raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
450
1067
 
451
- includes_measure = _expr_contains(expr, "surface_measure")
452
- if not includes_measure:
1068
+ surface_count = _count_op(expr, "surface_measure")
1069
+ volume_count = _count_op(expr, "volume_measure")
1070
+ if surface_count == 0:
453
1071
  raise ValueError("Surface linear form must include ds().")
1072
+ if surface_count > 1:
1073
+ raise ValueError("Surface linear form must include ds() exactly once.")
1074
+ if volume_count > 0:
1075
+ raise ValueError("Surface linear form must not include dOmega().")
1076
+
1077
+ plan = make_eval_plan(expr)
454
1078
 
455
1079
  def _form(ctx, params):
456
- return _as_expr(expr).eval(ctx, params)
1080
+ return eval_with_plan(plan, ctx, params)
457
1081
 
458
- _form._includes_measure = includes_measure # type: ignore[attr-defined]
1082
+ _form._includes_measure = True # type: ignore[attr-defined]
459
1083
  return _form
460
1084
 
461
1085
 
@@ -519,25 +1143,33 @@ def compile_residual(fn):
519
1143
  v = test_ref()
520
1144
  u = unknown_ref()
521
1145
  p = param_ref()
522
- try:
523
- expr = fn(v, u, p)
524
- except TypeError:
525
- expr = fn(v, u)
1146
+ expr = _call_user(fn, v, u, params=p)
1147
+ expr = _as_expr(expr)
1148
+ if not isinstance(expr, Expr):
1149
+ raise TypeError("Residual form must return an Expr.")
526
1150
 
527
- includes_measure = _expr_contains(expr, "volume_measure")
528
- if not includes_measure:
1151
+ volume_count = _count_op(expr, "volume_measure")
1152
+ surface_count = _count_op(expr, "surface_measure")
1153
+ if volume_count == 0:
529
1154
  raise ValueError("Volume residual form must include dOmega().")
1155
+ if volume_count > 1:
1156
+ raise ValueError("Volume residual form must include dOmega() exactly once.")
1157
+ if surface_count > 0:
1158
+ raise ValueError("Volume residual form must not include ds().")
1159
+
1160
+ plan = make_eval_plan(expr)
530
1161
 
531
1162
  def _form(ctx, u_elem, params):
532
- return _as_expr(expr).eval(ctx, params, u_elem=u_elem)
1163
+ return eval_with_plan(plan, ctx, params, u_elem=u_elem)
533
1164
 
534
- _form._includes_measure = includes_measure
1165
+ _form._includes_measure = True
535
1166
  return _form
536
1167
 
537
1168
 
538
1169
  def compile_mixed_residual(residuals: dict[str, Callable]):
539
1170
  """get_compiled mixed residuals keyed by field name."""
540
1171
  compiled = {}
1172
+ plans = {}
541
1173
  includes_measure = {}
542
1174
  for name, fn in residuals.items():
543
1175
  if isinstance(fn, Expr):
@@ -546,17 +1178,24 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
546
1178
  v = test_ref(name)
547
1179
  u = unknown_ref(name)
548
1180
  p = param_ref()
549
- try:
550
- expr = fn(v, u, p)
551
- except TypeError:
552
- expr = fn(v, u)
553
- compiled[name] = _as_expr(expr)
554
- includes_measure[name] = _expr_contains(compiled[name], "volume_measure")
555
- if not includes_measure[name]:
1181
+ expr = _call_user(fn, v, u, params=p)
1182
+ expr = _as_expr(expr)
1183
+ if not isinstance(expr, Expr):
1184
+ raise TypeError(f"Mixed residual '{name}' must return an Expr.")
1185
+ compiled[name] = expr
1186
+ plans[name] = make_eval_plan(expr)
1187
+ volume_count = _count_op(compiled[name], "volume_measure")
1188
+ surface_count = _count_op(compiled[name], "surface_measure")
1189
+ includes_measure[name] = volume_count == 1
1190
+ if volume_count == 0:
556
1191
  raise ValueError(f"Mixed residual '{name}' must include dOmega().")
1192
+ if volume_count > 1:
1193
+ raise ValueError(f"Mixed residual '{name}' must include dOmega() exactly once.")
1194
+ if surface_count > 0:
1195
+ raise ValueError(f"Mixed residual '{name}' must not include ds().")
557
1196
 
558
1197
  def _form(ctx, u_elem, params):
559
- return {name: expr.eval(ctx, params, u_elem=u_elem) for name, expr in compiled.items()}
1198
+ return {name: eval_with_plan(plan, ctx, params, u_elem=u_elem) for name, plan in plans.items()}
560
1199
 
561
1200
  _form._includes_measure = includes_measure
562
1201
  return _form
@@ -574,216 +1213,14 @@ class MixedWeakForm:
574
1213
  return compile_mixed_residual(self.residuals)
575
1214
 
576
1215
 
577
- def _eval_expr(expr: Expr, ctx, params, u_elem=None):
578
- op = expr.op
579
- args = expr.args
580
-
581
- if op == "lit":
582
- return args[0]
583
- if op == "param":
584
- return params
585
- if op == "getattr":
586
- base = _eval_value(args[0], ctx, params, u_elem=u_elem)
587
- name = args[1]
588
- if isinstance(base, dict):
589
- return base[name]
590
- return getattr(base, name)
591
- if op == "field":
592
- role, name = args
593
- if name is not None:
594
- if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
595
- if name in ctx.trial_fields:
596
- return ctx.trial_fields[name]
597
- if role == "test" and getattr(ctx, "test_fields", None) is not None:
598
- if name in ctx.test_fields:
599
- return ctx.test_fields[name]
600
- if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
601
- if name in ctx.unknown_fields:
602
- return ctx.unknown_fields[name]
603
- fields = getattr(ctx, "fields", None)
604
- if fields is not None and name in fields:
605
- group = fields[name]
606
- if isinstance(group, dict):
607
- if role in group:
608
- return group[role]
609
- if "field" in group:
610
- return group["field"]
611
- return group
612
- if role == "trial":
613
- return ctx.trial
614
- if role == "test":
615
- return ctx.test
616
- if role == "unknown":
617
- return getattr(ctx, "unknown", ctx.trial)
618
- raise ValueError(f"Unknown field role: {role}")
619
- if op == "value":
620
- field = _eval_field(args[0], ctx, params)
621
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
622
- return _eval_unknown_value(args[0], field, u_elem)
623
- return field.N
624
- if op == "grad":
625
- field = _eval_field(args[0], ctx, params)
626
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
627
- return _eval_unknown_grad(args[0], field, u_elem)
628
- return field.gradN
629
- if op == "pow":
630
- base = _eval_value(args[0], ctx, params, u_elem=u_elem)
631
- exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
632
- return base**exp
633
- if op == "eye":
634
- return jnp.eye(int(args[0]))
635
- if op == "det":
636
- return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
637
- if op == "inv":
638
- return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
639
- if op == "transpose":
640
- return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
641
- if op == "log":
642
- return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
643
- if op == "surface_normal":
644
- normal = getattr(ctx, "normal", None)
645
- if normal is None:
646
- raise ValueError("surface normal is not available in context")
647
- return normal
648
- if op == "surface_measure":
649
- if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
650
- raise ValueError("surface measure requires surface context with w and detJ.")
651
- return ctx.w * ctx.detJ
652
- if op == "volume_measure":
653
- if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
654
- raise ValueError("volume measure requires FormContext with w and test.detJ.")
655
- return ctx.w * ctx.test.detJ
656
- if op == "sym_grad":
657
- field = _eval_field(args[0], ctx, params)
658
- if isinstance(args[0], FieldRef) and args[0].role == "unknown":
659
- if u_elem is None:
660
- raise ValueError("u_elem is required to evaluate unknown sym_grad.")
661
- u_local = _extract_unknown_elem(args[0], u_elem)
662
- return _ops.sym_grad_u(field, u_local)
663
- return _ops.sym_grad(field)
664
- if op == "outer":
665
- a, b = args
666
- if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
667
- raise TypeError("outer expects FieldRef operands.")
668
- if a.role == b.role:
669
- raise ValueError("outer requires one trial and one test field.")
670
- test = a if a.role == "test" else b
671
- trial = b if a.role == "test" else a
672
- v_field = _eval_field(test, ctx, params)
673
- u_field = _eval_field(trial, ctx, params)
674
- if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
675
- raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
676
- vN = v_field.N
677
- uN = u_field.N
678
- return jnp.einsum("qi,qj->qij", vN, uN)
679
- if op == "add":
680
- return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
681
- if op == "sub":
682
- return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
683
- if op == "mul":
684
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
685
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
686
- if hasattr(a, "ndim") and hasattr(b, "ndim"):
687
- if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
688
- a = a[:, None]
689
- elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
690
- b = b[:, None]
691
- elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
692
- b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
693
- elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
694
- a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
695
- return a * b
696
- if op == "matmul":
697
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
698
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
699
- if (
700
- hasattr(a, "ndim")
701
- and hasattr(b, "ndim")
702
- and a.ndim == 3
703
- and b.ndim == 3
704
- and a.shape[0] == b.shape[0]
705
- and a.shape[-1] == b.shape[-1]
706
- ):
707
- return jnp.einsum("qia,qja->qij", a, b)
708
- return a @ b
709
- if op == "neg":
710
- return -_eval_value(args[0], ctx, params, u_elem=u_elem)
711
- if op == "dot":
712
- if isinstance(args[0], FieldRef):
713
- return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
714
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
715
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
716
- if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
717
- return jnp.einsum("qia,qja->qij", a, b)
718
- return jnp.matmul(a, b)
719
- if op == "sdot":
720
- if isinstance(args[0], FieldRef):
721
- return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
722
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
723
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
724
- if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
725
- return jnp.einsum("qia,qja->qij", a, b)
726
- return jnp.matmul(a, b)
727
- if op == "ddot":
728
- if len(args) == 2:
729
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
730
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
731
- if (
732
- hasattr(a, "ndim")
733
- and hasattr(b, "ndim")
734
- and a.ndim == 3
735
- and b.ndim == 3
736
- and a.shape[0] == b.shape[0]
737
- and a.shape[1] == b.shape[1]
738
- ):
739
- return jnp.einsum("qik,qim->qkm", a, b)
740
- return _ops.ddot(a, b)
741
- return _ops.ddot(
742
- _eval_value(args[0], ctx, params, u_elem=u_elem),
743
- _eval_value(args[1], ctx, params, u_elem=u_elem),
744
- _eval_value(args[2], ctx, params, u_elem=u_elem),
745
- )
746
- if op == "inner":
747
- a = _eval_value(args[0], ctx, params, u_elem=u_elem)
748
- b = _eval_value(args[1], ctx, params, u_elem=u_elem)
749
- return jnp.einsum("...i,...i->...", a, b)
750
- if op == "action":
751
- if isinstance(args[1], FieldRef):
752
- raise ValueError("action expects a scalar expression; use u.val for unknowns.")
753
- v_field = _eval_field(args[0], ctx, params)
754
- s = _eval_value(args[1], ctx, params, u_elem=u_elem)
755
- value_dim = int(getattr(v_field, "value_dim", 1))
756
- if value_dim == 1:
757
- if v_field.N.ndim != 2:
758
- raise ValueError("action expects scalar test field with N shape (q, ndofs).")
759
- if hasattr(s, "ndim") and s.ndim not in (0, 1):
760
- raise ValueError("action expects scalar s with shape (q,) or scalar.")
761
- return v_field.N * s
762
- if hasattr(s, "ndim") and s.ndim not in (1, 2):
763
- raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
764
- return _ops.dot(v_field, s)
765
- if op == "gaction":
766
- v_field = _eval_field(args[0], ctx, params)
767
- q = _eval_value(args[1], ctx, params, u_elem=u_elem)
768
- if v_field.gradN.ndim != 3:
769
- raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
770
- if not hasattr(q, "ndim"):
771
- raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
772
- if q.ndim == 2:
773
- return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
774
- if q.ndim == 3:
775
- if int(getattr(v_field, "value_dim", 1)) == 1:
776
- raise ValueError("gaction tensor flux requires vector test field.")
777
- return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
778
- raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
779
- if op == "transpose_last2":
780
- return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
781
- if op == "einsum":
782
- subscripts = args[0]
783
- operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
784
- return jnp.einsum(subscripts, *operands)
785
-
786
- raise ValueError(f"Unknown Expr op: {op}")
1216
+ def _eval_expr(
1217
+ expr: Expr,
1218
+ ctx: VolumeContext | SurfaceContext,
1219
+ params: ParamsLike,
1220
+ u_elem: UElement | None = None,
1221
+ ):
1222
+ plan = make_eval_plan(expr)
1223
+ return eval_with_plan(plan, ctx, params, u_elem=u_elem)
787
1224
 
788
1225
 
789
1226
  __all__ = [
@@ -803,6 +1240,7 @@ __all__ = [
803
1240
  "compile_mixed_residual",
804
1241
  "grad",
805
1242
  "sym_grad",
1243
+ "outer",
806
1244
  "dot",
807
1245
  "ddot",
808
1246
  "inner",
@@ -814,5 +1252,7 @@ __all__ = [
814
1252
  "transpose",
815
1253
  "log",
816
1254
  "transpose_last2",
1255
+ "matmul",
1256
+ "matmul_std",
817
1257
  "einsum",
818
1258
  ]