fluxfem 0.1.1a0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +68 -161
- fluxfem/core/__init__.py +59 -31
- fluxfem/core/context_types.py +36 -0
- fluxfem/core/forms.py +5 -1
- fluxfem/core/weakform.py +747 -307
- fluxfem/{helpers_num.py → helpers_ts.py} +1 -1
- fluxfem/helpers_wf.py +6 -0
- fluxfem/mesh/base.py +6 -1
- fluxfem/mesh/hex.py +1 -0
- fluxfem/mesh/io.py +2 -0
- {fluxfem-0.1.1a0.dist-info → fluxfem-0.1.4.dist-info}/METADATA +39 -23
- {fluxfem-0.1.1a0.dist-info → fluxfem-0.1.4.dist-info}/RECORD +14 -13
- {fluxfem-0.1.1a0.dist-info → fluxfem-0.1.4.dist-info}/WHEEL +1 -1
- {fluxfem-0.1.1a0.dist-info/licenses → fluxfem-0.1.4.dist-info}/LICENSE +0 -0
fluxfem/core/weakform.py
CHANGED
|
@@ -1,82 +1,312 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Callable
|
|
4
|
+
from typing import Any, Callable, Iterator, Literal, get_args
|
|
5
|
+
import inspect
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
5
8
|
|
|
6
9
|
import jax.numpy as jnp
|
|
7
10
|
import jax
|
|
8
11
|
|
|
9
12
|
from ..physics import operators as _ops
|
|
13
|
+
from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
OpName = Literal[
|
|
17
|
+
"lit",
|
|
18
|
+
"getattr",
|
|
19
|
+
"value",
|
|
20
|
+
"grad",
|
|
21
|
+
"pow",
|
|
22
|
+
"eye",
|
|
23
|
+
"det",
|
|
24
|
+
"inv",
|
|
25
|
+
"transpose",
|
|
26
|
+
"log",
|
|
27
|
+
"surface_normal",
|
|
28
|
+
"surface_measure",
|
|
29
|
+
"volume_measure",
|
|
30
|
+
"sym_grad",
|
|
31
|
+
"outer",
|
|
32
|
+
"add",
|
|
33
|
+
"sub",
|
|
34
|
+
"mul",
|
|
35
|
+
"matmul",
|
|
36
|
+
"matmul_std",
|
|
37
|
+
"neg",
|
|
38
|
+
"dot",
|
|
39
|
+
"sdot",
|
|
40
|
+
"ddot",
|
|
41
|
+
"inner",
|
|
42
|
+
"action",
|
|
43
|
+
"gaction",
|
|
44
|
+
"transpose_last2",
|
|
45
|
+
"einsum",
|
|
46
|
+
]
|
|
10
47
|
|
|
48
|
+
# Use OpName as the single source of truth for valid ops.
|
|
49
|
+
_OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_PRECEDENCE: dict[str, int] = {
|
|
53
|
+
"add": 10,
|
|
54
|
+
"sub": 10,
|
|
55
|
+
"mul": 20,
|
|
56
|
+
"matmul": 20,
|
|
57
|
+
"matmul_std": 20,
|
|
58
|
+
"inner": 20,
|
|
59
|
+
"dot": 20,
|
|
60
|
+
"sdot": 20,
|
|
61
|
+
"ddot": 20,
|
|
62
|
+
"pow": 30,
|
|
63
|
+
"neg": 40,
|
|
64
|
+
"transpose": 50,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _pretty_render_arg(arg, prec: int | None = None) -> str:
|
|
69
|
+
if isinstance(arg, Expr):
|
|
70
|
+
return _pretty_expr(arg, prec or 0)
|
|
71
|
+
if isinstance(arg, FieldRef):
|
|
72
|
+
if arg.name is None:
|
|
73
|
+
return f"{arg.role}"
|
|
74
|
+
return f"{arg.role}:{arg.name}"
|
|
75
|
+
if isinstance(arg, ParamRef):
|
|
76
|
+
return "param"
|
|
77
|
+
return repr(arg)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _pretty_wrap(text: str, prec: int, parent_prec: int) -> str:
|
|
81
|
+
if prec < parent_prec:
|
|
82
|
+
return f"({text})"
|
|
83
|
+
return text
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _pretty_expr(expr: Expr, parent_prec: int = 0) -> str:
|
|
87
|
+
op = expr.op
|
|
88
|
+
args = expr.args
|
|
11
89
|
|
|
90
|
+
if op == "lit":
|
|
91
|
+
return repr(args[0])
|
|
92
|
+
if op == "getattr":
|
|
93
|
+
base = _pretty_render_arg(args[0], _PRECEDENCE.get("transpose", 50))
|
|
94
|
+
return f"{base}.{args[1]}"
|
|
95
|
+
if op == "value":
|
|
96
|
+
return f"val({_pretty_render_arg(args[0])})"
|
|
97
|
+
if op == "grad":
|
|
98
|
+
return f"grad({_pretty_render_arg(args[0])})"
|
|
99
|
+
if op == "sym_grad":
|
|
100
|
+
return f"sym_grad({_pretty_render_arg(args[0])})"
|
|
101
|
+
if op == "neg":
|
|
102
|
+
inner = _pretty_render_arg(args[0], _PRECEDENCE["neg"])
|
|
103
|
+
return _pretty_wrap(f"-{inner}", _PRECEDENCE["neg"], parent_prec)
|
|
104
|
+
if op == "transpose":
|
|
105
|
+
inner = _pretty_render_arg(args[0], _PRECEDENCE["transpose"])
|
|
106
|
+
return _pretty_wrap(f"{inner}.T", _PRECEDENCE["transpose"], parent_prec)
|
|
107
|
+
if op == "pow":
|
|
108
|
+
base = _pretty_render_arg(args[0], _PRECEDENCE["pow"])
|
|
109
|
+
exp = _pretty_render_arg(args[1], _PRECEDENCE["pow"] + 1)
|
|
110
|
+
return _pretty_wrap(f"{base}**{exp}", _PRECEDENCE["pow"], parent_prec)
|
|
111
|
+
if op in {"add", "sub", "mul", "matmul", "dot", "sdot", "ddot"}:
|
|
112
|
+
left = _pretty_render_arg(args[0], _PRECEDENCE[op])
|
|
113
|
+
right = _pretty_render_arg(args[1], _PRECEDENCE[op] + 1)
|
|
114
|
+
symbol = {
|
|
115
|
+
"add": "+",
|
|
116
|
+
"sub": "-",
|
|
117
|
+
"mul": "*",
|
|
118
|
+
"matmul": "@",
|
|
119
|
+
"inner": "|",
|
|
120
|
+
"dot": "dot",
|
|
121
|
+
"sdot": "sdot",
|
|
122
|
+
"ddot": "ddot",
|
|
123
|
+
}[op]
|
|
124
|
+
if symbol in {"dot", "sdot", "ddot"}:
|
|
125
|
+
text = f"{symbol}({left}, {right})"
|
|
126
|
+
else:
|
|
127
|
+
text = f"{left} {symbol} {right}"
|
|
128
|
+
return _pretty_wrap(text, _PRECEDENCE[op], parent_prec)
|
|
129
|
+
if op == "inner":
|
|
130
|
+
return f"inner({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
131
|
+
if op in {"action", "gaction"}:
|
|
132
|
+
return f"{op}({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
133
|
+
if op == "matmul_std":
|
|
134
|
+
return f"matmul_std({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
135
|
+
if op == "outer":
|
|
136
|
+
return f"outer({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
137
|
+
if op in {
|
|
138
|
+
"eye",
|
|
139
|
+
"det",
|
|
140
|
+
"inv",
|
|
141
|
+
"log",
|
|
142
|
+
"surface_normal",
|
|
143
|
+
"surface_measure",
|
|
144
|
+
"volume_measure",
|
|
145
|
+
"transpose_last2",
|
|
146
|
+
"einsum",
|
|
147
|
+
}:
|
|
148
|
+
rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
|
|
149
|
+
return f"{op}({rendered})"
|
|
150
|
+
rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
|
|
151
|
+
return f"{op}({rendered})"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _as_expr(obj) -> Expr | FieldRef | ParamRef:
|
|
155
|
+
"""Normalize inputs into Expr/FieldRef/ParamRef nodes."""
|
|
156
|
+
if isinstance(obj, Expr):
|
|
157
|
+
return obj
|
|
158
|
+
if isinstance(obj, FieldRef):
|
|
159
|
+
return obj
|
|
160
|
+
if isinstance(obj, ParamRef):
|
|
161
|
+
return obj
|
|
162
|
+
if isinstance(obj, (int, float, bool, str)):
|
|
163
|
+
return Expr("lit", obj)
|
|
164
|
+
if isinstance(obj, np.generic):
|
|
165
|
+
return Expr("lit", obj.item())
|
|
166
|
+
if isinstance(obj, tuple):
|
|
167
|
+
try:
|
|
168
|
+
hash(obj)
|
|
169
|
+
except TypeError as exc:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
"Expr tuple literal must be hashable; use only immutable items."
|
|
172
|
+
) from exc
|
|
173
|
+
return Expr("lit", obj)
|
|
174
|
+
raise TypeError(
|
|
175
|
+
"Expr literal must be a scalar or hashable tuple. "
|
|
176
|
+
"Arrays are not allowed; pass them via params (ParamRef/params.xxx)."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@dataclass(frozen=True, slots=True, init=False)
|
|
12
181
|
class Expr:
|
|
13
|
-
"""Expression tree node evaluated against a FormContext.
|
|
182
|
+
"""Expression tree node evaluated against a FormContext.
|
|
183
|
+
|
|
184
|
+
Compile flow (recommended):
|
|
185
|
+
- build an Expr via operators/refs
|
|
186
|
+
- compile_* builds an EvalPlan (postorder nodes + index)
|
|
187
|
+
- eval_with_plan(plan, ctx, params, u_elem) evaluates per element
|
|
188
|
+
|
|
189
|
+
Expr.eval is a debug/single-shot path that creates a plan on demand.
|
|
190
|
+
"""
|
|
14
191
|
|
|
15
|
-
|
|
192
|
+
op: OpName
|
|
193
|
+
args: tuple[Any, ...]
|
|
194
|
+
|
|
195
|
+
def __init__(self, op: OpName, *args):
|
|
196
|
+
if op not in _OP_NAMES:
|
|
197
|
+
raise ValueError(f"Unknown Expr op: {op!r}")
|
|
16
198
|
object.__setattr__(self, "op", op)
|
|
17
199
|
object.__setattr__(self, "args", args)
|
|
18
200
|
|
|
19
201
|
def eval(self, ctx, params=None, u_elem=None):
|
|
202
|
+
"""Evaluate the expression against a context (debug/single-shot path)."""
|
|
20
203
|
return _eval_expr(self, ctx, params, u_elem=u_elem)
|
|
21
204
|
|
|
205
|
+
def children(self) -> tuple[Any, ...]:
|
|
206
|
+
"""Return direct child nodes (Expr/FieldRef/ParamRef) for traversal."""
|
|
207
|
+
return tuple(arg for arg in self.args if isinstance(arg, (Expr, FieldRef, ParamRef)))
|
|
208
|
+
|
|
209
|
+
def walk(self) -> Iterator[Any]:
|
|
210
|
+
"""Depth-first walk over nodes, including leaf FieldRef/ParamRef."""
|
|
211
|
+
yield self
|
|
212
|
+
for child in self.children():
|
|
213
|
+
if isinstance(child, Expr):
|
|
214
|
+
yield from child.walk()
|
|
215
|
+
else:
|
|
216
|
+
yield child
|
|
217
|
+
|
|
218
|
+
def postorder(self) -> Iterator[Any]:
|
|
219
|
+
"""Postorder walk over nodes, including leaf FieldRef/ParamRef."""
|
|
220
|
+
for child in self.children():
|
|
221
|
+
if isinstance(child, Expr):
|
|
222
|
+
yield from child.postorder()
|
|
223
|
+
else:
|
|
224
|
+
yield child
|
|
225
|
+
yield self
|
|
226
|
+
|
|
227
|
+
def postorder_expr(self) -> Iterator["Expr"]:
|
|
228
|
+
"""Postorder walk over Expr nodes only (for eval planning)."""
|
|
229
|
+
for arg in self.args:
|
|
230
|
+
if isinstance(arg, Expr):
|
|
231
|
+
yield from arg.postorder_expr()
|
|
232
|
+
yield self
|
|
233
|
+
|
|
22
234
|
def _binop(self, other, op):
|
|
23
235
|
return Expr(op, self, _as_expr(other))
|
|
24
236
|
|
|
25
237
|
def __add__(self, other):
|
|
238
|
+
"""Add expressions: `a + b`."""
|
|
26
239
|
return self._binop(other, "add")
|
|
27
240
|
|
|
28
241
|
def __radd__(self, other):
|
|
29
|
-
|
|
242
|
+
"""Right-add expressions: `1 + expr`."""
|
|
243
|
+
return Expr("add", _as_expr(other), self)
|
|
30
244
|
|
|
31
245
|
def __sub__(self, other):
|
|
246
|
+
"""Subtract expressions: `a - b`."""
|
|
32
247
|
return self._binop(other, "sub")
|
|
33
248
|
|
|
34
249
|
def __rsub__(self, other):
|
|
35
|
-
|
|
250
|
+
"""Right-subtract expressions: `1 - expr`."""
|
|
251
|
+
return Expr("sub", _as_expr(other), self)
|
|
36
252
|
|
|
37
253
|
def __mul__(self, other):
|
|
254
|
+
"""Multiply expressions: `a * b`."""
|
|
38
255
|
return self._binop(other, "mul")
|
|
39
256
|
|
|
40
257
|
def __rmul__(self, other):
|
|
41
|
-
|
|
258
|
+
"""Right-multiply expressions: `2 * expr`."""
|
|
259
|
+
return Expr("mul", _as_expr(other), self)
|
|
42
260
|
|
|
43
261
|
def __matmul__(self, other):
|
|
262
|
+
"""Matrix product: `a @ b` (FEM-specific contraction semantics)."""
|
|
44
263
|
return self._binop(other, "matmul")
|
|
45
264
|
|
|
46
265
|
def __rmatmul__(self, other):
|
|
47
|
-
|
|
266
|
+
"""Right-matmul: `A @ expr`."""
|
|
267
|
+
return Expr("matmul", _as_expr(other), self)
|
|
48
268
|
|
|
49
269
|
def __or__(self, other):
|
|
50
|
-
|
|
270
|
+
"""Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
|
|
271
|
+
if isinstance(other, FieldRef):
|
|
272
|
+
raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
|
|
273
|
+
return Expr("inner", self, _as_expr(other))
|
|
51
274
|
|
|
52
275
|
def __ror__(self, other):
|
|
53
|
-
|
|
276
|
+
"""Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
|
|
277
|
+
if isinstance(other, FieldRef):
|
|
278
|
+
raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
|
|
279
|
+
return Expr("inner", _as_expr(other), self)
|
|
54
280
|
|
|
55
281
|
def __pow__(self, power, modulo=None):
|
|
282
|
+
"""Power: `a ** p` (no modulo support)."""
|
|
56
283
|
if modulo is not None:
|
|
57
284
|
raise ValueError("modulo is not supported for Expr exponentiation.")
|
|
58
285
|
return Expr("pow", self, _as_expr(power))
|
|
59
286
|
|
|
60
287
|
def __neg__(self):
|
|
288
|
+
"""Unary negation: `-expr`."""
|
|
61
289
|
return Expr("neg", self)
|
|
62
290
|
|
|
63
291
|
@property
|
|
64
292
|
def T(self):
|
|
293
|
+
"""Transpose view: `expr.T`."""
|
|
65
294
|
return Expr("transpose", self)
|
|
66
295
|
|
|
296
|
+
def __repr__(self) -> str:
|
|
297
|
+
return self.pretty()
|
|
67
298
|
|
|
68
|
-
|
|
69
|
-
|
|
299
|
+
def pretty(self) -> str:
|
|
300
|
+
return _pretty_expr(self)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@dataclass(frozen=True, slots=True)
|
|
304
|
+
class FieldRef:
|
|
70
305
|
"""Symbolic reference to trial/test/unknown field, optionally by name."""
|
|
71
306
|
|
|
72
307
|
role: str
|
|
73
308
|
name: str | None = None
|
|
74
309
|
|
|
75
|
-
def __init__(self, role: str, name: str | None = None):
|
|
76
|
-
object.__setattr__(self, "role", role)
|
|
77
|
-
object.__setattr__(self, "name", name)
|
|
78
|
-
super().__init__("field", role, name)
|
|
79
|
-
|
|
80
310
|
@property
|
|
81
311
|
def val(self):
|
|
82
312
|
return Expr("value", self)
|
|
@@ -91,12 +321,18 @@ class FieldRef(Expr):
|
|
|
91
321
|
|
|
92
322
|
def __mul__(self, other):
|
|
93
323
|
if isinstance(other, FieldRef):
|
|
94
|
-
|
|
324
|
+
raise TypeError(
|
|
325
|
+
"FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
|
|
326
|
+
"action(v, s), or dot(v, q)."
|
|
327
|
+
)
|
|
95
328
|
return Expr("mul", Expr("value", self), _as_expr(other))
|
|
96
329
|
|
|
97
330
|
def __rmul__(self, other):
|
|
98
331
|
if isinstance(other, FieldRef):
|
|
99
|
-
|
|
332
|
+
raise TypeError(
|
|
333
|
+
"FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
|
|
334
|
+
"action(v, s), or dot(v, q)."
|
|
335
|
+
)
|
|
100
336
|
return Expr("mul", _as_expr(other), Expr("value", self))
|
|
101
337
|
|
|
102
338
|
def __add__(self, other):
|
|
@@ -113,22 +349,23 @@ class FieldRef(Expr):
|
|
|
113
349
|
|
|
114
350
|
def __or__(self, other):
|
|
115
351
|
if isinstance(other, FieldRef):
|
|
116
|
-
|
|
117
|
-
|
|
352
|
+
raise TypeError(
|
|
353
|
+
"FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
|
|
354
|
+
)
|
|
355
|
+
return Expr("dot", self, _as_expr(other))
|
|
118
356
|
|
|
119
357
|
def __ror__(self, other):
|
|
120
358
|
if isinstance(other, FieldRef):
|
|
121
|
-
|
|
122
|
-
|
|
359
|
+
raise TypeError(
|
|
360
|
+
"FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
|
|
361
|
+
)
|
|
362
|
+
return Expr("dot", _as_expr(other), self)
|
|
123
363
|
|
|
124
364
|
|
|
125
|
-
@dataclass(frozen=True)
|
|
126
|
-
class ParamRef
|
|
365
|
+
@dataclass(frozen=True, slots=True)
|
|
366
|
+
class ParamRef:
|
|
127
367
|
"""Symbolic reference to params passed into the kernel."""
|
|
128
368
|
|
|
129
|
-
def __init__(self):
|
|
130
|
-
super().__init__("param")
|
|
131
|
-
|
|
132
369
|
def __getattr__(self, name: str):
|
|
133
370
|
return Expr("getattr", self, name)
|
|
134
371
|
|
|
@@ -179,13 +416,11 @@ def param_ref() -> ParamRef:
|
|
|
179
416
|
return ParamRef()
|
|
180
417
|
|
|
181
418
|
|
|
182
|
-
def
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def _eval_field(obj: Any, ctx, params):
|
|
419
|
+
def _eval_field(
|
|
420
|
+
obj: Any,
|
|
421
|
+
ctx: VolumeContext | SurfaceContext,
|
|
422
|
+
params: ParamsLike,
|
|
423
|
+
) -> FormFieldLike:
|
|
189
424
|
if isinstance(obj, FieldRef):
|
|
190
425
|
if obj.name is not None:
|
|
191
426
|
mixed_fields = getattr(ctx, "fields", None)
|
|
@@ -226,25 +461,23 @@ def _eval_field(obj: Any, ctx, params):
|
|
|
226
461
|
if obj.role == "unknown":
|
|
227
462
|
return getattr(ctx, "unknown", ctx.trial)
|
|
228
463
|
raise ValueError(f"Unknown field role: {obj.role}")
|
|
229
|
-
if isinstance(obj, Expr):
|
|
230
|
-
val = obj.eval(ctx, params)
|
|
231
|
-
if hasattr(val, "N"):
|
|
232
|
-
return val
|
|
233
464
|
raise TypeError("Expected a field reference for this operator.")
|
|
234
465
|
|
|
235
466
|
|
|
236
|
-
def _eval_value(obj: Any, ctx, params, u_elem=None):
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
467
|
+
# def _eval_value(obj: Any, ctx, params, u_elem=None):
|
|
468
|
+
# if isinstance(obj, FieldRef):
|
|
469
|
+
# field = _eval_field(obj, ctx, params)
|
|
470
|
+
# if obj.role == "unknown":
|
|
471
|
+
# return _eval_unknown_value(obj, field, u_elem)
|
|
472
|
+
# return field.N
|
|
473
|
+
# if isinstance(obj, ParamRef):
|
|
474
|
+
# return params
|
|
475
|
+
# if isinstance(obj, Expr):
|
|
476
|
+
# return obj.eval(ctx, params, u_elem=u_elem)
|
|
477
|
+
# return obj
|
|
245
478
|
|
|
246
479
|
|
|
247
|
-
def _extract_unknown_elem(field_ref: FieldRef, u_elem):
|
|
480
|
+
def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
|
|
248
481
|
if u_elem is None:
|
|
249
482
|
raise ValueError("u_elem is required to evaluate unknown field value.")
|
|
250
483
|
if isinstance(u_elem, dict):
|
|
@@ -255,7 +488,17 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem):
|
|
|
255
488
|
return u_elem
|
|
256
489
|
|
|
257
490
|
|
|
258
|
-
def
|
|
491
|
+
def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
|
|
492
|
+
v_field = _eval_field(test, ctx, params)
|
|
493
|
+
u_field = _eval_field(trial, ctx, params)
|
|
494
|
+
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
495
|
+
raise ValueError(
|
|
496
|
+
"inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
|
|
497
|
+
)
|
|
498
|
+
return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
259
502
|
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
260
503
|
value_dim = int(getattr(field, "value_dim", 1))
|
|
261
504
|
if value_dim == 1:
|
|
@@ -264,7 +507,7 @@ def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
|
|
|
264
507
|
return jnp.einsum("qa,ai->qi", field.N, u_nodes)
|
|
265
508
|
|
|
266
509
|
|
|
267
|
-
def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
|
|
510
|
+
def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
268
511
|
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
269
512
|
if u_local is None:
|
|
270
513
|
raise ValueError("u_elem is required to evaluate unknown field gradient.")
|
|
@@ -285,6 +528,15 @@ def sym_grad(field) -> Expr:
|
|
|
285
528
|
return Expr("sym_grad", _as_expr(field))
|
|
286
529
|
|
|
287
530
|
|
|
531
|
+
def outer(a, b) -> Expr:
|
|
532
|
+
"""Outer product of scalar fields: `outer(v, u)` (test, trial)."""
|
|
533
|
+
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
534
|
+
raise TypeError("outer expects FieldRef operands.")
|
|
535
|
+
if a.role != "test" or b.role != "trial":
|
|
536
|
+
raise TypeError("outer expects outer(test, trial).")
|
|
537
|
+
return Expr("outer", a, b)
|
|
538
|
+
|
|
539
|
+
|
|
288
540
|
def dot(a, b) -> Expr:
|
|
289
541
|
"""Dot product or vector load helper."""
|
|
290
542
|
return Expr("dot", _as_expr(a), _as_expr(b))
|
|
@@ -303,7 +555,7 @@ def ddot(a, b, c=None) -> Expr:
|
|
|
303
555
|
|
|
304
556
|
|
|
305
557
|
def inner(a, b) -> Expr:
|
|
306
|
-
"""Inner product over the last axis."""
|
|
558
|
+
"""Inner product over the last axis (tensor-level)."""
|
|
307
559
|
return Expr("inner", _as_expr(a), _as_expr(b))
|
|
308
560
|
|
|
309
561
|
|
|
@@ -362,6 +614,16 @@ def transpose_last2(a) -> Expr:
|
|
|
362
614
|
return Expr("transpose_last2", _as_expr(a))
|
|
363
615
|
|
|
364
616
|
|
|
617
|
+
def matmul(a, b) -> Expr:
|
|
618
|
+
"""FEM-specific batched contraction (same semantics as `@`)."""
|
|
619
|
+
return Expr("matmul", _as_expr(a), _as_expr(b))
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def matmul_std(a, b) -> Expr:
|
|
623
|
+
"""Standard matrix product (`jnp.matmul` semantics)."""
|
|
624
|
+
return Expr("matmul_std", _as_expr(a), _as_expr(b))
|
|
625
|
+
|
|
626
|
+
|
|
365
627
|
def einsum(subscripts: str, *args) -> Expr:
|
|
366
628
|
"""Einsum wrapper that supports Expr inputs."""
|
|
367
629
|
return Expr("einsum", subscripts, *[_as_expr(arg) for arg in args])
|
|
@@ -369,9 +631,22 @@ def einsum(subscripts: str, *args) -> Expr:
|
|
|
369
631
|
|
|
370
632
|
def _call_user(fn, *args, params):
|
|
371
633
|
try:
|
|
634
|
+
sig = inspect.signature(fn)
|
|
635
|
+
except (TypeError, ValueError):
|
|
636
|
+
return fn(*args, params)
|
|
637
|
+
|
|
638
|
+
params_list = list(sig.parameters.values())
|
|
639
|
+
if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params_list):
|
|
640
|
+
return fn(*args, params)
|
|
641
|
+
positional = [
|
|
642
|
+
p
|
|
643
|
+
for p in params_list
|
|
644
|
+
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
645
|
+
]
|
|
646
|
+
max_positional = len(positional)
|
|
647
|
+
if len(args) + 1 <= max_positional:
|
|
372
648
|
return fn(*args, params)
|
|
373
|
-
|
|
374
|
-
return fn(*args)
|
|
649
|
+
return fn(*args)
|
|
375
650
|
|
|
376
651
|
|
|
377
652
|
def compile_bilinear(fn):
|
|
@@ -382,19 +657,26 @@ def compile_bilinear(fn):
|
|
|
382
657
|
u = trial_ref()
|
|
383
658
|
v = test_ref()
|
|
384
659
|
p = param_ref()
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
660
|
+
expr = _call_user(fn, u, v, params=p)
|
|
661
|
+
expr = _as_expr(expr)
|
|
662
|
+
if not isinstance(expr, Expr):
|
|
663
|
+
raise TypeError("Bilinear form must return an Expr.")
|
|
389
664
|
|
|
390
|
-
|
|
391
|
-
|
|
665
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
666
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
667
|
+
if volume_count == 0:
|
|
392
668
|
raise ValueError("Volume bilinear form must include dOmega().")
|
|
669
|
+
if volume_count > 1:
|
|
670
|
+
raise ValueError("Volume bilinear form must include dOmega() exactly once.")
|
|
671
|
+
if surface_count > 0:
|
|
672
|
+
raise ValueError("Volume bilinear form must not include ds().")
|
|
673
|
+
|
|
674
|
+
plan = make_eval_plan(expr)
|
|
393
675
|
|
|
394
676
|
def _form(ctx, params):
|
|
395
|
-
return
|
|
677
|
+
return eval_with_plan(plan, ctx, params)
|
|
396
678
|
|
|
397
|
-
_form._includes_measure =
|
|
679
|
+
_form._includes_measure = True
|
|
398
680
|
return _form
|
|
399
681
|
|
|
400
682
|
|
|
@@ -405,19 +687,26 @@ def compile_linear(fn):
|
|
|
405
687
|
else:
|
|
406
688
|
v = test_ref()
|
|
407
689
|
p = param_ref()
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
690
|
+
expr = _call_user(fn, v, params=p)
|
|
691
|
+
expr = _as_expr(expr)
|
|
692
|
+
if not isinstance(expr, Expr):
|
|
693
|
+
raise TypeError("Linear form must return an Expr.")
|
|
412
694
|
|
|
413
|
-
|
|
414
|
-
|
|
695
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
696
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
697
|
+
if volume_count == 0:
|
|
415
698
|
raise ValueError("Volume linear form must include dOmega().")
|
|
699
|
+
if volume_count > 1:
|
|
700
|
+
raise ValueError("Volume linear form must include dOmega() exactly once.")
|
|
701
|
+
if surface_count > 0:
|
|
702
|
+
raise ValueError("Volume linear form must not include ds().")
|
|
703
|
+
|
|
704
|
+
plan = make_eval_plan(expr)
|
|
416
705
|
|
|
417
706
|
def _form(ctx, params):
|
|
418
|
-
return
|
|
707
|
+
return eval_with_plan(plan, ctx, params)
|
|
419
708
|
|
|
420
|
-
_form._includes_measure =
|
|
709
|
+
_form._includes_measure = True
|
|
421
710
|
return _form
|
|
422
711
|
|
|
423
712
|
|
|
@@ -429,6 +718,340 @@ def _expr_contains(expr: Expr, op: str) -> bool:
|
|
|
429
718
|
return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
|
|
430
719
|
|
|
431
720
|
|
|
721
|
+
def _count_op(expr: Expr, op: str) -> int:
|
|
722
|
+
if not isinstance(expr, Expr):
|
|
723
|
+
return 0
|
|
724
|
+
count = 1 if expr.op == op else 0
|
|
725
|
+
for arg in expr.args:
|
|
726
|
+
if isinstance(arg, Expr):
|
|
727
|
+
count += _count_op(arg, op)
|
|
728
|
+
return count
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
@dataclass(frozen=True, slots=True)
|
|
732
|
+
class EvalPlan:
|
|
733
|
+
expr: Expr
|
|
734
|
+
nodes: tuple[Expr, ...]
|
|
735
|
+
index: dict[int, int]
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def _validate_eval_plan(nodes: tuple[Expr, ...]) -> None:
|
|
739
|
+
fieldref_arg_ops = {
|
|
740
|
+
"value",
|
|
741
|
+
"grad",
|
|
742
|
+
"sym_grad",
|
|
743
|
+
"dot",
|
|
744
|
+
"sdot",
|
|
745
|
+
"action",
|
|
746
|
+
"gaction",
|
|
747
|
+
"outer",
|
|
748
|
+
}
|
|
749
|
+
for node in nodes:
|
|
750
|
+
op = node.op
|
|
751
|
+
args = node.args
|
|
752
|
+
if op not in fieldref_arg_ops:
|
|
753
|
+
if any(isinstance(arg, FieldRef) for arg in args):
|
|
754
|
+
raise TypeError(f"{op} cannot take FieldRef directly; wrap with .val/.grad/.sym_grad.")
|
|
755
|
+
if op in {"value", "grad", "sym_grad"}:
|
|
756
|
+
if len(args) != 1 or not isinstance(args[0], FieldRef):
|
|
757
|
+
raise TypeError(f"{op} expects FieldRef.")
|
|
758
|
+
elif op in {"dot", "sdot"}:
|
|
759
|
+
if len(args) != 2:
|
|
760
|
+
raise TypeError(f"{op} expects two arguments.")
|
|
761
|
+
if any(isinstance(arg, FieldRef) for arg in args):
|
|
762
|
+
if not isinstance(args[0], FieldRef):
|
|
763
|
+
raise TypeError(f"{op} expects FieldRef as the first argument.")
|
|
764
|
+
if isinstance(args[1], FieldRef):
|
|
765
|
+
raise TypeError(f"{op} expects an expression for the second argument; use .val/.grad.")
|
|
766
|
+
elif op in {"action", "gaction"}:
|
|
767
|
+
if len(args) != 2 or not isinstance(args[0], FieldRef):
|
|
768
|
+
raise TypeError(f"{op} expects FieldRef as the first argument.")
|
|
769
|
+
if op == "action" and isinstance(args[1], FieldRef):
|
|
770
|
+
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
771
|
+
if op == "gaction" and isinstance(args[1], FieldRef):
|
|
772
|
+
raise TypeError("gaction expects an expression for the second argument; use .grad.")
|
|
773
|
+
elif op == "outer":
|
|
774
|
+
if len(args) != 2 or not all(isinstance(arg, FieldRef) for arg in args):
|
|
775
|
+
raise TypeError("outer expects two FieldRef operands.")
|
|
776
|
+
if args[0].role != "test" or args[1].role != "trial":
|
|
777
|
+
raise TypeError("outer expects outer(test, trial).")
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def make_eval_plan(expr: Expr) -> EvalPlan:
|
|
781
|
+
nodes = tuple(expr.postorder_expr())
|
|
782
|
+
_validate_eval_plan(nodes)
|
|
783
|
+
index: dict[int, int] = {}
|
|
784
|
+
for i, node in enumerate(nodes):
|
|
785
|
+
index.setdefault(id(node), i)
|
|
786
|
+
return EvalPlan(expr=expr, nodes=nodes, index=index)
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def eval_with_plan(
|
|
790
|
+
plan: EvalPlan,
|
|
791
|
+
ctx: VolumeContext | SurfaceContext,
|
|
792
|
+
params: ParamsLike,
|
|
793
|
+
u_elem: UElement | None = None,
|
|
794
|
+
):
|
|
795
|
+
nodes = plan.nodes
|
|
796
|
+
index = plan.index
|
|
797
|
+
vals: list[Any] = [None] * len(nodes)
|
|
798
|
+
|
|
799
|
+
def get(obj):
|
|
800
|
+
if isinstance(obj, Expr):
|
|
801
|
+
return vals[index[id(obj)]]
|
|
802
|
+
if isinstance(obj, FieldRef):
|
|
803
|
+
raise TypeError(
|
|
804
|
+
"FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
|
|
805
|
+
)
|
|
806
|
+
if isinstance(obj, ParamRef):
|
|
807
|
+
return params
|
|
808
|
+
return obj
|
|
809
|
+
|
|
810
|
+
for i, node in enumerate(nodes):
|
|
811
|
+
op = node.op
|
|
812
|
+
args = node.args
|
|
813
|
+
|
|
814
|
+
if op == "lit":
|
|
815
|
+
vals[i] = args[0]
|
|
816
|
+
continue
|
|
817
|
+
if op == "getattr":
|
|
818
|
+
base = get(args[0])
|
|
819
|
+
name = args[1]
|
|
820
|
+
if isinstance(base, dict):
|
|
821
|
+
vals[i] = base[name]
|
|
822
|
+
else:
|
|
823
|
+
vals[i] = getattr(base, name)
|
|
824
|
+
continue
|
|
825
|
+
if op == "value":
|
|
826
|
+
ref = args[0]
|
|
827
|
+
assert isinstance(ref, FieldRef)
|
|
828
|
+
field = _eval_field(ref, ctx, params)
|
|
829
|
+
if ref.role == "unknown":
|
|
830
|
+
vals[i] = _eval_unknown_value(ref, field, u_elem)
|
|
831
|
+
else:
|
|
832
|
+
vals[i] = field.N
|
|
833
|
+
continue
|
|
834
|
+
if op == "grad":
|
|
835
|
+
ref = args[0]
|
|
836
|
+
assert isinstance(ref, FieldRef)
|
|
837
|
+
field = _eval_field(ref, ctx, params)
|
|
838
|
+
if ref.role == "unknown":
|
|
839
|
+
vals[i] = _eval_unknown_grad(ref, field, u_elem)
|
|
840
|
+
else:
|
|
841
|
+
vals[i] = field.gradN
|
|
842
|
+
continue
|
|
843
|
+
if op == "pow":
|
|
844
|
+
base = get(args[0])
|
|
845
|
+
exp = get(args[1])
|
|
846
|
+
vals[i] = base**exp
|
|
847
|
+
continue
|
|
848
|
+
if op == "eye":
|
|
849
|
+
vals[i] = jnp.eye(int(args[0]))
|
|
850
|
+
continue
|
|
851
|
+
if op == "det":
|
|
852
|
+
vals[i] = jnp.linalg.det(get(args[0]))
|
|
853
|
+
continue
|
|
854
|
+
if op == "inv":
|
|
855
|
+
vals[i] = jnp.linalg.inv(get(args[0]))
|
|
856
|
+
continue
|
|
857
|
+
if op == "transpose":
|
|
858
|
+
vals[i] = jnp.swapaxes(get(args[0]), -1, -2)
|
|
859
|
+
continue
|
|
860
|
+
if op == "log":
|
|
861
|
+
vals[i] = jnp.log(get(args[0]))
|
|
862
|
+
continue
|
|
863
|
+
if op == "surface_normal":
|
|
864
|
+
normal = getattr(ctx, "normal", None)
|
|
865
|
+
if normal is None:
|
|
866
|
+
raise ValueError("surface normal is not available in context")
|
|
867
|
+
vals[i] = normal
|
|
868
|
+
continue
|
|
869
|
+
if op == "surface_measure":
|
|
870
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
|
|
871
|
+
raise TypeError("surface measure requires SurfaceContext.")
|
|
872
|
+
vals[i] = ctx.w * ctx.detJ
|
|
873
|
+
continue
|
|
874
|
+
if op == "volume_measure":
|
|
875
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
|
|
876
|
+
raise TypeError("volume measure requires VolumeContext.")
|
|
877
|
+
vals[i] = ctx.w * ctx.test.detJ
|
|
878
|
+
continue
|
|
879
|
+
if op == "sym_grad":
|
|
880
|
+
ref = args[0]
|
|
881
|
+
assert isinstance(ref, FieldRef)
|
|
882
|
+
field = _eval_field(ref, ctx, params)
|
|
883
|
+
if ref.role == "unknown":
|
|
884
|
+
if u_elem is None:
|
|
885
|
+
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
886
|
+
u_local = _extract_unknown_elem(ref, u_elem)
|
|
887
|
+
vals[i] = _ops.sym_grad_u(field, u_local)
|
|
888
|
+
else:
|
|
889
|
+
vals[i] = _ops.sym_grad(field)
|
|
890
|
+
continue
|
|
891
|
+
if op == "outer":
|
|
892
|
+
a, b = args
|
|
893
|
+
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
894
|
+
raise TypeError("outer expects FieldRef operands.")
|
|
895
|
+
test, trial = a, b
|
|
896
|
+
vals[i] = _basis_outer(test, trial, ctx, params)
|
|
897
|
+
continue
|
|
898
|
+
if op == "add":
|
|
899
|
+
vals[i] = get(args[0]) + get(args[1])
|
|
900
|
+
continue
|
|
901
|
+
if op == "sub":
|
|
902
|
+
vals[i] = get(args[0]) - get(args[1])
|
|
903
|
+
continue
|
|
904
|
+
if op == "mul":
|
|
905
|
+
a = get(args[0])
|
|
906
|
+
b = get(args[1])
|
|
907
|
+
if hasattr(a, "ndim") and hasattr(b, "ndim"):
|
|
908
|
+
if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
|
|
909
|
+
a = a[:, None]
|
|
910
|
+
elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
|
|
911
|
+
b = b[:, None]
|
|
912
|
+
elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
|
|
913
|
+
b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
|
|
914
|
+
elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
|
|
915
|
+
a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
|
|
916
|
+
vals[i] = a * b
|
|
917
|
+
continue
|
|
918
|
+
if op == "matmul":
|
|
919
|
+
a = get(args[0])
|
|
920
|
+
b = get(args[1])
|
|
921
|
+
if (
|
|
922
|
+
hasattr(a, "ndim")
|
|
923
|
+
and hasattr(b, "ndim")
|
|
924
|
+
and a.ndim == 3
|
|
925
|
+
and b.ndim == 3
|
|
926
|
+
and a.shape[0] == b.shape[0]
|
|
927
|
+
and a.shape[-1] == b.shape[-1]
|
|
928
|
+
):
|
|
929
|
+
vals[i] = jnp.einsum("qia,qja->qij", a, b)
|
|
930
|
+
else:
|
|
931
|
+
raise TypeError(
|
|
932
|
+
"Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
|
|
933
|
+
)
|
|
934
|
+
continue
|
|
935
|
+
if op == "matmul_std":
|
|
936
|
+
a = get(args[0])
|
|
937
|
+
b = get(args[1])
|
|
938
|
+
vals[i] = jnp.matmul(a, b)
|
|
939
|
+
continue
|
|
940
|
+
if op == "neg":
|
|
941
|
+
vals[i] = -get(args[0])
|
|
942
|
+
continue
|
|
943
|
+
if op == "dot":
|
|
944
|
+
ref = args[0]
|
|
945
|
+
if isinstance(ref, FieldRef):
|
|
946
|
+
vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
|
|
947
|
+
else:
|
|
948
|
+
a = get(args[0])
|
|
949
|
+
b = get(args[1])
|
|
950
|
+
if (
|
|
951
|
+
hasattr(a, "ndim")
|
|
952
|
+
and hasattr(b, "ndim")
|
|
953
|
+
and a.ndim == 3
|
|
954
|
+
and b.ndim == 3
|
|
955
|
+
and a.shape[-1] == b.shape[-1]
|
|
956
|
+
):
|
|
957
|
+
vals[i] = jnp.einsum("qia,qja->qij", a, b)
|
|
958
|
+
else:
|
|
959
|
+
vals[i] = jnp.matmul(a, b)
|
|
960
|
+
continue
|
|
961
|
+
if op == "sdot":
|
|
962
|
+
ref = args[0]
|
|
963
|
+
if isinstance(ref, FieldRef):
|
|
964
|
+
vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
|
|
965
|
+
else:
|
|
966
|
+
a = get(args[0])
|
|
967
|
+
b = get(args[1])
|
|
968
|
+
if (
|
|
969
|
+
hasattr(a, "ndim")
|
|
970
|
+
and hasattr(b, "ndim")
|
|
971
|
+
and a.ndim == 3
|
|
972
|
+
and b.ndim == 3
|
|
973
|
+
and a.shape[-1] == b.shape[-1]
|
|
974
|
+
):
|
|
975
|
+
vals[i] = jnp.einsum("qia,qja->qij", a, b)
|
|
976
|
+
else:
|
|
977
|
+
vals[i] = jnp.matmul(a, b)
|
|
978
|
+
continue
|
|
979
|
+
if op == "ddot":
|
|
980
|
+
if len(args) == 2:
|
|
981
|
+
a = get(args[0])
|
|
982
|
+
b = get(args[1])
|
|
983
|
+
if (
|
|
984
|
+
hasattr(a, "ndim")
|
|
985
|
+
and hasattr(b, "ndim")
|
|
986
|
+
and a.ndim == 3
|
|
987
|
+
and b.ndim == 3
|
|
988
|
+
and a.shape[0] == b.shape[0]
|
|
989
|
+
and a.shape[1] == b.shape[1]
|
|
990
|
+
):
|
|
991
|
+
vals[i] = jnp.einsum("qik,qim->qkm", a, b)
|
|
992
|
+
else:
|
|
993
|
+
vals[i] = _ops.ddot(a, b)
|
|
994
|
+
else:
|
|
995
|
+
vals[i] = _ops.ddot(get(args[0]), get(args[1]), get(args[2]))
|
|
996
|
+
continue
|
|
997
|
+
if op == "inner":
|
|
998
|
+
a = get(args[0])
|
|
999
|
+
b = get(args[1])
|
|
1000
|
+
vals[i] = jnp.einsum("...i,...i->...", a, b)
|
|
1001
|
+
continue
|
|
1002
|
+
if op == "action":
|
|
1003
|
+
ref = args[0]
|
|
1004
|
+
assert isinstance(ref, FieldRef)
|
|
1005
|
+
if isinstance(args[1], FieldRef):
|
|
1006
|
+
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
1007
|
+
v_field = _eval_field(ref, ctx, params)
|
|
1008
|
+
s = get(args[1])
|
|
1009
|
+
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
1010
|
+
# action maps a test field with a scalar/vector expression into nodal space.
|
|
1011
|
+
if value_dim == 1:
|
|
1012
|
+
if v_field.N.ndim != 2:
|
|
1013
|
+
raise ValueError("action expects scalar test field with N shape (q, ndofs).")
|
|
1014
|
+
if hasattr(s, "ndim") and s.ndim not in (0, 1):
|
|
1015
|
+
raise ValueError("action expects scalar s with shape (q,) or scalar.")
|
|
1016
|
+
vals[i] = v_field.N * s
|
|
1017
|
+
else:
|
|
1018
|
+
if hasattr(s, "ndim") and s.ndim not in (1, 2):
|
|
1019
|
+
raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
|
|
1020
|
+
vals[i] = _ops.dot(v_field, s)
|
|
1021
|
+
continue
|
|
1022
|
+
if op == "gaction":
|
|
1023
|
+
ref = args[0]
|
|
1024
|
+
assert isinstance(ref, FieldRef)
|
|
1025
|
+
v_field = _eval_field(ref, ctx, params)
|
|
1026
|
+
q = get(args[1])
|
|
1027
|
+
# gaction maps a flux-like expression to nodal space via test gradients.
|
|
1028
|
+
if v_field.gradN.ndim != 3:
|
|
1029
|
+
raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
|
|
1030
|
+
if not hasattr(q, "ndim"):
|
|
1031
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
1032
|
+
if q.ndim == 2:
|
|
1033
|
+
vals[i] = jnp.einsum("qaj,qj->qa", v_field.gradN, q)
|
|
1034
|
+
elif q.ndim == 3:
|
|
1035
|
+
if int(getattr(v_field, "value_dim", 1)) == 1:
|
|
1036
|
+
raise ValueError("gaction tensor flux requires vector test field.")
|
|
1037
|
+
vals[i] = jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
|
|
1038
|
+
else:
|
|
1039
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
1040
|
+
continue
|
|
1041
|
+
if op == "transpose_last2":
|
|
1042
|
+
vals[i] = _ops.transpose_last2(get(args[0]))
|
|
1043
|
+
continue
|
|
1044
|
+
if op == "einsum":
|
|
1045
|
+
subscripts = args[0]
|
|
1046
|
+
operands = [get(arg) for arg in args[1:]]
|
|
1047
|
+
vals[i] = jnp.einsum(subscripts, *operands)
|
|
1048
|
+
continue
|
|
1049
|
+
|
|
1050
|
+
raise ValueError(f"Unknown Expr op: {op}")
|
|
1051
|
+
|
|
1052
|
+
return vals[index[id(plan.expr)]]
|
|
1053
|
+
|
|
1054
|
+
|
|
432
1055
|
def compile_surface_linear(fn):
|
|
433
1056
|
"""get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
|
|
434
1057
|
if isinstance(fn, Expr):
|
|
@@ -436,26 +1059,27 @@ def compile_surface_linear(fn):
|
|
|
436
1059
|
else:
|
|
437
1060
|
v = test_ref()
|
|
438
1061
|
p = param_ref()
|
|
439
|
-
expr =
|
|
440
|
-
try:
|
|
441
|
-
expr = fn(v, p)
|
|
442
|
-
except TypeError:
|
|
443
|
-
try:
|
|
444
|
-
expr = fn(v)
|
|
445
|
-
except TypeError:
|
|
446
|
-
expr = None
|
|
1062
|
+
expr = _call_user(fn, v, params=p)
|
|
447
1063
|
|
|
1064
|
+
expr = _as_expr(expr)
|
|
448
1065
|
if not isinstance(expr, Expr):
|
|
449
1066
|
raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
|
|
450
1067
|
|
|
451
|
-
|
|
452
|
-
|
|
1068
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
1069
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
1070
|
+
if surface_count == 0:
|
|
453
1071
|
raise ValueError("Surface linear form must include ds().")
|
|
1072
|
+
if surface_count > 1:
|
|
1073
|
+
raise ValueError("Surface linear form must include ds() exactly once.")
|
|
1074
|
+
if volume_count > 0:
|
|
1075
|
+
raise ValueError("Surface linear form must not include dOmega().")
|
|
1076
|
+
|
|
1077
|
+
plan = make_eval_plan(expr)
|
|
454
1078
|
|
|
455
1079
|
def _form(ctx, params):
|
|
456
|
-
return
|
|
1080
|
+
return eval_with_plan(plan, ctx, params)
|
|
457
1081
|
|
|
458
|
-
_form._includes_measure =
|
|
1082
|
+
_form._includes_measure = True # type: ignore[attr-defined]
|
|
459
1083
|
return _form
|
|
460
1084
|
|
|
461
1085
|
|
|
@@ -519,25 +1143,33 @@ def compile_residual(fn):
|
|
|
519
1143
|
v = test_ref()
|
|
520
1144
|
u = unknown_ref()
|
|
521
1145
|
p = param_ref()
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
1146
|
+
expr = _call_user(fn, v, u, params=p)
|
|
1147
|
+
expr = _as_expr(expr)
|
|
1148
|
+
if not isinstance(expr, Expr):
|
|
1149
|
+
raise TypeError("Residual form must return an Expr.")
|
|
526
1150
|
|
|
527
|
-
|
|
528
|
-
|
|
1151
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
1152
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
1153
|
+
if volume_count == 0:
|
|
529
1154
|
raise ValueError("Volume residual form must include dOmega().")
|
|
1155
|
+
if volume_count > 1:
|
|
1156
|
+
raise ValueError("Volume residual form must include dOmega() exactly once.")
|
|
1157
|
+
if surface_count > 0:
|
|
1158
|
+
raise ValueError("Volume residual form must not include ds().")
|
|
1159
|
+
|
|
1160
|
+
plan = make_eval_plan(expr)
|
|
530
1161
|
|
|
531
1162
|
def _form(ctx, u_elem, params):
|
|
532
|
-
return
|
|
1163
|
+
return eval_with_plan(plan, ctx, params, u_elem=u_elem)
|
|
533
1164
|
|
|
534
|
-
_form._includes_measure =
|
|
1165
|
+
_form._includes_measure = True
|
|
535
1166
|
return _form
|
|
536
1167
|
|
|
537
1168
|
|
|
538
1169
|
def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
539
1170
|
"""get_compiled mixed residuals keyed by field name."""
|
|
540
1171
|
compiled = {}
|
|
1172
|
+
plans = {}
|
|
541
1173
|
includes_measure = {}
|
|
542
1174
|
for name, fn in residuals.items():
|
|
543
1175
|
if isinstance(fn, Expr):
|
|
@@ -546,17 +1178,24 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
|
546
1178
|
v = test_ref(name)
|
|
547
1179
|
u = unknown_ref(name)
|
|
548
1180
|
p = param_ref()
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
compiled[name] =
|
|
554
|
-
|
|
555
|
-
|
|
1181
|
+
expr = _call_user(fn, v, u, params=p)
|
|
1182
|
+
expr = _as_expr(expr)
|
|
1183
|
+
if not isinstance(expr, Expr):
|
|
1184
|
+
raise TypeError(f"Mixed residual '{name}' must return an Expr.")
|
|
1185
|
+
compiled[name] = expr
|
|
1186
|
+
plans[name] = make_eval_plan(expr)
|
|
1187
|
+
volume_count = _count_op(compiled[name], "volume_measure")
|
|
1188
|
+
surface_count = _count_op(compiled[name], "surface_measure")
|
|
1189
|
+
includes_measure[name] = volume_count == 1
|
|
1190
|
+
if volume_count == 0:
|
|
556
1191
|
raise ValueError(f"Mixed residual '{name}' must include dOmega().")
|
|
1192
|
+
if volume_count > 1:
|
|
1193
|
+
raise ValueError(f"Mixed residual '{name}' must include dOmega() exactly once.")
|
|
1194
|
+
if surface_count > 0:
|
|
1195
|
+
raise ValueError(f"Mixed residual '{name}' must not include ds().")
|
|
557
1196
|
|
|
558
1197
|
def _form(ctx, u_elem, params):
|
|
559
|
-
return {name:
|
|
1198
|
+
return {name: eval_with_plan(plan, ctx, params, u_elem=u_elem) for name, plan in plans.items()}
|
|
560
1199
|
|
|
561
1200
|
_form._includes_measure = includes_measure
|
|
562
1201
|
return _form
|
|
@@ -574,216 +1213,14 @@ class MixedWeakForm:
|
|
|
574
1213
|
return compile_mixed_residual(self.residuals)
|
|
575
1214
|
|
|
576
1215
|
|
|
577
|
-
def _eval_expr(
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
if op == "getattr":
|
|
586
|
-
base = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
587
|
-
name = args[1]
|
|
588
|
-
if isinstance(base, dict):
|
|
589
|
-
return base[name]
|
|
590
|
-
return getattr(base, name)
|
|
591
|
-
if op == "field":
|
|
592
|
-
role, name = args
|
|
593
|
-
if name is not None:
|
|
594
|
-
if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
|
|
595
|
-
if name in ctx.trial_fields:
|
|
596
|
-
return ctx.trial_fields[name]
|
|
597
|
-
if role == "test" and getattr(ctx, "test_fields", None) is not None:
|
|
598
|
-
if name in ctx.test_fields:
|
|
599
|
-
return ctx.test_fields[name]
|
|
600
|
-
if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
|
|
601
|
-
if name in ctx.unknown_fields:
|
|
602
|
-
return ctx.unknown_fields[name]
|
|
603
|
-
fields = getattr(ctx, "fields", None)
|
|
604
|
-
if fields is not None and name in fields:
|
|
605
|
-
group = fields[name]
|
|
606
|
-
if isinstance(group, dict):
|
|
607
|
-
if role in group:
|
|
608
|
-
return group[role]
|
|
609
|
-
if "field" in group:
|
|
610
|
-
return group["field"]
|
|
611
|
-
return group
|
|
612
|
-
if role == "trial":
|
|
613
|
-
return ctx.trial
|
|
614
|
-
if role == "test":
|
|
615
|
-
return ctx.test
|
|
616
|
-
if role == "unknown":
|
|
617
|
-
return getattr(ctx, "unknown", ctx.trial)
|
|
618
|
-
raise ValueError(f"Unknown field role: {role}")
|
|
619
|
-
if op == "value":
|
|
620
|
-
field = _eval_field(args[0], ctx, params)
|
|
621
|
-
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
622
|
-
return _eval_unknown_value(args[0], field, u_elem)
|
|
623
|
-
return field.N
|
|
624
|
-
if op == "grad":
|
|
625
|
-
field = _eval_field(args[0], ctx, params)
|
|
626
|
-
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
627
|
-
return _eval_unknown_grad(args[0], field, u_elem)
|
|
628
|
-
return field.gradN
|
|
629
|
-
if op == "pow":
|
|
630
|
-
base = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
631
|
-
exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
632
|
-
return base**exp
|
|
633
|
-
if op == "eye":
|
|
634
|
-
return jnp.eye(int(args[0]))
|
|
635
|
-
if op == "det":
|
|
636
|
-
return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
637
|
-
if op == "inv":
|
|
638
|
-
return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
639
|
-
if op == "transpose":
|
|
640
|
-
return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
|
|
641
|
-
if op == "log":
|
|
642
|
-
return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
643
|
-
if op == "surface_normal":
|
|
644
|
-
normal = getattr(ctx, "normal", None)
|
|
645
|
-
if normal is None:
|
|
646
|
-
raise ValueError("surface normal is not available in context")
|
|
647
|
-
return normal
|
|
648
|
-
if op == "surface_measure":
|
|
649
|
-
if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
|
|
650
|
-
raise ValueError("surface measure requires surface context with w and detJ.")
|
|
651
|
-
return ctx.w * ctx.detJ
|
|
652
|
-
if op == "volume_measure":
|
|
653
|
-
if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
|
|
654
|
-
raise ValueError("volume measure requires FormContext with w and test.detJ.")
|
|
655
|
-
return ctx.w * ctx.test.detJ
|
|
656
|
-
if op == "sym_grad":
|
|
657
|
-
field = _eval_field(args[0], ctx, params)
|
|
658
|
-
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
659
|
-
if u_elem is None:
|
|
660
|
-
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
661
|
-
u_local = _extract_unknown_elem(args[0], u_elem)
|
|
662
|
-
return _ops.sym_grad_u(field, u_local)
|
|
663
|
-
return _ops.sym_grad(field)
|
|
664
|
-
if op == "outer":
|
|
665
|
-
a, b = args
|
|
666
|
-
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
667
|
-
raise TypeError("outer expects FieldRef operands.")
|
|
668
|
-
if a.role == b.role:
|
|
669
|
-
raise ValueError("outer requires one trial and one test field.")
|
|
670
|
-
test = a if a.role == "test" else b
|
|
671
|
-
trial = b if a.role == "test" else a
|
|
672
|
-
v_field = _eval_field(test, ctx, params)
|
|
673
|
-
u_field = _eval_field(trial, ctx, params)
|
|
674
|
-
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
675
|
-
raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
|
|
676
|
-
vN = v_field.N
|
|
677
|
-
uN = u_field.N
|
|
678
|
-
return jnp.einsum("qi,qj->qij", vN, uN)
|
|
679
|
-
if op == "add":
|
|
680
|
-
return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
681
|
-
if op == "sub":
|
|
682
|
-
return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
683
|
-
if op == "mul":
|
|
684
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
685
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
686
|
-
if hasattr(a, "ndim") and hasattr(b, "ndim"):
|
|
687
|
-
if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
|
|
688
|
-
a = a[:, None]
|
|
689
|
-
elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
|
|
690
|
-
b = b[:, None]
|
|
691
|
-
elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
|
|
692
|
-
b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
|
|
693
|
-
elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
|
|
694
|
-
a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
|
|
695
|
-
return a * b
|
|
696
|
-
if op == "matmul":
|
|
697
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
698
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
699
|
-
if (
|
|
700
|
-
hasattr(a, "ndim")
|
|
701
|
-
and hasattr(b, "ndim")
|
|
702
|
-
and a.ndim == 3
|
|
703
|
-
and b.ndim == 3
|
|
704
|
-
and a.shape[0] == b.shape[0]
|
|
705
|
-
and a.shape[-1] == b.shape[-1]
|
|
706
|
-
):
|
|
707
|
-
return jnp.einsum("qia,qja->qij", a, b)
|
|
708
|
-
return a @ b
|
|
709
|
-
if op == "neg":
|
|
710
|
-
return -_eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
711
|
-
if op == "dot":
|
|
712
|
-
if isinstance(args[0], FieldRef):
|
|
713
|
-
return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
|
|
714
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
715
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
716
|
-
if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
|
|
717
|
-
return jnp.einsum("qia,qja->qij", a, b)
|
|
718
|
-
return jnp.matmul(a, b)
|
|
719
|
-
if op == "sdot":
|
|
720
|
-
if isinstance(args[0], FieldRef):
|
|
721
|
-
return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
|
|
722
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
723
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
724
|
-
if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
|
|
725
|
-
return jnp.einsum("qia,qja->qij", a, b)
|
|
726
|
-
return jnp.matmul(a, b)
|
|
727
|
-
if op == "ddot":
|
|
728
|
-
if len(args) == 2:
|
|
729
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
730
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
731
|
-
if (
|
|
732
|
-
hasattr(a, "ndim")
|
|
733
|
-
and hasattr(b, "ndim")
|
|
734
|
-
and a.ndim == 3
|
|
735
|
-
and b.ndim == 3
|
|
736
|
-
and a.shape[0] == b.shape[0]
|
|
737
|
-
and a.shape[1] == b.shape[1]
|
|
738
|
-
):
|
|
739
|
-
return jnp.einsum("qik,qim->qkm", a, b)
|
|
740
|
-
return _ops.ddot(a, b)
|
|
741
|
-
return _ops.ddot(
|
|
742
|
-
_eval_value(args[0], ctx, params, u_elem=u_elem),
|
|
743
|
-
_eval_value(args[1], ctx, params, u_elem=u_elem),
|
|
744
|
-
_eval_value(args[2], ctx, params, u_elem=u_elem),
|
|
745
|
-
)
|
|
746
|
-
if op == "inner":
|
|
747
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
748
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
749
|
-
return jnp.einsum("...i,...i->...", a, b)
|
|
750
|
-
if op == "action":
|
|
751
|
-
if isinstance(args[1], FieldRef):
|
|
752
|
-
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
753
|
-
v_field = _eval_field(args[0], ctx, params)
|
|
754
|
-
s = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
755
|
-
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
756
|
-
if value_dim == 1:
|
|
757
|
-
if v_field.N.ndim != 2:
|
|
758
|
-
raise ValueError("action expects scalar test field with N shape (q, ndofs).")
|
|
759
|
-
if hasattr(s, "ndim") and s.ndim not in (0, 1):
|
|
760
|
-
raise ValueError("action expects scalar s with shape (q,) or scalar.")
|
|
761
|
-
return v_field.N * s
|
|
762
|
-
if hasattr(s, "ndim") and s.ndim not in (1, 2):
|
|
763
|
-
raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
|
|
764
|
-
return _ops.dot(v_field, s)
|
|
765
|
-
if op == "gaction":
|
|
766
|
-
v_field = _eval_field(args[0], ctx, params)
|
|
767
|
-
q = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
768
|
-
if v_field.gradN.ndim != 3:
|
|
769
|
-
raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
|
|
770
|
-
if not hasattr(q, "ndim"):
|
|
771
|
-
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
772
|
-
if q.ndim == 2:
|
|
773
|
-
return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
|
|
774
|
-
if q.ndim == 3:
|
|
775
|
-
if int(getattr(v_field, "value_dim", 1)) == 1:
|
|
776
|
-
raise ValueError("gaction tensor flux requires vector test field.")
|
|
777
|
-
return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
|
|
778
|
-
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
779
|
-
if op == "transpose_last2":
|
|
780
|
-
return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
781
|
-
if op == "einsum":
|
|
782
|
-
subscripts = args[0]
|
|
783
|
-
operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
|
|
784
|
-
return jnp.einsum(subscripts, *operands)
|
|
785
|
-
|
|
786
|
-
raise ValueError(f"Unknown Expr op: {op}")
|
|
1216
|
+
def _eval_expr(
|
|
1217
|
+
expr: Expr,
|
|
1218
|
+
ctx: VolumeContext | SurfaceContext,
|
|
1219
|
+
params: ParamsLike,
|
|
1220
|
+
u_elem: UElement | None = None,
|
|
1221
|
+
):
|
|
1222
|
+
plan = make_eval_plan(expr)
|
|
1223
|
+
return eval_with_plan(plan, ctx, params, u_elem=u_elem)
|
|
787
1224
|
|
|
788
1225
|
|
|
789
1226
|
__all__ = [
|
|
@@ -803,6 +1240,7 @@ __all__ = [
|
|
|
803
1240
|
"compile_mixed_residual",
|
|
804
1241
|
"grad",
|
|
805
1242
|
"sym_grad",
|
|
1243
|
+
"outer",
|
|
806
1244
|
"dot",
|
|
807
1245
|
"ddot",
|
|
808
1246
|
"inner",
|
|
@@ -814,5 +1252,7 @@ __all__ = [
|
|
|
814
1252
|
"transpose",
|
|
815
1253
|
"log",
|
|
816
1254
|
"transpose_last2",
|
|
1255
|
+
"matmul",
|
|
1256
|
+
"matmul_std",
|
|
817
1257
|
"einsum",
|
|
818
1258
|
]
|