fluxfem 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +68 -161
- fluxfem/core/__init__.py +57 -31
- fluxfem/core/context_types.py +36 -0
- fluxfem/core/forms.py +5 -1
- fluxfem/core/weakform.py +742 -312
- fluxfem/helpers_wf.py +4 -0
- fluxfem/mesh/base.py +6 -1
- fluxfem/mesh/hex.py +1 -0
- fluxfem/mesh/io.py +2 -0
- {fluxfem-0.1.3.dist-info → fluxfem-0.1.4.dist-info}/METADATA +13 -11
- {fluxfem-0.1.3.dist-info → fluxfem-0.1.4.dist-info}/RECORD +13 -12
- {fluxfem-0.1.3.dist-info → fluxfem-0.1.4.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.3.dist-info → fluxfem-0.1.4.dist-info}/WHEEL +0 -0
fluxfem/core/weakform.py
CHANGED
|
@@ -1,82 +1,312 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Callable
|
|
4
|
+
from typing import Any, Callable, Iterator, Literal, get_args
|
|
5
|
+
import inspect
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
5
8
|
|
|
6
9
|
import jax.numpy as jnp
|
|
7
10
|
import jax
|
|
8
11
|
|
|
9
12
|
from ..physics import operators as _ops
|
|
13
|
+
from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
|
|
10
14
|
|
|
11
15
|
|
|
16
|
+
OpName = Literal[
|
|
17
|
+
"lit",
|
|
18
|
+
"getattr",
|
|
19
|
+
"value",
|
|
20
|
+
"grad",
|
|
21
|
+
"pow",
|
|
22
|
+
"eye",
|
|
23
|
+
"det",
|
|
24
|
+
"inv",
|
|
25
|
+
"transpose",
|
|
26
|
+
"log",
|
|
27
|
+
"surface_normal",
|
|
28
|
+
"surface_measure",
|
|
29
|
+
"volume_measure",
|
|
30
|
+
"sym_grad",
|
|
31
|
+
"outer",
|
|
32
|
+
"add",
|
|
33
|
+
"sub",
|
|
34
|
+
"mul",
|
|
35
|
+
"matmul",
|
|
36
|
+
"matmul_std",
|
|
37
|
+
"neg",
|
|
38
|
+
"dot",
|
|
39
|
+
"sdot",
|
|
40
|
+
"ddot",
|
|
41
|
+
"inner",
|
|
42
|
+
"action",
|
|
43
|
+
"gaction",
|
|
44
|
+
"transpose_last2",
|
|
45
|
+
"einsum",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Use OpName as the single source of truth for valid ops.
|
|
49
|
+
_OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_PRECEDENCE: dict[str, int] = {
|
|
53
|
+
"add": 10,
|
|
54
|
+
"sub": 10,
|
|
55
|
+
"mul": 20,
|
|
56
|
+
"matmul": 20,
|
|
57
|
+
"matmul_std": 20,
|
|
58
|
+
"inner": 20,
|
|
59
|
+
"dot": 20,
|
|
60
|
+
"sdot": 20,
|
|
61
|
+
"ddot": 20,
|
|
62
|
+
"pow": 30,
|
|
63
|
+
"neg": 40,
|
|
64
|
+
"transpose": 50,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _pretty_render_arg(arg, prec: int | None = None) -> str:
|
|
69
|
+
if isinstance(arg, Expr):
|
|
70
|
+
return _pretty_expr(arg, prec or 0)
|
|
71
|
+
if isinstance(arg, FieldRef):
|
|
72
|
+
if arg.name is None:
|
|
73
|
+
return f"{arg.role}"
|
|
74
|
+
return f"{arg.role}:{arg.name}"
|
|
75
|
+
if isinstance(arg, ParamRef):
|
|
76
|
+
return "param"
|
|
77
|
+
return repr(arg)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _pretty_wrap(text: str, prec: int, parent_prec: int) -> str:
|
|
81
|
+
if prec < parent_prec:
|
|
82
|
+
return f"({text})"
|
|
83
|
+
return text
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _pretty_expr(expr: Expr, parent_prec: int = 0) -> str:
|
|
87
|
+
op = expr.op
|
|
88
|
+
args = expr.args
|
|
89
|
+
|
|
90
|
+
if op == "lit":
|
|
91
|
+
return repr(args[0])
|
|
92
|
+
if op == "getattr":
|
|
93
|
+
base = _pretty_render_arg(args[0], _PRECEDENCE.get("transpose", 50))
|
|
94
|
+
return f"{base}.{args[1]}"
|
|
95
|
+
if op == "value":
|
|
96
|
+
return f"val({_pretty_render_arg(args[0])})"
|
|
97
|
+
if op == "grad":
|
|
98
|
+
return f"grad({_pretty_render_arg(args[0])})"
|
|
99
|
+
if op == "sym_grad":
|
|
100
|
+
return f"sym_grad({_pretty_render_arg(args[0])})"
|
|
101
|
+
if op == "neg":
|
|
102
|
+
inner = _pretty_render_arg(args[0], _PRECEDENCE["neg"])
|
|
103
|
+
return _pretty_wrap(f"-{inner}", _PRECEDENCE["neg"], parent_prec)
|
|
104
|
+
if op == "transpose":
|
|
105
|
+
inner = _pretty_render_arg(args[0], _PRECEDENCE["transpose"])
|
|
106
|
+
return _pretty_wrap(f"{inner}.T", _PRECEDENCE["transpose"], parent_prec)
|
|
107
|
+
if op == "pow":
|
|
108
|
+
base = _pretty_render_arg(args[0], _PRECEDENCE["pow"])
|
|
109
|
+
exp = _pretty_render_arg(args[1], _PRECEDENCE["pow"] + 1)
|
|
110
|
+
return _pretty_wrap(f"{base}**{exp}", _PRECEDENCE["pow"], parent_prec)
|
|
111
|
+
if op in {"add", "sub", "mul", "matmul", "dot", "sdot", "ddot"}:
|
|
112
|
+
left = _pretty_render_arg(args[0], _PRECEDENCE[op])
|
|
113
|
+
right = _pretty_render_arg(args[1], _PRECEDENCE[op] + 1)
|
|
114
|
+
symbol = {
|
|
115
|
+
"add": "+",
|
|
116
|
+
"sub": "-",
|
|
117
|
+
"mul": "*",
|
|
118
|
+
"matmul": "@",
|
|
119
|
+
"inner": "|",
|
|
120
|
+
"dot": "dot",
|
|
121
|
+
"sdot": "sdot",
|
|
122
|
+
"ddot": "ddot",
|
|
123
|
+
}[op]
|
|
124
|
+
if symbol in {"dot", "sdot", "ddot"}:
|
|
125
|
+
text = f"{symbol}({left}, {right})"
|
|
126
|
+
else:
|
|
127
|
+
text = f"{left} {symbol} {right}"
|
|
128
|
+
return _pretty_wrap(text, _PRECEDENCE[op], parent_prec)
|
|
129
|
+
if op == "inner":
|
|
130
|
+
return f"inner({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
131
|
+
if op in {"action", "gaction"}:
|
|
132
|
+
return f"{op}({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
133
|
+
if op == "matmul_std":
|
|
134
|
+
return f"matmul_std({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
135
|
+
if op == "outer":
|
|
136
|
+
return f"outer({_pretty_render_arg(args[0])}, {_pretty_render_arg(args[1])})"
|
|
137
|
+
if op in {
|
|
138
|
+
"eye",
|
|
139
|
+
"det",
|
|
140
|
+
"inv",
|
|
141
|
+
"log",
|
|
142
|
+
"surface_normal",
|
|
143
|
+
"surface_measure",
|
|
144
|
+
"volume_measure",
|
|
145
|
+
"transpose_last2",
|
|
146
|
+
"einsum",
|
|
147
|
+
}:
|
|
148
|
+
rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
|
|
149
|
+
return f"{op}({rendered})"
|
|
150
|
+
rendered = ", ".join(_pretty_render_arg(arg) for arg in args)
|
|
151
|
+
return f"{op}({rendered})"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _as_expr(obj) -> Expr | FieldRef | ParamRef:
|
|
155
|
+
"""Normalize inputs into Expr/FieldRef/ParamRef nodes."""
|
|
156
|
+
if isinstance(obj, Expr):
|
|
157
|
+
return obj
|
|
158
|
+
if isinstance(obj, FieldRef):
|
|
159
|
+
return obj
|
|
160
|
+
if isinstance(obj, ParamRef):
|
|
161
|
+
return obj
|
|
162
|
+
if isinstance(obj, (int, float, bool, str)):
|
|
163
|
+
return Expr("lit", obj)
|
|
164
|
+
if isinstance(obj, np.generic):
|
|
165
|
+
return Expr("lit", obj.item())
|
|
166
|
+
if isinstance(obj, tuple):
|
|
167
|
+
try:
|
|
168
|
+
hash(obj)
|
|
169
|
+
except TypeError as exc:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
"Expr tuple literal must be hashable; use only immutable items."
|
|
172
|
+
) from exc
|
|
173
|
+
return Expr("lit", obj)
|
|
174
|
+
raise TypeError(
|
|
175
|
+
"Expr literal must be a scalar or hashable tuple. "
|
|
176
|
+
"Arrays are not allowed; pass them via params (ParamRef/params.xxx)."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@dataclass(frozen=True, slots=True, init=False)
|
|
12
181
|
class Expr:
|
|
13
|
-
"""Expression tree node evaluated against a FormContext.
|
|
182
|
+
"""Expression tree node evaluated against a FormContext.
|
|
183
|
+
|
|
184
|
+
Compile flow (recommended):
|
|
185
|
+
- build an Expr via operators/refs
|
|
186
|
+
- compile_* builds an EvalPlan (postorder nodes + index)
|
|
187
|
+
- eval_with_plan(plan, ctx, params, u_elem) evaluates per element
|
|
14
188
|
|
|
15
|
-
|
|
189
|
+
Expr.eval is a debug/single-shot path that creates a plan on demand.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
op: OpName
|
|
193
|
+
args: tuple[Any, ...]
|
|
194
|
+
|
|
195
|
+
def __init__(self, op: OpName, *args):
|
|
196
|
+
if op not in _OP_NAMES:
|
|
197
|
+
raise ValueError(f"Unknown Expr op: {op!r}")
|
|
16
198
|
object.__setattr__(self, "op", op)
|
|
17
199
|
object.__setattr__(self, "args", args)
|
|
18
200
|
|
|
19
201
|
def eval(self, ctx, params=None, u_elem=None):
|
|
202
|
+
"""Evaluate the expression against a context (debug/single-shot path)."""
|
|
20
203
|
return _eval_expr(self, ctx, params, u_elem=u_elem)
|
|
21
204
|
|
|
205
|
+
def children(self) -> tuple[Any, ...]:
|
|
206
|
+
"""Return direct child nodes (Expr/FieldRef/ParamRef) for traversal."""
|
|
207
|
+
return tuple(arg for arg in self.args if isinstance(arg, (Expr, FieldRef, ParamRef)))
|
|
208
|
+
|
|
209
|
+
def walk(self) -> Iterator[Any]:
|
|
210
|
+
"""Depth-first walk over nodes, including leaf FieldRef/ParamRef."""
|
|
211
|
+
yield self
|
|
212
|
+
for child in self.children():
|
|
213
|
+
if isinstance(child, Expr):
|
|
214
|
+
yield from child.walk()
|
|
215
|
+
else:
|
|
216
|
+
yield child
|
|
217
|
+
|
|
218
|
+
def postorder(self) -> Iterator[Any]:
|
|
219
|
+
"""Postorder walk over nodes, including leaf FieldRef/ParamRef."""
|
|
220
|
+
for child in self.children():
|
|
221
|
+
if isinstance(child, Expr):
|
|
222
|
+
yield from child.postorder()
|
|
223
|
+
else:
|
|
224
|
+
yield child
|
|
225
|
+
yield self
|
|
226
|
+
|
|
227
|
+
def postorder_expr(self) -> Iterator["Expr"]:
|
|
228
|
+
"""Postorder walk over Expr nodes only (for eval planning)."""
|
|
229
|
+
for arg in self.args:
|
|
230
|
+
if isinstance(arg, Expr):
|
|
231
|
+
yield from arg.postorder_expr()
|
|
232
|
+
yield self
|
|
233
|
+
|
|
22
234
|
def _binop(self, other, op):
|
|
23
235
|
return Expr(op, self, _as_expr(other))
|
|
24
236
|
|
|
25
237
|
def __add__(self, other):
|
|
238
|
+
"""Add expressions: `a + b`."""
|
|
26
239
|
return self._binop(other, "add")
|
|
27
240
|
|
|
28
241
|
def __radd__(self, other):
|
|
29
|
-
|
|
242
|
+
"""Right-add expressions: `1 + expr`."""
|
|
243
|
+
return Expr("add", _as_expr(other), self)
|
|
30
244
|
|
|
31
245
|
def __sub__(self, other):
|
|
246
|
+
"""Subtract expressions: `a - b`."""
|
|
32
247
|
return self._binop(other, "sub")
|
|
33
248
|
|
|
34
249
|
def __rsub__(self, other):
|
|
35
|
-
|
|
250
|
+
"""Right-subtract expressions: `1 - expr`."""
|
|
251
|
+
return Expr("sub", _as_expr(other), self)
|
|
36
252
|
|
|
37
253
|
def __mul__(self, other):
|
|
254
|
+
"""Multiply expressions: `a * b`."""
|
|
38
255
|
return self._binop(other, "mul")
|
|
39
256
|
|
|
40
257
|
def __rmul__(self, other):
|
|
41
|
-
|
|
258
|
+
"""Right-multiply expressions: `2 * expr`."""
|
|
259
|
+
return Expr("mul", _as_expr(other), self)
|
|
42
260
|
|
|
43
261
|
def __matmul__(self, other):
|
|
262
|
+
"""Matrix product: `a @ b` (FEM-specific contraction semantics)."""
|
|
44
263
|
return self._binop(other, "matmul")
|
|
45
264
|
|
|
46
265
|
def __rmatmul__(self, other):
|
|
47
|
-
|
|
266
|
+
"""Right-matmul: `A @ expr`."""
|
|
267
|
+
return Expr("matmul", _as_expr(other), self)
|
|
48
268
|
|
|
49
269
|
def __or__(self, other):
|
|
50
|
-
|
|
270
|
+
"""Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
|
|
271
|
+
if isinstance(other, FieldRef):
|
|
272
|
+
raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
|
|
273
|
+
return Expr("inner", self, _as_expr(other))
|
|
51
274
|
|
|
52
275
|
def __ror__(self, other):
|
|
53
|
-
|
|
276
|
+
"""Tensor inner product: `a | b` (use .val/.grad for FieldRef)."""
|
|
277
|
+
if isinstance(other, FieldRef):
|
|
278
|
+
raise TypeError("FieldRef | FieldRef is not supported; use outer(test, trial).")
|
|
279
|
+
return Expr("inner", _as_expr(other), self)
|
|
54
280
|
|
|
55
281
|
def __pow__(self, power, modulo=None):
|
|
282
|
+
"""Power: `a ** p` (no modulo support)."""
|
|
56
283
|
if modulo is not None:
|
|
57
284
|
raise ValueError("modulo is not supported for Expr exponentiation.")
|
|
58
285
|
return Expr("pow", self, _as_expr(power))
|
|
59
286
|
|
|
60
287
|
def __neg__(self):
|
|
288
|
+
"""Unary negation: `-expr`."""
|
|
61
289
|
return Expr("neg", self)
|
|
62
290
|
|
|
63
291
|
@property
|
|
64
292
|
def T(self):
|
|
293
|
+
"""Transpose view: `expr.T`."""
|
|
65
294
|
return Expr("transpose", self)
|
|
66
295
|
|
|
296
|
+
def __repr__(self) -> str:
|
|
297
|
+
return self.pretty()
|
|
298
|
+
|
|
299
|
+
def pretty(self) -> str:
|
|
300
|
+
return _pretty_expr(self)
|
|
301
|
+
|
|
67
302
|
|
|
68
|
-
@dataclass(frozen=True)
|
|
69
|
-
class FieldRef
|
|
303
|
+
@dataclass(frozen=True, slots=True)
|
|
304
|
+
class FieldRef:
|
|
70
305
|
"""Symbolic reference to trial/test/unknown field, optionally by name."""
|
|
71
306
|
|
|
72
307
|
role: str
|
|
73
308
|
name: str | None = None
|
|
74
309
|
|
|
75
|
-
def __init__(self, role: str, name: str | None = None):
|
|
76
|
-
object.__setattr__(self, "role", role)
|
|
77
|
-
object.__setattr__(self, "name", name)
|
|
78
|
-
super().__init__("field", role, name)
|
|
79
|
-
|
|
80
310
|
@property
|
|
81
311
|
def val(self):
|
|
82
312
|
return Expr("value", self)
|
|
@@ -91,12 +321,18 @@ class FieldRef(Expr):
|
|
|
91
321
|
|
|
92
322
|
def __mul__(self, other):
|
|
93
323
|
if isinstance(other, FieldRef):
|
|
94
|
-
|
|
324
|
+
raise TypeError(
|
|
325
|
+
"FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
|
|
326
|
+
"action(v, s), or dot(v, q)."
|
|
327
|
+
)
|
|
95
328
|
return Expr("mul", Expr("value", self), _as_expr(other))
|
|
96
329
|
|
|
97
330
|
def __rmul__(self, other):
|
|
98
331
|
if isinstance(other, FieldRef):
|
|
99
|
-
|
|
332
|
+
raise TypeError(
|
|
333
|
+
"FieldRef * FieldRef is ambiguous; use outer(v, u) (test, trial), "
|
|
334
|
+
"action(v, s), or dot(v, q)."
|
|
335
|
+
)
|
|
100
336
|
return Expr("mul", _as_expr(other), Expr("value", self))
|
|
101
337
|
|
|
102
338
|
def __add__(self, other):
|
|
@@ -113,22 +349,23 @@ class FieldRef(Expr):
|
|
|
113
349
|
|
|
114
350
|
def __or__(self, other):
|
|
115
351
|
if isinstance(other, FieldRef):
|
|
116
|
-
|
|
117
|
-
|
|
352
|
+
raise TypeError(
|
|
353
|
+
"FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
|
|
354
|
+
)
|
|
355
|
+
return Expr("dot", self, _as_expr(other))
|
|
118
356
|
|
|
119
357
|
def __ror__(self, other):
|
|
120
358
|
if isinstance(other, FieldRef):
|
|
121
|
-
|
|
122
|
-
|
|
359
|
+
raise TypeError(
|
|
360
|
+
"FieldRef | FieldRef is not supported; use outer(test, trial) for basis kernels."
|
|
361
|
+
)
|
|
362
|
+
return Expr("dot", _as_expr(other), self)
|
|
123
363
|
|
|
124
364
|
|
|
125
|
-
@dataclass(frozen=True)
|
|
126
|
-
class ParamRef
|
|
365
|
+
@dataclass(frozen=True, slots=True)
|
|
366
|
+
class ParamRef:
|
|
127
367
|
"""Symbolic reference to params passed into the kernel."""
|
|
128
368
|
|
|
129
|
-
def __init__(self):
|
|
130
|
-
super().__init__("param")
|
|
131
|
-
|
|
132
369
|
def __getattr__(self, name: str):
|
|
133
370
|
return Expr("getattr", self, name)
|
|
134
371
|
|
|
@@ -179,13 +416,11 @@ def param_ref() -> ParamRef:
|
|
|
179
416
|
return ParamRef()
|
|
180
417
|
|
|
181
418
|
|
|
182
|
-
def
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def _eval_field(obj: Any, ctx, params):
|
|
419
|
+
def _eval_field(
|
|
420
|
+
obj: Any,
|
|
421
|
+
ctx: VolumeContext | SurfaceContext,
|
|
422
|
+
params: ParamsLike,
|
|
423
|
+
) -> FormFieldLike:
|
|
189
424
|
if isinstance(obj, FieldRef):
|
|
190
425
|
if obj.name is not None:
|
|
191
426
|
mixed_fields = getattr(ctx, "fields", None)
|
|
@@ -226,25 +461,23 @@ def _eval_field(obj: Any, ctx, params):
|
|
|
226
461
|
if obj.role == "unknown":
|
|
227
462
|
return getattr(ctx, "unknown", ctx.trial)
|
|
228
463
|
raise ValueError(f"Unknown field role: {obj.role}")
|
|
229
|
-
if isinstance(obj, Expr):
|
|
230
|
-
val = obj.eval(ctx, params)
|
|
231
|
-
if hasattr(val, "N"):
|
|
232
|
-
return val
|
|
233
464
|
raise TypeError("Expected a field reference for this operator.")
|
|
234
465
|
|
|
235
466
|
|
|
236
|
-
def _eval_value(obj: Any, ctx, params, u_elem=None):
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
467
|
+
# def _eval_value(obj: Any, ctx, params, u_elem=None):
|
|
468
|
+
# if isinstance(obj, FieldRef):
|
|
469
|
+
# field = _eval_field(obj, ctx, params)
|
|
470
|
+
# if obj.role == "unknown":
|
|
471
|
+
# return _eval_unknown_value(obj, field, u_elem)
|
|
472
|
+
# return field.N
|
|
473
|
+
# if isinstance(obj, ParamRef):
|
|
474
|
+
# return params
|
|
475
|
+
# if isinstance(obj, Expr):
|
|
476
|
+
# return obj.eval(ctx, params, u_elem=u_elem)
|
|
477
|
+
# return obj
|
|
245
478
|
|
|
246
479
|
|
|
247
|
-
def _extract_unknown_elem(field_ref: FieldRef, u_elem):
|
|
480
|
+
def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
|
|
248
481
|
if u_elem is None:
|
|
249
482
|
raise ValueError("u_elem is required to evaluate unknown field value.")
|
|
250
483
|
if isinstance(u_elem, dict):
|
|
@@ -255,7 +488,17 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem):
|
|
|
255
488
|
return u_elem
|
|
256
489
|
|
|
257
490
|
|
|
258
|
-
def
|
|
491
|
+
def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
|
|
492
|
+
v_field = _eval_field(test, ctx, params)
|
|
493
|
+
u_field = _eval_field(trial, ctx, params)
|
|
494
|
+
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
495
|
+
raise ValueError(
|
|
496
|
+
"inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
|
|
497
|
+
)
|
|
498
|
+
return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
259
502
|
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
260
503
|
value_dim = int(getattr(field, "value_dim", 1))
|
|
261
504
|
if value_dim == 1:
|
|
@@ -264,7 +507,7 @@ def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
|
|
|
264
507
|
return jnp.einsum("qa,ai->qi", field.N, u_nodes)
|
|
265
508
|
|
|
266
509
|
|
|
267
|
-
def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
|
|
510
|
+
def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
268
511
|
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
269
512
|
if u_local is None:
|
|
270
513
|
raise ValueError("u_elem is required to evaluate unknown field gradient.")
|
|
@@ -285,6 +528,15 @@ def sym_grad(field) -> Expr:
|
|
|
285
528
|
return Expr("sym_grad", _as_expr(field))
|
|
286
529
|
|
|
287
530
|
|
|
531
|
+
def outer(a, b) -> Expr:
|
|
532
|
+
"""Outer product of scalar fields: `outer(v, u)` (test, trial)."""
|
|
533
|
+
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
534
|
+
raise TypeError("outer expects FieldRef operands.")
|
|
535
|
+
if a.role != "test" or b.role != "trial":
|
|
536
|
+
raise TypeError("outer expects outer(test, trial).")
|
|
537
|
+
return Expr("outer", a, b)
|
|
538
|
+
|
|
539
|
+
|
|
288
540
|
def dot(a, b) -> Expr:
|
|
289
541
|
"""Dot product or vector load helper."""
|
|
290
542
|
return Expr("dot", _as_expr(a), _as_expr(b))
|
|
@@ -303,7 +555,7 @@ def ddot(a, b, c=None) -> Expr:
|
|
|
303
555
|
|
|
304
556
|
|
|
305
557
|
def inner(a, b) -> Expr:
|
|
306
|
-
"""Inner product over the last axis."""
|
|
558
|
+
"""Inner product over the last axis (tensor-level)."""
|
|
307
559
|
return Expr("inner", _as_expr(a), _as_expr(b))
|
|
308
560
|
|
|
309
561
|
|
|
@@ -363,7 +615,12 @@ def transpose_last2(a) -> Expr:
|
|
|
363
615
|
|
|
364
616
|
|
|
365
617
|
def matmul(a, b) -> Expr:
|
|
366
|
-
"""
|
|
618
|
+
"""FEM-specific batched contraction (same semantics as `@`)."""
|
|
619
|
+
return Expr("matmul", _as_expr(a), _as_expr(b))
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def matmul_std(a, b) -> Expr:
|
|
623
|
+
"""Standard matrix product (`jnp.matmul` semantics)."""
|
|
367
624
|
return Expr("matmul_std", _as_expr(a), _as_expr(b))
|
|
368
625
|
|
|
369
626
|
|
|
@@ -374,9 +631,22 @@ def einsum(subscripts: str, *args) -> Expr:
|
|
|
374
631
|
|
|
375
632
|
def _call_user(fn, *args, params):
|
|
376
633
|
try:
|
|
634
|
+
sig = inspect.signature(fn)
|
|
635
|
+
except (TypeError, ValueError):
|
|
636
|
+
return fn(*args, params)
|
|
637
|
+
|
|
638
|
+
params_list = list(sig.parameters.values())
|
|
639
|
+
if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params_list):
|
|
377
640
|
return fn(*args, params)
|
|
378
|
-
|
|
379
|
-
|
|
641
|
+
positional = [
|
|
642
|
+
p
|
|
643
|
+
for p in params_list
|
|
644
|
+
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
645
|
+
]
|
|
646
|
+
max_positional = len(positional)
|
|
647
|
+
if len(args) + 1 <= max_positional:
|
|
648
|
+
return fn(*args, params)
|
|
649
|
+
return fn(*args)
|
|
380
650
|
|
|
381
651
|
|
|
382
652
|
def compile_bilinear(fn):
|
|
@@ -387,19 +657,26 @@ def compile_bilinear(fn):
|
|
|
387
657
|
u = trial_ref()
|
|
388
658
|
v = test_ref()
|
|
389
659
|
p = param_ref()
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
660
|
+
expr = _call_user(fn, u, v, params=p)
|
|
661
|
+
expr = _as_expr(expr)
|
|
662
|
+
if not isinstance(expr, Expr):
|
|
663
|
+
raise TypeError("Bilinear form must return an Expr.")
|
|
394
664
|
|
|
395
|
-
|
|
396
|
-
|
|
665
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
666
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
667
|
+
if volume_count == 0:
|
|
397
668
|
raise ValueError("Volume bilinear form must include dOmega().")
|
|
669
|
+
if volume_count > 1:
|
|
670
|
+
raise ValueError("Volume bilinear form must include dOmega() exactly once.")
|
|
671
|
+
if surface_count > 0:
|
|
672
|
+
raise ValueError("Volume bilinear form must not include ds().")
|
|
673
|
+
|
|
674
|
+
plan = make_eval_plan(expr)
|
|
398
675
|
|
|
399
676
|
def _form(ctx, params):
|
|
400
|
-
return
|
|
677
|
+
return eval_with_plan(plan, ctx, params)
|
|
401
678
|
|
|
402
|
-
_form._includes_measure =
|
|
679
|
+
_form._includes_measure = True
|
|
403
680
|
return _form
|
|
404
681
|
|
|
405
682
|
|
|
@@ -410,19 +687,26 @@ def compile_linear(fn):
|
|
|
410
687
|
else:
|
|
411
688
|
v = test_ref()
|
|
412
689
|
p = param_ref()
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
690
|
+
expr = _call_user(fn, v, params=p)
|
|
691
|
+
expr = _as_expr(expr)
|
|
692
|
+
if not isinstance(expr, Expr):
|
|
693
|
+
raise TypeError("Linear form must return an Expr.")
|
|
417
694
|
|
|
418
|
-
|
|
419
|
-
|
|
695
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
696
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
697
|
+
if volume_count == 0:
|
|
420
698
|
raise ValueError("Volume linear form must include dOmega().")
|
|
699
|
+
if volume_count > 1:
|
|
700
|
+
raise ValueError("Volume linear form must include dOmega() exactly once.")
|
|
701
|
+
if surface_count > 0:
|
|
702
|
+
raise ValueError("Volume linear form must not include ds().")
|
|
703
|
+
|
|
704
|
+
plan = make_eval_plan(expr)
|
|
421
705
|
|
|
422
706
|
def _form(ctx, params):
|
|
423
|
-
return
|
|
707
|
+
return eval_with_plan(plan, ctx, params)
|
|
424
708
|
|
|
425
|
-
_form._includes_measure =
|
|
709
|
+
_form._includes_measure = True
|
|
426
710
|
return _form
|
|
427
711
|
|
|
428
712
|
|
|
@@ -434,6 +718,340 @@ def _expr_contains(expr: Expr, op: str) -> bool:
|
|
|
434
718
|
return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
|
|
435
719
|
|
|
436
720
|
|
|
721
|
+
def _count_op(expr: Expr, op: str) -> int:
|
|
722
|
+
if not isinstance(expr, Expr):
|
|
723
|
+
return 0
|
|
724
|
+
count = 1 if expr.op == op else 0
|
|
725
|
+
for arg in expr.args:
|
|
726
|
+
if isinstance(arg, Expr):
|
|
727
|
+
count += _count_op(arg, op)
|
|
728
|
+
return count
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
@dataclass(frozen=True, slots=True)
|
|
732
|
+
class EvalPlan:
|
|
733
|
+
expr: Expr
|
|
734
|
+
nodes: tuple[Expr, ...]
|
|
735
|
+
index: dict[int, int]
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def _validate_eval_plan(nodes: tuple[Expr, ...]) -> None:
|
|
739
|
+
fieldref_arg_ops = {
|
|
740
|
+
"value",
|
|
741
|
+
"grad",
|
|
742
|
+
"sym_grad",
|
|
743
|
+
"dot",
|
|
744
|
+
"sdot",
|
|
745
|
+
"action",
|
|
746
|
+
"gaction",
|
|
747
|
+
"outer",
|
|
748
|
+
}
|
|
749
|
+
for node in nodes:
|
|
750
|
+
op = node.op
|
|
751
|
+
args = node.args
|
|
752
|
+
if op not in fieldref_arg_ops:
|
|
753
|
+
if any(isinstance(arg, FieldRef) for arg in args):
|
|
754
|
+
raise TypeError(f"{op} cannot take FieldRef directly; wrap with .val/.grad/.sym_grad.")
|
|
755
|
+
if op in {"value", "grad", "sym_grad"}:
|
|
756
|
+
if len(args) != 1 or not isinstance(args[0], FieldRef):
|
|
757
|
+
raise TypeError(f"{op} expects FieldRef.")
|
|
758
|
+
elif op in {"dot", "sdot"}:
|
|
759
|
+
if len(args) != 2:
|
|
760
|
+
raise TypeError(f"{op} expects two arguments.")
|
|
761
|
+
if any(isinstance(arg, FieldRef) for arg in args):
|
|
762
|
+
if not isinstance(args[0], FieldRef):
|
|
763
|
+
raise TypeError(f"{op} expects FieldRef as the first argument.")
|
|
764
|
+
if isinstance(args[1], FieldRef):
|
|
765
|
+
raise TypeError(f"{op} expects an expression for the second argument; use .val/.grad.")
|
|
766
|
+
elif op in {"action", "gaction"}:
|
|
767
|
+
if len(args) != 2 or not isinstance(args[0], FieldRef):
|
|
768
|
+
raise TypeError(f"{op} expects FieldRef as the first argument.")
|
|
769
|
+
if op == "action" and isinstance(args[1], FieldRef):
|
|
770
|
+
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
771
|
+
if op == "gaction" and isinstance(args[1], FieldRef):
|
|
772
|
+
raise TypeError("gaction expects an expression for the second argument; use .grad.")
|
|
773
|
+
elif op == "outer":
|
|
774
|
+
if len(args) != 2 or not all(isinstance(arg, FieldRef) for arg in args):
|
|
775
|
+
raise TypeError("outer expects two FieldRef operands.")
|
|
776
|
+
if args[0].role != "test" or args[1].role != "trial":
|
|
777
|
+
raise TypeError("outer expects outer(test, trial).")
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def make_eval_plan(expr: Expr) -> EvalPlan:
|
|
781
|
+
nodes = tuple(expr.postorder_expr())
|
|
782
|
+
_validate_eval_plan(nodes)
|
|
783
|
+
index: dict[int, int] = {}
|
|
784
|
+
for i, node in enumerate(nodes):
|
|
785
|
+
index.setdefault(id(node), i)
|
|
786
|
+
return EvalPlan(expr=expr, nodes=nodes, index=index)
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def eval_with_plan(
|
|
790
|
+
plan: EvalPlan,
|
|
791
|
+
ctx: VolumeContext | SurfaceContext,
|
|
792
|
+
params: ParamsLike,
|
|
793
|
+
u_elem: UElement | None = None,
|
|
794
|
+
):
|
|
795
|
+
nodes = plan.nodes
|
|
796
|
+
index = plan.index
|
|
797
|
+
vals: list[Any] = [None] * len(nodes)
|
|
798
|
+
|
|
799
|
+
def get(obj):
|
|
800
|
+
if isinstance(obj, Expr):
|
|
801
|
+
return vals[index[id(obj)]]
|
|
802
|
+
if isinstance(obj, FieldRef):
|
|
803
|
+
raise TypeError(
|
|
804
|
+
"FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
|
|
805
|
+
)
|
|
806
|
+
if isinstance(obj, ParamRef):
|
|
807
|
+
return params
|
|
808
|
+
return obj
|
|
809
|
+
|
|
810
|
+
for i, node in enumerate(nodes):
|
|
811
|
+
op = node.op
|
|
812
|
+
args = node.args
|
|
813
|
+
|
|
814
|
+
if op == "lit":
|
|
815
|
+
vals[i] = args[0]
|
|
816
|
+
continue
|
|
817
|
+
if op == "getattr":
|
|
818
|
+
base = get(args[0])
|
|
819
|
+
name = args[1]
|
|
820
|
+
if isinstance(base, dict):
|
|
821
|
+
vals[i] = base[name]
|
|
822
|
+
else:
|
|
823
|
+
vals[i] = getattr(base, name)
|
|
824
|
+
continue
|
|
825
|
+
if op == "value":
|
|
826
|
+
ref = args[0]
|
|
827
|
+
assert isinstance(ref, FieldRef)
|
|
828
|
+
field = _eval_field(ref, ctx, params)
|
|
829
|
+
if ref.role == "unknown":
|
|
830
|
+
vals[i] = _eval_unknown_value(ref, field, u_elem)
|
|
831
|
+
else:
|
|
832
|
+
vals[i] = field.N
|
|
833
|
+
continue
|
|
834
|
+
if op == "grad":
|
|
835
|
+
ref = args[0]
|
|
836
|
+
assert isinstance(ref, FieldRef)
|
|
837
|
+
field = _eval_field(ref, ctx, params)
|
|
838
|
+
if ref.role == "unknown":
|
|
839
|
+
vals[i] = _eval_unknown_grad(ref, field, u_elem)
|
|
840
|
+
else:
|
|
841
|
+
vals[i] = field.gradN
|
|
842
|
+
continue
|
|
843
|
+
if op == "pow":
|
|
844
|
+
base = get(args[0])
|
|
845
|
+
exp = get(args[1])
|
|
846
|
+
vals[i] = base**exp
|
|
847
|
+
continue
|
|
848
|
+
if op == "eye":
|
|
849
|
+
vals[i] = jnp.eye(int(args[0]))
|
|
850
|
+
continue
|
|
851
|
+
if op == "det":
|
|
852
|
+
vals[i] = jnp.linalg.det(get(args[0]))
|
|
853
|
+
continue
|
|
854
|
+
if op == "inv":
|
|
855
|
+
vals[i] = jnp.linalg.inv(get(args[0]))
|
|
856
|
+
continue
|
|
857
|
+
if op == "transpose":
|
|
858
|
+
vals[i] = jnp.swapaxes(get(args[0]), -1, -2)
|
|
859
|
+
continue
|
|
860
|
+
if op == "log":
|
|
861
|
+
vals[i] = jnp.log(get(args[0]))
|
|
862
|
+
continue
|
|
863
|
+
if op == "surface_normal":
|
|
864
|
+
normal = getattr(ctx, "normal", None)
|
|
865
|
+
if normal is None:
|
|
866
|
+
raise ValueError("surface normal is not available in context")
|
|
867
|
+
vals[i] = normal
|
|
868
|
+
continue
|
|
869
|
+
if op == "surface_measure":
|
|
870
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
|
|
871
|
+
raise TypeError("surface measure requires SurfaceContext.")
|
|
872
|
+
vals[i] = ctx.w * ctx.detJ
|
|
873
|
+
continue
|
|
874
|
+
if op == "volume_measure":
|
|
875
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
|
|
876
|
+
raise TypeError("volume measure requires VolumeContext.")
|
|
877
|
+
vals[i] = ctx.w * ctx.test.detJ
|
|
878
|
+
continue
|
|
879
|
+
if op == "sym_grad":
|
|
880
|
+
ref = args[0]
|
|
881
|
+
assert isinstance(ref, FieldRef)
|
|
882
|
+
field = _eval_field(ref, ctx, params)
|
|
883
|
+
if ref.role == "unknown":
|
|
884
|
+
if u_elem is None:
|
|
885
|
+
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
886
|
+
u_local = _extract_unknown_elem(ref, u_elem)
|
|
887
|
+
vals[i] = _ops.sym_grad_u(field, u_local)
|
|
888
|
+
else:
|
|
889
|
+
vals[i] = _ops.sym_grad(field)
|
|
890
|
+
continue
|
|
891
|
+
if op == "outer":
|
|
892
|
+
a, b = args
|
|
893
|
+
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
894
|
+
raise TypeError("outer expects FieldRef operands.")
|
|
895
|
+
test, trial = a, b
|
|
896
|
+
vals[i] = _basis_outer(test, trial, ctx, params)
|
|
897
|
+
continue
|
|
898
|
+
if op == "add":
|
|
899
|
+
vals[i] = get(args[0]) + get(args[1])
|
|
900
|
+
continue
|
|
901
|
+
if op == "sub":
|
|
902
|
+
vals[i] = get(args[0]) - get(args[1])
|
|
903
|
+
continue
|
|
904
|
+
if op == "mul":
|
|
905
|
+
a = get(args[0])
|
|
906
|
+
b = get(args[1])
|
|
907
|
+
if hasattr(a, "ndim") and hasattr(b, "ndim"):
|
|
908
|
+
if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
|
|
909
|
+
a = a[:, None]
|
|
910
|
+
elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
|
|
911
|
+
b = b[:, None]
|
|
912
|
+
elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
|
|
913
|
+
b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
|
|
914
|
+
elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
|
|
915
|
+
a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
|
|
916
|
+
vals[i] = a * b
|
|
917
|
+
continue
|
|
918
|
+
if op == "matmul":
|
|
919
|
+
a = get(args[0])
|
|
920
|
+
b = get(args[1])
|
|
921
|
+
if (
|
|
922
|
+
hasattr(a, "ndim")
|
|
923
|
+
and hasattr(b, "ndim")
|
|
924
|
+
and a.ndim == 3
|
|
925
|
+
and b.ndim == 3
|
|
926
|
+
and a.shape[0] == b.shape[0]
|
|
927
|
+
and a.shape[-1] == b.shape[-1]
|
|
928
|
+
):
|
|
929
|
+
vals[i] = jnp.einsum("qia,qja->qij", a, b)
|
|
930
|
+
else:
|
|
931
|
+
raise TypeError(
|
|
932
|
+
"Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
|
|
933
|
+
)
|
|
934
|
+
continue
|
|
935
|
+
if op == "matmul_std":
|
|
936
|
+
a = get(args[0])
|
|
937
|
+
b = get(args[1])
|
|
938
|
+
vals[i] = jnp.matmul(a, b)
|
|
939
|
+
continue
|
|
940
|
+
if op == "neg":
|
|
941
|
+
vals[i] = -get(args[0])
|
|
942
|
+
continue
|
|
943
|
+
if op == "dot":
|
|
944
|
+
ref = args[0]
|
|
945
|
+
if isinstance(ref, FieldRef):
|
|
946
|
+
vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
|
|
947
|
+
else:
|
|
948
|
+
a = get(args[0])
|
|
949
|
+
b = get(args[1])
|
|
950
|
+
if (
|
|
951
|
+
hasattr(a, "ndim")
|
|
952
|
+
and hasattr(b, "ndim")
|
|
953
|
+
and a.ndim == 3
|
|
954
|
+
and b.ndim == 3
|
|
955
|
+
and a.shape[-1] == b.shape[-1]
|
|
956
|
+
):
|
|
957
|
+
vals[i] = jnp.einsum("qia,qja->qij", a, b)
|
|
958
|
+
else:
|
|
959
|
+
vals[i] = jnp.matmul(a, b)
|
|
960
|
+
continue
|
|
961
|
+
if op == "sdot":
|
|
962
|
+
ref = args[0]
|
|
963
|
+
if isinstance(ref, FieldRef):
|
|
964
|
+
vals[i] = _ops.dot(_eval_field(ref, ctx, params), get(args[1]))
|
|
965
|
+
else:
|
|
966
|
+
a = get(args[0])
|
|
967
|
+
b = get(args[1])
|
|
968
|
+
if (
|
|
969
|
+
hasattr(a, "ndim")
|
|
970
|
+
and hasattr(b, "ndim")
|
|
971
|
+
and a.ndim == 3
|
|
972
|
+
and b.ndim == 3
|
|
973
|
+
and a.shape[-1] == b.shape[-1]
|
|
974
|
+
):
|
|
975
|
+
vals[i] = jnp.einsum("qia,qja->qij", a, b)
|
|
976
|
+
else:
|
|
977
|
+
vals[i] = jnp.matmul(a, b)
|
|
978
|
+
continue
|
|
979
|
+
if op == "ddot":
|
|
980
|
+
if len(args) == 2:
|
|
981
|
+
a = get(args[0])
|
|
982
|
+
b = get(args[1])
|
|
983
|
+
if (
|
|
984
|
+
hasattr(a, "ndim")
|
|
985
|
+
and hasattr(b, "ndim")
|
|
986
|
+
and a.ndim == 3
|
|
987
|
+
and b.ndim == 3
|
|
988
|
+
and a.shape[0] == b.shape[0]
|
|
989
|
+
and a.shape[1] == b.shape[1]
|
|
990
|
+
):
|
|
991
|
+
vals[i] = jnp.einsum("qik,qim->qkm", a, b)
|
|
992
|
+
else:
|
|
993
|
+
vals[i] = _ops.ddot(a, b)
|
|
994
|
+
else:
|
|
995
|
+
vals[i] = _ops.ddot(get(args[0]), get(args[1]), get(args[2]))
|
|
996
|
+
continue
|
|
997
|
+
if op == "inner":
|
|
998
|
+
a = get(args[0])
|
|
999
|
+
b = get(args[1])
|
|
1000
|
+
vals[i] = jnp.einsum("...i,...i->...", a, b)
|
|
1001
|
+
continue
|
|
1002
|
+
if op == "action":
|
|
1003
|
+
ref = args[0]
|
|
1004
|
+
assert isinstance(ref, FieldRef)
|
|
1005
|
+
if isinstance(args[1], FieldRef):
|
|
1006
|
+
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
1007
|
+
v_field = _eval_field(ref, ctx, params)
|
|
1008
|
+
s = get(args[1])
|
|
1009
|
+
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
1010
|
+
# action maps a test field with a scalar/vector expression into nodal space.
|
|
1011
|
+
if value_dim == 1:
|
|
1012
|
+
if v_field.N.ndim != 2:
|
|
1013
|
+
raise ValueError("action expects scalar test field with N shape (q, ndofs).")
|
|
1014
|
+
if hasattr(s, "ndim") and s.ndim not in (0, 1):
|
|
1015
|
+
raise ValueError("action expects scalar s with shape (q,) or scalar.")
|
|
1016
|
+
vals[i] = v_field.N * s
|
|
1017
|
+
else:
|
|
1018
|
+
if hasattr(s, "ndim") and s.ndim not in (1, 2):
|
|
1019
|
+
raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
|
|
1020
|
+
vals[i] = _ops.dot(v_field, s)
|
|
1021
|
+
continue
|
|
1022
|
+
if op == "gaction":
|
|
1023
|
+
ref = args[0]
|
|
1024
|
+
assert isinstance(ref, FieldRef)
|
|
1025
|
+
v_field = _eval_field(ref, ctx, params)
|
|
1026
|
+
q = get(args[1])
|
|
1027
|
+
# gaction maps a flux-like expression to nodal space via test gradients.
|
|
1028
|
+
if v_field.gradN.ndim != 3:
|
|
1029
|
+
raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
|
|
1030
|
+
if not hasattr(q, "ndim"):
|
|
1031
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
1032
|
+
if q.ndim == 2:
|
|
1033
|
+
vals[i] = jnp.einsum("qaj,qj->qa", v_field.gradN, q)
|
|
1034
|
+
elif q.ndim == 3:
|
|
1035
|
+
if int(getattr(v_field, "value_dim", 1)) == 1:
|
|
1036
|
+
raise ValueError("gaction tensor flux requires vector test field.")
|
|
1037
|
+
vals[i] = jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
|
|
1038
|
+
else:
|
|
1039
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
1040
|
+
continue
|
|
1041
|
+
if op == "transpose_last2":
|
|
1042
|
+
vals[i] = _ops.transpose_last2(get(args[0]))
|
|
1043
|
+
continue
|
|
1044
|
+
if op == "einsum":
|
|
1045
|
+
subscripts = args[0]
|
|
1046
|
+
operands = [get(arg) for arg in args[1:]]
|
|
1047
|
+
vals[i] = jnp.einsum(subscripts, *operands)
|
|
1048
|
+
continue
|
|
1049
|
+
|
|
1050
|
+
raise ValueError(f"Unknown Expr op: {op}")
|
|
1051
|
+
|
|
1052
|
+
return vals[index[id(plan.expr)]]
|
|
1053
|
+
|
|
1054
|
+
|
|
437
1055
|
def compile_surface_linear(fn):
|
|
438
1056
|
"""get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
|
|
439
1057
|
if isinstance(fn, Expr):
|
|
@@ -441,26 +1059,27 @@ def compile_surface_linear(fn):
|
|
|
441
1059
|
else:
|
|
442
1060
|
v = test_ref()
|
|
443
1061
|
p = param_ref()
|
|
444
|
-
expr =
|
|
445
|
-
try:
|
|
446
|
-
expr = fn(v, p)
|
|
447
|
-
except TypeError:
|
|
448
|
-
try:
|
|
449
|
-
expr = fn(v)
|
|
450
|
-
except TypeError:
|
|
451
|
-
expr = None
|
|
1062
|
+
expr = _call_user(fn, v, params=p)
|
|
452
1063
|
|
|
1064
|
+
expr = _as_expr(expr)
|
|
453
1065
|
if not isinstance(expr, Expr):
|
|
454
1066
|
raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
|
|
455
1067
|
|
|
456
|
-
|
|
457
|
-
|
|
1068
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
1069
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
1070
|
+
if surface_count == 0:
|
|
458
1071
|
raise ValueError("Surface linear form must include ds().")
|
|
1072
|
+
if surface_count > 1:
|
|
1073
|
+
raise ValueError("Surface linear form must include ds() exactly once.")
|
|
1074
|
+
if volume_count > 0:
|
|
1075
|
+
raise ValueError("Surface linear form must not include dOmega().")
|
|
1076
|
+
|
|
1077
|
+
plan = make_eval_plan(expr)
|
|
459
1078
|
|
|
460
1079
|
def _form(ctx, params):
|
|
461
|
-
return
|
|
1080
|
+
return eval_with_plan(plan, ctx, params)
|
|
462
1081
|
|
|
463
|
-
_form._includes_measure =
|
|
1082
|
+
_form._includes_measure = True # type: ignore[attr-defined]
|
|
464
1083
|
return _form
|
|
465
1084
|
|
|
466
1085
|
|
|
@@ -524,25 +1143,33 @@ def compile_residual(fn):
|
|
|
524
1143
|
v = test_ref()
|
|
525
1144
|
u = unknown_ref()
|
|
526
1145
|
p = param_ref()
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
1146
|
+
expr = _call_user(fn, v, u, params=p)
|
|
1147
|
+
expr = _as_expr(expr)
|
|
1148
|
+
if not isinstance(expr, Expr):
|
|
1149
|
+
raise TypeError("Residual form must return an Expr.")
|
|
531
1150
|
|
|
532
|
-
|
|
533
|
-
|
|
1151
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
1152
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
1153
|
+
if volume_count == 0:
|
|
534
1154
|
raise ValueError("Volume residual form must include dOmega().")
|
|
1155
|
+
if volume_count > 1:
|
|
1156
|
+
raise ValueError("Volume residual form must include dOmega() exactly once.")
|
|
1157
|
+
if surface_count > 0:
|
|
1158
|
+
raise ValueError("Volume residual form must not include ds().")
|
|
1159
|
+
|
|
1160
|
+
plan = make_eval_plan(expr)
|
|
535
1161
|
|
|
536
1162
|
def _form(ctx, u_elem, params):
|
|
537
|
-
return
|
|
1163
|
+
return eval_with_plan(plan, ctx, params, u_elem=u_elem)
|
|
538
1164
|
|
|
539
|
-
_form._includes_measure =
|
|
1165
|
+
_form._includes_measure = True
|
|
540
1166
|
return _form
|
|
541
1167
|
|
|
542
1168
|
|
|
543
1169
|
def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
544
1170
|
"""get_compiled mixed residuals keyed by field name."""
|
|
545
1171
|
compiled = {}
|
|
1172
|
+
plans = {}
|
|
546
1173
|
includes_measure = {}
|
|
547
1174
|
for name, fn in residuals.items():
|
|
548
1175
|
if isinstance(fn, Expr):
|
|
@@ -551,17 +1178,24 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
|
551
1178
|
v = test_ref(name)
|
|
552
1179
|
u = unknown_ref(name)
|
|
553
1180
|
p = param_ref()
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
compiled[name] =
|
|
559
|
-
|
|
560
|
-
|
|
1181
|
+
expr = _call_user(fn, v, u, params=p)
|
|
1182
|
+
expr = _as_expr(expr)
|
|
1183
|
+
if not isinstance(expr, Expr):
|
|
1184
|
+
raise TypeError(f"Mixed residual '{name}' must return an Expr.")
|
|
1185
|
+
compiled[name] = expr
|
|
1186
|
+
plans[name] = make_eval_plan(expr)
|
|
1187
|
+
volume_count = _count_op(compiled[name], "volume_measure")
|
|
1188
|
+
surface_count = _count_op(compiled[name], "surface_measure")
|
|
1189
|
+
includes_measure[name] = volume_count == 1
|
|
1190
|
+
if volume_count == 0:
|
|
561
1191
|
raise ValueError(f"Mixed residual '{name}' must include dOmega().")
|
|
1192
|
+
if volume_count > 1:
|
|
1193
|
+
raise ValueError(f"Mixed residual '{name}' must include dOmega() exactly once.")
|
|
1194
|
+
if surface_count > 0:
|
|
1195
|
+
raise ValueError(f"Mixed residual '{name}' must not include ds().")
|
|
562
1196
|
|
|
563
1197
|
def _form(ctx, u_elem, params):
|
|
564
|
-
return {name:
|
|
1198
|
+
return {name: eval_with_plan(plan, ctx, params, u_elem=u_elem) for name, plan in plans.items()}
|
|
565
1199
|
|
|
566
1200
|
_form._includes_measure = includes_measure
|
|
567
1201
|
return _form
|
|
@@ -579,220 +1213,14 @@ class MixedWeakForm:
|
|
|
579
1213
|
return compile_mixed_residual(self.residuals)
|
|
580
1214
|
|
|
581
1215
|
|
|
582
|
-
def _eval_expr(
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
if op == "getattr":
|
|
591
|
-
base = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
592
|
-
name = args[1]
|
|
593
|
-
if isinstance(base, dict):
|
|
594
|
-
return base[name]
|
|
595
|
-
return getattr(base, name)
|
|
596
|
-
if op == "field":
|
|
597
|
-
role, name = args
|
|
598
|
-
if name is not None:
|
|
599
|
-
if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
|
|
600
|
-
if name in ctx.trial_fields:
|
|
601
|
-
return ctx.trial_fields[name]
|
|
602
|
-
if role == "test" and getattr(ctx, "test_fields", None) is not None:
|
|
603
|
-
if name in ctx.test_fields:
|
|
604
|
-
return ctx.test_fields[name]
|
|
605
|
-
if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
|
|
606
|
-
if name in ctx.unknown_fields:
|
|
607
|
-
return ctx.unknown_fields[name]
|
|
608
|
-
fields = getattr(ctx, "fields", None)
|
|
609
|
-
if fields is not None and name in fields:
|
|
610
|
-
group = fields[name]
|
|
611
|
-
if isinstance(group, dict):
|
|
612
|
-
if role in group:
|
|
613
|
-
return group[role]
|
|
614
|
-
if "field" in group:
|
|
615
|
-
return group["field"]
|
|
616
|
-
return group
|
|
617
|
-
if role == "trial":
|
|
618
|
-
return ctx.trial
|
|
619
|
-
if role == "test":
|
|
620
|
-
return ctx.test
|
|
621
|
-
if role == "unknown":
|
|
622
|
-
return getattr(ctx, "unknown", ctx.trial)
|
|
623
|
-
raise ValueError(f"Unknown field role: {role}")
|
|
624
|
-
if op == "value":
|
|
625
|
-
field = _eval_field(args[0], ctx, params)
|
|
626
|
-
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
627
|
-
return _eval_unknown_value(args[0], field, u_elem)
|
|
628
|
-
return field.N
|
|
629
|
-
if op == "grad":
|
|
630
|
-
field = _eval_field(args[0], ctx, params)
|
|
631
|
-
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
632
|
-
return _eval_unknown_grad(args[0], field, u_elem)
|
|
633
|
-
return field.gradN
|
|
634
|
-
if op == "pow":
|
|
635
|
-
base = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
636
|
-
exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
637
|
-
return base**exp
|
|
638
|
-
if op == "eye":
|
|
639
|
-
return jnp.eye(int(args[0]))
|
|
640
|
-
if op == "det":
|
|
641
|
-
return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
642
|
-
if op == "inv":
|
|
643
|
-
return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
644
|
-
if op == "transpose":
|
|
645
|
-
return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
|
|
646
|
-
if op == "log":
|
|
647
|
-
return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
648
|
-
if op == "surface_normal":
|
|
649
|
-
normal = getattr(ctx, "normal", None)
|
|
650
|
-
if normal is None:
|
|
651
|
-
raise ValueError("surface normal is not available in context")
|
|
652
|
-
return normal
|
|
653
|
-
if op == "surface_measure":
|
|
654
|
-
if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
|
|
655
|
-
raise ValueError("surface measure requires surface context with w and detJ.")
|
|
656
|
-
return ctx.w * ctx.detJ
|
|
657
|
-
if op == "volume_measure":
|
|
658
|
-
if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
|
|
659
|
-
raise ValueError("volume measure requires FormContext with w and test.detJ.")
|
|
660
|
-
return ctx.w * ctx.test.detJ
|
|
661
|
-
if op == "sym_grad":
|
|
662
|
-
field = _eval_field(args[0], ctx, params)
|
|
663
|
-
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
664
|
-
if u_elem is None:
|
|
665
|
-
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
666
|
-
u_local = _extract_unknown_elem(args[0], u_elem)
|
|
667
|
-
return _ops.sym_grad_u(field, u_local)
|
|
668
|
-
return _ops.sym_grad(field)
|
|
669
|
-
if op == "outer":
|
|
670
|
-
a, b = args
|
|
671
|
-
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
672
|
-
raise TypeError("outer expects FieldRef operands.")
|
|
673
|
-
if a.role == b.role:
|
|
674
|
-
raise ValueError("outer requires one trial and one test field.")
|
|
675
|
-
test = a if a.role == "test" else b
|
|
676
|
-
trial = b if a.role == "test" else a
|
|
677
|
-
v_field = _eval_field(test, ctx, params)
|
|
678
|
-
u_field = _eval_field(trial, ctx, params)
|
|
679
|
-
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
680
|
-
raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
|
|
681
|
-
vN = v_field.N
|
|
682
|
-
uN = u_field.N
|
|
683
|
-
return jnp.einsum("qi,qj->qij", vN, uN)
|
|
684
|
-
if op == "add":
|
|
685
|
-
return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
686
|
-
if op == "sub":
|
|
687
|
-
return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
688
|
-
if op == "mul":
|
|
689
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
690
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
691
|
-
if hasattr(a, "ndim") and hasattr(b, "ndim"):
|
|
692
|
-
if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
|
|
693
|
-
a = a[:, None]
|
|
694
|
-
elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
|
|
695
|
-
b = b[:, None]
|
|
696
|
-
elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
|
|
697
|
-
b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
|
|
698
|
-
elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
|
|
699
|
-
a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
|
|
700
|
-
return a * b
|
|
701
|
-
if op == "matmul":
|
|
702
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
703
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
704
|
-
if (
|
|
705
|
-
hasattr(a, "ndim")
|
|
706
|
-
and hasattr(b, "ndim")
|
|
707
|
-
and a.ndim == 3
|
|
708
|
-
and b.ndim == 3
|
|
709
|
-
and a.shape[0] == b.shape[0]
|
|
710
|
-
and a.shape[-1] == b.shape[-1]
|
|
711
|
-
):
|
|
712
|
-
return jnp.einsum("qia,qja->qij", a, b)
|
|
713
|
-
return a @ b
|
|
714
|
-
if op == "matmul_std":
|
|
715
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
716
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
717
|
-
return jnp.matmul(a, b)
|
|
718
|
-
if op == "neg":
|
|
719
|
-
return -_eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
720
|
-
if op == "dot":
|
|
721
|
-
if isinstance(args[0], FieldRef):
|
|
722
|
-
return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
|
|
723
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
724
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
725
|
-
if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
|
|
726
|
-
return jnp.einsum("qia,qja->qij", a, b)
|
|
727
|
-
return jnp.matmul(a, b)
|
|
728
|
-
if op == "sdot":
|
|
729
|
-
if isinstance(args[0], FieldRef):
|
|
730
|
-
return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
|
|
731
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
732
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
733
|
-
if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
|
|
734
|
-
return jnp.einsum("qia,qja->qij", a, b)
|
|
735
|
-
return jnp.matmul(a, b)
|
|
736
|
-
if op == "ddot":
|
|
737
|
-
if len(args) == 2:
|
|
738
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
739
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
740
|
-
if (
|
|
741
|
-
hasattr(a, "ndim")
|
|
742
|
-
and hasattr(b, "ndim")
|
|
743
|
-
and a.ndim == 3
|
|
744
|
-
and b.ndim == 3
|
|
745
|
-
and a.shape[0] == b.shape[0]
|
|
746
|
-
and a.shape[1] == b.shape[1]
|
|
747
|
-
):
|
|
748
|
-
return jnp.einsum("qik,qim->qkm", a, b)
|
|
749
|
-
return _ops.ddot(a, b)
|
|
750
|
-
return _ops.ddot(
|
|
751
|
-
_eval_value(args[0], ctx, params, u_elem=u_elem),
|
|
752
|
-
_eval_value(args[1], ctx, params, u_elem=u_elem),
|
|
753
|
-
_eval_value(args[2], ctx, params, u_elem=u_elem),
|
|
754
|
-
)
|
|
755
|
-
if op == "inner":
|
|
756
|
-
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
757
|
-
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
758
|
-
return jnp.einsum("...i,...i->...", a, b)
|
|
759
|
-
if op == "action":
|
|
760
|
-
if isinstance(args[1], FieldRef):
|
|
761
|
-
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
762
|
-
v_field = _eval_field(args[0], ctx, params)
|
|
763
|
-
s = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
764
|
-
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
765
|
-
if value_dim == 1:
|
|
766
|
-
if v_field.N.ndim != 2:
|
|
767
|
-
raise ValueError("action expects scalar test field with N shape (q, ndofs).")
|
|
768
|
-
if hasattr(s, "ndim") and s.ndim not in (0, 1):
|
|
769
|
-
raise ValueError("action expects scalar s with shape (q,) or scalar.")
|
|
770
|
-
return v_field.N * s
|
|
771
|
-
if hasattr(s, "ndim") and s.ndim not in (1, 2):
|
|
772
|
-
raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
|
|
773
|
-
return _ops.dot(v_field, s)
|
|
774
|
-
if op == "gaction":
|
|
775
|
-
v_field = _eval_field(args[0], ctx, params)
|
|
776
|
-
q = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
777
|
-
if v_field.gradN.ndim != 3:
|
|
778
|
-
raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
|
|
779
|
-
if not hasattr(q, "ndim"):
|
|
780
|
-
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
781
|
-
if q.ndim == 2:
|
|
782
|
-
return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
|
|
783
|
-
if q.ndim == 3:
|
|
784
|
-
if int(getattr(v_field, "value_dim", 1)) == 1:
|
|
785
|
-
raise ValueError("gaction tensor flux requires vector test field.")
|
|
786
|
-
return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
|
|
787
|
-
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
788
|
-
if op == "transpose_last2":
|
|
789
|
-
return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
790
|
-
if op == "einsum":
|
|
791
|
-
subscripts = args[0]
|
|
792
|
-
operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
|
|
793
|
-
return jnp.einsum(subscripts, *operands)
|
|
794
|
-
|
|
795
|
-
raise ValueError(f"Unknown Expr op: {op}")
|
|
1216
|
+
def _eval_expr(
|
|
1217
|
+
expr: Expr,
|
|
1218
|
+
ctx: VolumeContext | SurfaceContext,
|
|
1219
|
+
params: ParamsLike,
|
|
1220
|
+
u_elem: UElement | None = None,
|
|
1221
|
+
):
|
|
1222
|
+
plan = make_eval_plan(expr)
|
|
1223
|
+
return eval_with_plan(plan, ctx, params, u_elem=u_elem)
|
|
796
1224
|
|
|
797
1225
|
|
|
798
1226
|
__all__ = [
|
|
@@ -812,6 +1240,7 @@ __all__ = [
|
|
|
812
1240
|
"compile_mixed_residual",
|
|
813
1241
|
"grad",
|
|
814
1242
|
"sym_grad",
|
|
1243
|
+
"outer",
|
|
815
1244
|
"dot",
|
|
816
1245
|
"ddot",
|
|
817
1246
|
"inner",
|
|
@@ -824,5 +1253,6 @@ __all__ = [
|
|
|
824
1253
|
"log",
|
|
825
1254
|
"transpose_last2",
|
|
826
1255
|
"matmul",
|
|
1256
|
+
"matmul_std",
|
|
827
1257
|
"einsum",
|
|
828
1258
|
]
|