fluxfem 0.1.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fluxfem might be problematic. Click here for more details.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +318 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +828 -0
- fluxfem/helpers_ts.py +11 -0
- fluxfem/helpers_wf.py +44 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.3a0.dist-info/LICENSE +201 -0
- fluxfem-0.1.3a0.dist-info/METADATA +125 -0
- fluxfem-0.1.3a0.dist-info/RECORD +47 -0
- fluxfem-0.1.3a0.dist-info/WHEEL +4 -0
fluxfem/core/weakform.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import jax
|
|
8
|
+
|
|
9
|
+
from ..physics import operators as _ops
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Expr:
|
|
13
|
+
"""Expression tree node evaluated against a FormContext."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, op: str, *args):
|
|
16
|
+
object.__setattr__(self, "op", op)
|
|
17
|
+
object.__setattr__(self, "args", args)
|
|
18
|
+
|
|
19
|
+
def eval(self, ctx, params=None, u_elem=None):
|
|
20
|
+
return _eval_expr(self, ctx, params, u_elem=u_elem)
|
|
21
|
+
|
|
22
|
+
def _binop(self, other, op):
|
|
23
|
+
return Expr(op, self, _as_expr(other))
|
|
24
|
+
|
|
25
|
+
def __add__(self, other):
|
|
26
|
+
return self._binop(other, "add")
|
|
27
|
+
|
|
28
|
+
def __radd__(self, other):
|
|
29
|
+
return _as_expr(other)._binop(self, "add")
|
|
30
|
+
|
|
31
|
+
def __sub__(self, other):
|
|
32
|
+
return self._binop(other, "sub")
|
|
33
|
+
|
|
34
|
+
def __rsub__(self, other):
|
|
35
|
+
return _as_expr(other)._binop(self, "sub")
|
|
36
|
+
|
|
37
|
+
def __mul__(self, other):
|
|
38
|
+
return self._binop(other, "mul")
|
|
39
|
+
|
|
40
|
+
def __rmul__(self, other):
|
|
41
|
+
return _as_expr(other)._binop(self, "mul")
|
|
42
|
+
|
|
43
|
+
def __matmul__(self, other):
|
|
44
|
+
return self._binop(other, "matmul")
|
|
45
|
+
|
|
46
|
+
def __rmatmul__(self, other):
|
|
47
|
+
return _as_expr(other)._binop(self, "matmul")
|
|
48
|
+
|
|
49
|
+
def __or__(self, other):
|
|
50
|
+
return self._binop(other, "inner")
|
|
51
|
+
|
|
52
|
+
def __ror__(self, other):
|
|
53
|
+
return _as_expr(other)._binop(self, "inner")
|
|
54
|
+
|
|
55
|
+
def __pow__(self, power, modulo=None):
|
|
56
|
+
if modulo is not None:
|
|
57
|
+
raise ValueError("modulo is not supported for Expr exponentiation.")
|
|
58
|
+
return Expr("pow", self, _as_expr(power))
|
|
59
|
+
|
|
60
|
+
def __neg__(self):
|
|
61
|
+
return Expr("neg", self)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def T(self):
|
|
65
|
+
return Expr("transpose", self)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class FieldRef(Expr):
|
|
70
|
+
"""Symbolic reference to trial/test/unknown field, optionally by name."""
|
|
71
|
+
|
|
72
|
+
role: str
|
|
73
|
+
name: str | None = None
|
|
74
|
+
|
|
75
|
+
def __init__(self, role: str, name: str | None = None):
|
|
76
|
+
object.__setattr__(self, "role", role)
|
|
77
|
+
object.__setattr__(self, "name", name)
|
|
78
|
+
super().__init__("field", role, name)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def val(self):
|
|
82
|
+
return Expr("value", self)
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def grad(self):
|
|
86
|
+
return Expr("grad", self)
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def sym_grad(self):
|
|
90
|
+
return Expr("sym_grad", self)
|
|
91
|
+
|
|
92
|
+
def __mul__(self, other):
|
|
93
|
+
if isinstance(other, FieldRef):
|
|
94
|
+
return Expr("outer", self, other)
|
|
95
|
+
return Expr("mul", Expr("value", self), _as_expr(other))
|
|
96
|
+
|
|
97
|
+
def __rmul__(self, other):
|
|
98
|
+
if isinstance(other, FieldRef):
|
|
99
|
+
return Expr("outer", other, self)
|
|
100
|
+
return Expr("mul", _as_expr(other), Expr("value", self))
|
|
101
|
+
|
|
102
|
+
def __add__(self, other):
|
|
103
|
+
return Expr("add", Expr("value", self), _as_expr(other))
|
|
104
|
+
|
|
105
|
+
def __radd__(self, other):
|
|
106
|
+
return Expr("add", _as_expr(other), Expr("value", self))
|
|
107
|
+
|
|
108
|
+
def __sub__(self, other):
|
|
109
|
+
return Expr("sub", Expr("value", self), _as_expr(other))
|
|
110
|
+
|
|
111
|
+
def __rsub__(self, other):
|
|
112
|
+
return Expr("sub", _as_expr(other), Expr("value", self))
|
|
113
|
+
|
|
114
|
+
def __or__(self, other):
|
|
115
|
+
if isinstance(other, FieldRef):
|
|
116
|
+
return Expr("inner", self, other)
|
|
117
|
+
return Expr("sdot", self, _as_expr(other))
|
|
118
|
+
|
|
119
|
+
def __ror__(self, other):
|
|
120
|
+
if isinstance(other, FieldRef):
|
|
121
|
+
return Expr("inner", other, self)
|
|
122
|
+
return Expr("sdot", _as_expr(other), self)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass(frozen=True)
|
|
126
|
+
class ParamRef(Expr):
|
|
127
|
+
"""Symbolic reference to params passed into the kernel."""
|
|
128
|
+
|
|
129
|
+
def __init__(self):
|
|
130
|
+
super().__init__("param")
|
|
131
|
+
|
|
132
|
+
def __getattr__(self, name: str):
|
|
133
|
+
return Expr("getattr", self, name)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@jax.tree_util.register_pytree_node_class
|
|
137
|
+
class Params:
|
|
138
|
+
"""Simple params container with attribute access (JAX pytree)."""
|
|
139
|
+
|
|
140
|
+
def __init__(self, **kwargs):
|
|
141
|
+
self._data = dict(kwargs)
|
|
142
|
+
|
|
143
|
+
def __getattr__(self, name: str):
|
|
144
|
+
try:
|
|
145
|
+
return self._data[name]
|
|
146
|
+
except KeyError as exc:
|
|
147
|
+
raise AttributeError(name) from exc
|
|
148
|
+
|
|
149
|
+
def __getitem__(self, key: str):
|
|
150
|
+
return self._data[key]
|
|
151
|
+
|
|
152
|
+
def tree_flatten(self):
|
|
153
|
+
keys = tuple(sorted(self._data.keys()))
|
|
154
|
+
values = tuple(self._data[k] for k in keys)
|
|
155
|
+
return values, keys
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def tree_unflatten(cls, keys, values):
|
|
159
|
+
return cls(**dict(zip(keys, values)))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def trial_ref(name: str | None = "u") -> FieldRef:
|
|
163
|
+
"""Create a symbolic trial field reference."""
|
|
164
|
+
return FieldRef(role="trial", name=name)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def test_ref(name: str | None = "v") -> FieldRef:
|
|
168
|
+
"""Create a symbolic test field reference."""
|
|
169
|
+
return FieldRef(role="test", name=name)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def unknown_ref(name: str | None = "u") -> FieldRef:
|
|
173
|
+
"""Create a symbolic unknown (current solution) field reference."""
|
|
174
|
+
return FieldRef(role="unknown", name=name)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def param_ref() -> ParamRef:
|
|
178
|
+
"""Create a symbolic params reference."""
|
|
179
|
+
return ParamRef()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _as_expr(obj) -> Expr:
|
|
183
|
+
if isinstance(obj, Expr):
|
|
184
|
+
return obj
|
|
185
|
+
return Expr("lit", obj)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _eval_field(obj: Any, ctx, params):
|
|
189
|
+
if isinstance(obj, FieldRef):
|
|
190
|
+
if obj.name is not None:
|
|
191
|
+
mixed_fields = getattr(ctx, "fields", None)
|
|
192
|
+
if mixed_fields is not None and obj.name in mixed_fields:
|
|
193
|
+
group = mixed_fields[obj.name]
|
|
194
|
+
if hasattr(group, "trial") and obj.role == "trial":
|
|
195
|
+
return group.trial
|
|
196
|
+
if hasattr(group, "test") and obj.role == "test":
|
|
197
|
+
return group.test
|
|
198
|
+
if hasattr(group, "unknown") and obj.role == "unknown":
|
|
199
|
+
return group.unknown if group.unknown is not None else group.trial
|
|
200
|
+
if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
|
|
201
|
+
if obj.name in ctx.trial_fields:
|
|
202
|
+
return ctx.trial_fields[obj.name]
|
|
203
|
+
if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
|
|
204
|
+
if obj.name in ctx.test_fields:
|
|
205
|
+
return ctx.test_fields[obj.name]
|
|
206
|
+
if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
|
|
207
|
+
if obj.name in ctx.unknown_fields:
|
|
208
|
+
return ctx.unknown_fields[obj.name]
|
|
209
|
+
fields = getattr(ctx, "fields", None)
|
|
210
|
+
if fields is not None and obj.name in fields:
|
|
211
|
+
group = fields[obj.name]
|
|
212
|
+
if isinstance(group, dict):
|
|
213
|
+
if obj.role in group:
|
|
214
|
+
return group[obj.role]
|
|
215
|
+
if "field" in group:
|
|
216
|
+
return group["field"]
|
|
217
|
+
return group
|
|
218
|
+
if obj.role == "trial":
|
|
219
|
+
return ctx.trial
|
|
220
|
+
if obj.role == "test":
|
|
221
|
+
if hasattr(ctx, "test"):
|
|
222
|
+
return ctx.test
|
|
223
|
+
if hasattr(ctx, "v"):
|
|
224
|
+
return ctx.v
|
|
225
|
+
raise ValueError("Surface context is missing test field.")
|
|
226
|
+
if obj.role == "unknown":
|
|
227
|
+
return getattr(ctx, "unknown", ctx.trial)
|
|
228
|
+
raise ValueError(f"Unknown field role: {obj.role}")
|
|
229
|
+
if isinstance(obj, Expr):
|
|
230
|
+
val = obj.eval(ctx, params)
|
|
231
|
+
if hasattr(val, "N"):
|
|
232
|
+
return val
|
|
233
|
+
raise TypeError("Expected a field reference for this operator.")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _eval_value(obj: Any, ctx, params, u_elem=None):
|
|
237
|
+
if isinstance(obj, FieldRef):
|
|
238
|
+
field = _eval_field(obj, ctx, params)
|
|
239
|
+
if obj.role == "unknown":
|
|
240
|
+
return _eval_unknown_value(obj, field, u_elem)
|
|
241
|
+
return field.N
|
|
242
|
+
if isinstance(obj, Expr):
|
|
243
|
+
return obj.eval(ctx, params, u_elem=u_elem)
|
|
244
|
+
return obj
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _extract_unknown_elem(field_ref: FieldRef, u_elem):
|
|
248
|
+
if u_elem is None:
|
|
249
|
+
raise ValueError("u_elem is required to evaluate unknown field value.")
|
|
250
|
+
if isinstance(u_elem, dict):
|
|
251
|
+
name = field_ref.name or "u"
|
|
252
|
+
if name not in u_elem:
|
|
253
|
+
raise ValueError(f"u_elem is missing key '{name}'.")
|
|
254
|
+
return u_elem[name]
|
|
255
|
+
return u_elem
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
|
|
259
|
+
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
260
|
+
value_dim = int(getattr(field, "value_dim", 1))
|
|
261
|
+
if value_dim == 1:
|
|
262
|
+
return jnp.einsum("qa,a->q", field.N, u_local)
|
|
263
|
+
u_nodes = u_local.reshape((-1, value_dim))
|
|
264
|
+
return jnp.einsum("qa,ai->qi", field.N, u_nodes)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
|
|
268
|
+
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
269
|
+
if u_local is None:
|
|
270
|
+
raise ValueError("u_elem is required to evaluate unknown field gradient.")
|
|
271
|
+
value_dim = int(getattr(field, "value_dim", 1))
|
|
272
|
+
if value_dim == 1:
|
|
273
|
+
return jnp.einsum("qaj,a->qj", field.gradN, u_local)
|
|
274
|
+
u_nodes = u_local.reshape((-1, value_dim))
|
|
275
|
+
return jnp.einsum("qaj,ai->qij", field.gradN, u_nodes)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def grad(field) -> Expr:
|
|
279
|
+
"""Return basis gradients for a scalar or vector FormField."""
|
|
280
|
+
return Expr("grad", _as_expr(field))
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def sym_grad(field) -> Expr:
|
|
284
|
+
"""Return symmetric-gradient B-matrix for a vector FormField."""
|
|
285
|
+
return Expr("sym_grad", _as_expr(field))
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def dot(a, b) -> Expr:
|
|
289
|
+
"""Dot product or vector load helper."""
|
|
290
|
+
return Expr("dot", _as_expr(a), _as_expr(b))
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def sdot(a, b) -> Expr:
|
|
294
|
+
"""Surface dot product or vector load helper."""
|
|
295
|
+
return Expr("sdot", _as_expr(a), _as_expr(b))
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def ddot(a, b, c=None) -> Expr:
|
|
299
|
+
"""Double contraction or a^T b c."""
|
|
300
|
+
if c is None:
|
|
301
|
+
return Expr("ddot", _as_expr(a), _as_expr(b))
|
|
302
|
+
return Expr("ddot", _as_expr(a), _as_expr(b), _as_expr(c))
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def inner(a, b) -> Expr:
|
|
306
|
+
"""Inner product over the last axis."""
|
|
307
|
+
return Expr("inner", _as_expr(a), _as_expr(b))
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def action(v, s) -> Expr:
|
|
311
|
+
"""Test-function action: v.val * s -> (q, n_ldofs)."""
|
|
312
|
+
return Expr("action", _as_expr(v), _as_expr(s))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def gaction(v, q) -> Expr:
|
|
316
|
+
"""Gradient action: v.grad · q -> (q, n_ldofs)."""
|
|
317
|
+
return Expr("gaction", _as_expr(v), _as_expr(q))
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def normal() -> Expr:
|
|
321
|
+
"""Surface normal vector (from SurfaceFormContext)."""
|
|
322
|
+
return Expr("surface_normal")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def ds() -> Expr:
|
|
326
|
+
"""Surface quadrature measure (w * detJ)."""
|
|
327
|
+
return Expr("surface_measure")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def dOmega() -> Expr:
|
|
331
|
+
"""Volume quadrature measure (w * detJ)."""
|
|
332
|
+
return Expr("volume_measure")
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def I(dim: int) -> Expr:
|
|
336
|
+
"""Identity matrix of size dim."""
|
|
337
|
+
return Expr("eye", dim)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def det(a) -> Expr:
|
|
341
|
+
"""Determinant of a square matrix."""
|
|
342
|
+
return Expr("det", _as_expr(a))
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def inv(a) -> Expr:
|
|
346
|
+
"""Matrix inverse."""
|
|
347
|
+
return Expr("inv", _as_expr(a))
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def transpose(a) -> Expr:
|
|
351
|
+
"""Swap the last two axes."""
|
|
352
|
+
return Expr("transpose", _as_expr(a))
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def log(a) -> Expr:
|
|
356
|
+
"""Natural logarithm."""
|
|
357
|
+
return Expr("log", _as_expr(a))
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def transpose_last2(a) -> Expr:
|
|
361
|
+
"""Swap the last two axes."""
|
|
362
|
+
return Expr("transpose_last2", _as_expr(a))
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def matmul(a, b) -> Expr:
|
|
366
|
+
"""Matrix product with standard semantics (no special 3D contraction)."""
|
|
367
|
+
return Expr("matmul_std", _as_expr(a), _as_expr(b))
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def einsum(subscripts: str, *args) -> Expr:
|
|
371
|
+
"""Einsum wrapper that supports Expr inputs."""
|
|
372
|
+
return Expr("einsum", subscripts, *[_as_expr(arg) for arg in args])
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _call_user(fn, *args, params):
|
|
376
|
+
try:
|
|
377
|
+
return fn(*args, params)
|
|
378
|
+
except TypeError:
|
|
379
|
+
return fn(*args)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def compile_bilinear(fn):
|
|
383
|
+
"""get_compiled a bilinear weak form (u, v, params) -> Expr into a kernel."""
|
|
384
|
+
if isinstance(fn, Expr):
|
|
385
|
+
expr = fn
|
|
386
|
+
else:
|
|
387
|
+
u = trial_ref()
|
|
388
|
+
v = test_ref()
|
|
389
|
+
p = param_ref()
|
|
390
|
+
try:
|
|
391
|
+
expr = fn(u, v, p)
|
|
392
|
+
except TypeError:
|
|
393
|
+
expr = fn(u, v)
|
|
394
|
+
|
|
395
|
+
includes_measure = _expr_contains(expr, "volume_measure")
|
|
396
|
+
if not includes_measure:
|
|
397
|
+
raise ValueError("Volume bilinear form must include dOmega().")
|
|
398
|
+
|
|
399
|
+
def _form(ctx, params):
|
|
400
|
+
return _as_expr(expr).eval(ctx, params)
|
|
401
|
+
|
|
402
|
+
_form._includes_measure = includes_measure
|
|
403
|
+
return _form
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def compile_linear(fn):
|
|
407
|
+
"""get_compiled a linear weak form (v, params) -> Expr into a kernel."""
|
|
408
|
+
if isinstance(fn, Expr):
|
|
409
|
+
expr = fn
|
|
410
|
+
else:
|
|
411
|
+
v = test_ref()
|
|
412
|
+
p = param_ref()
|
|
413
|
+
try:
|
|
414
|
+
expr = fn(v, p)
|
|
415
|
+
except TypeError:
|
|
416
|
+
expr = fn(v)
|
|
417
|
+
|
|
418
|
+
includes_measure = _expr_contains(expr, "volume_measure")
|
|
419
|
+
if not includes_measure:
|
|
420
|
+
raise ValueError("Volume linear form must include dOmega().")
|
|
421
|
+
|
|
422
|
+
def _form(ctx, params):
|
|
423
|
+
return _as_expr(expr).eval(ctx, params)
|
|
424
|
+
|
|
425
|
+
_form._includes_measure = includes_measure
|
|
426
|
+
return _form
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _expr_contains(expr: Expr, op: str) -> bool:
|
|
430
|
+
if not isinstance(expr, Expr):
|
|
431
|
+
return False
|
|
432
|
+
if expr.op == op:
|
|
433
|
+
return True
|
|
434
|
+
return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def compile_surface_linear(fn):
|
|
438
|
+
"""get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
|
|
439
|
+
if isinstance(fn, Expr):
|
|
440
|
+
expr = fn
|
|
441
|
+
else:
|
|
442
|
+
v = test_ref()
|
|
443
|
+
p = param_ref()
|
|
444
|
+
expr = None
|
|
445
|
+
try:
|
|
446
|
+
expr = fn(v, p)
|
|
447
|
+
except TypeError:
|
|
448
|
+
try:
|
|
449
|
+
expr = fn(v)
|
|
450
|
+
except TypeError:
|
|
451
|
+
expr = None
|
|
452
|
+
|
|
453
|
+
if not isinstance(expr, Expr):
|
|
454
|
+
raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
|
|
455
|
+
|
|
456
|
+
includes_measure = _expr_contains(expr, "surface_measure")
|
|
457
|
+
if not includes_measure:
|
|
458
|
+
raise ValueError("Surface linear form must include ds().")
|
|
459
|
+
|
|
460
|
+
def _form(ctx, params):
|
|
461
|
+
return _as_expr(expr).eval(ctx, params)
|
|
462
|
+
|
|
463
|
+
_form._includes_measure = includes_measure # type: ignore[attr-defined]
|
|
464
|
+
return _form
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
class LinearForm:
|
|
468
|
+
"""Linear form wrapper with volume/surface backends."""
|
|
469
|
+
|
|
470
|
+
def __init__(self, fn, *, kind: str):
|
|
471
|
+
self.fn = fn
|
|
472
|
+
self.kind = kind
|
|
473
|
+
|
|
474
|
+
@classmethod
|
|
475
|
+
def volume(cls, fn):
|
|
476
|
+
return cls(fn, kind="volume")
|
|
477
|
+
|
|
478
|
+
@classmethod
|
|
479
|
+
def surface(cls, fn):
|
|
480
|
+
return cls(fn, kind="surface")
|
|
481
|
+
|
|
482
|
+
def get_compiled(self, *, ctx_kind: str | None = None):
|
|
483
|
+
kind = self.kind if ctx_kind is None else ctx_kind
|
|
484
|
+
if kind == "volume":
|
|
485
|
+
return compile_linear(self.fn)
|
|
486
|
+
if kind == "surface":
|
|
487
|
+
return compile_surface_linear(self.fn)
|
|
488
|
+
raise ValueError(f"Unknown linear form kind: {kind}")
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class BilinearForm:
|
|
492
|
+
"""Bilinear form wrapper (volume only for now)."""
|
|
493
|
+
|
|
494
|
+
def __init__(self, fn):
|
|
495
|
+
self.fn = fn
|
|
496
|
+
|
|
497
|
+
@classmethod
|
|
498
|
+
def volume(cls, fn):
|
|
499
|
+
return cls(fn)
|
|
500
|
+
|
|
501
|
+
def get_compiled(self):
|
|
502
|
+
return compile_bilinear(self.fn)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class ResidualForm:
|
|
506
|
+
"""Residual form wrapper (volume only for now)."""
|
|
507
|
+
|
|
508
|
+
def __init__(self, fn):
|
|
509
|
+
self.fn = fn
|
|
510
|
+
|
|
511
|
+
@classmethod
|
|
512
|
+
def volume(cls, fn):
|
|
513
|
+
return cls(fn)
|
|
514
|
+
|
|
515
|
+
def get_compiled(self):
|
|
516
|
+
return compile_residual(self.fn)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def compile_residual(fn):
|
|
520
|
+
"""get_compiled a residual weak form (v, u, params) -> Expr into a kernel."""
|
|
521
|
+
if isinstance(fn, Expr):
|
|
522
|
+
expr = fn
|
|
523
|
+
else:
|
|
524
|
+
v = test_ref()
|
|
525
|
+
u = unknown_ref()
|
|
526
|
+
p = param_ref()
|
|
527
|
+
try:
|
|
528
|
+
expr = fn(v, u, p)
|
|
529
|
+
except TypeError:
|
|
530
|
+
expr = fn(v, u)
|
|
531
|
+
|
|
532
|
+
includes_measure = _expr_contains(expr, "volume_measure")
|
|
533
|
+
if not includes_measure:
|
|
534
|
+
raise ValueError("Volume residual form must include dOmega().")
|
|
535
|
+
|
|
536
|
+
def _form(ctx, u_elem, params):
|
|
537
|
+
return _as_expr(expr).eval(ctx, params, u_elem=u_elem)
|
|
538
|
+
|
|
539
|
+
_form._includes_measure = includes_measure
|
|
540
|
+
return _form
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
544
|
+
"""get_compiled mixed residuals keyed by field name."""
|
|
545
|
+
compiled = {}
|
|
546
|
+
includes_measure = {}
|
|
547
|
+
for name, fn in residuals.items():
|
|
548
|
+
if isinstance(fn, Expr):
|
|
549
|
+
expr = fn
|
|
550
|
+
else:
|
|
551
|
+
v = test_ref(name)
|
|
552
|
+
u = unknown_ref(name)
|
|
553
|
+
p = param_ref()
|
|
554
|
+
try:
|
|
555
|
+
expr = fn(v, u, p)
|
|
556
|
+
except TypeError:
|
|
557
|
+
expr = fn(v, u)
|
|
558
|
+
compiled[name] = _as_expr(expr)
|
|
559
|
+
includes_measure[name] = _expr_contains(compiled[name], "volume_measure")
|
|
560
|
+
if not includes_measure[name]:
|
|
561
|
+
raise ValueError(f"Mixed residual '{name}' must include dOmega().")
|
|
562
|
+
|
|
563
|
+
def _form(ctx, u_elem, params):
|
|
564
|
+
return {name: expr.eval(ctx, params, u_elem=u_elem) for name, expr in compiled.items()}
|
|
565
|
+
|
|
566
|
+
_form._includes_measure = includes_measure
|
|
567
|
+
return _form
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class MixedWeakForm:
|
|
571
|
+
"""Container for mixed weak-form residuals keyed by field name."""
|
|
572
|
+
|
|
573
|
+
def __init__(self, *, residuals: dict[str, Callable]):
|
|
574
|
+
self.residuals = residuals
|
|
575
|
+
|
|
576
|
+
def get_compiled(self):
|
|
577
|
+
if not self.residuals:
|
|
578
|
+
raise ValueError("residuals are not defined")
|
|
579
|
+
return compile_mixed_residual(self.residuals)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def _eval_expr(expr: Expr, ctx, params, u_elem=None):
|
|
583
|
+
op = expr.op
|
|
584
|
+
args = expr.args
|
|
585
|
+
|
|
586
|
+
if op == "lit":
|
|
587
|
+
return args[0]
|
|
588
|
+
if op == "param":
|
|
589
|
+
return params
|
|
590
|
+
if op == "getattr":
|
|
591
|
+
base = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
592
|
+
name = args[1]
|
|
593
|
+
if isinstance(base, dict):
|
|
594
|
+
return base[name]
|
|
595
|
+
return getattr(base, name)
|
|
596
|
+
if op == "field":
|
|
597
|
+
role, name = args
|
|
598
|
+
if name is not None:
|
|
599
|
+
if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
|
|
600
|
+
if name in ctx.trial_fields:
|
|
601
|
+
return ctx.trial_fields[name]
|
|
602
|
+
if role == "test" and getattr(ctx, "test_fields", None) is not None:
|
|
603
|
+
if name in ctx.test_fields:
|
|
604
|
+
return ctx.test_fields[name]
|
|
605
|
+
if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
|
|
606
|
+
if name in ctx.unknown_fields:
|
|
607
|
+
return ctx.unknown_fields[name]
|
|
608
|
+
fields = getattr(ctx, "fields", None)
|
|
609
|
+
if fields is not None and name in fields:
|
|
610
|
+
group = fields[name]
|
|
611
|
+
if isinstance(group, dict):
|
|
612
|
+
if role in group:
|
|
613
|
+
return group[role]
|
|
614
|
+
if "field" in group:
|
|
615
|
+
return group["field"]
|
|
616
|
+
return group
|
|
617
|
+
if role == "trial":
|
|
618
|
+
return ctx.trial
|
|
619
|
+
if role == "test":
|
|
620
|
+
return ctx.test
|
|
621
|
+
if role == "unknown":
|
|
622
|
+
return getattr(ctx, "unknown", ctx.trial)
|
|
623
|
+
raise ValueError(f"Unknown field role: {role}")
|
|
624
|
+
if op == "value":
|
|
625
|
+
field = _eval_field(args[0], ctx, params)
|
|
626
|
+
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
627
|
+
return _eval_unknown_value(args[0], field, u_elem)
|
|
628
|
+
return field.N
|
|
629
|
+
if op == "grad":
|
|
630
|
+
field = _eval_field(args[0], ctx, params)
|
|
631
|
+
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
632
|
+
return _eval_unknown_grad(args[0], field, u_elem)
|
|
633
|
+
return field.gradN
|
|
634
|
+
if op == "pow":
|
|
635
|
+
base = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
636
|
+
exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
637
|
+
return base**exp
|
|
638
|
+
if op == "eye":
|
|
639
|
+
return jnp.eye(int(args[0]))
|
|
640
|
+
if op == "det":
|
|
641
|
+
return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
642
|
+
if op == "inv":
|
|
643
|
+
return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
644
|
+
if op == "transpose":
|
|
645
|
+
return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
|
|
646
|
+
if op == "log":
|
|
647
|
+
return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
648
|
+
if op == "surface_normal":
|
|
649
|
+
normal = getattr(ctx, "normal", None)
|
|
650
|
+
if normal is None:
|
|
651
|
+
raise ValueError("surface normal is not available in context")
|
|
652
|
+
return normal
|
|
653
|
+
if op == "surface_measure":
|
|
654
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
|
|
655
|
+
raise ValueError("surface measure requires surface context with w and detJ.")
|
|
656
|
+
return ctx.w * ctx.detJ
|
|
657
|
+
if op == "volume_measure":
|
|
658
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
|
|
659
|
+
raise ValueError("volume measure requires FormContext with w and test.detJ.")
|
|
660
|
+
return ctx.w * ctx.test.detJ
|
|
661
|
+
if op == "sym_grad":
|
|
662
|
+
field = _eval_field(args[0], ctx, params)
|
|
663
|
+
if isinstance(args[0], FieldRef) and args[0].role == "unknown":
|
|
664
|
+
if u_elem is None:
|
|
665
|
+
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
666
|
+
u_local = _extract_unknown_elem(args[0], u_elem)
|
|
667
|
+
return _ops.sym_grad_u(field, u_local)
|
|
668
|
+
return _ops.sym_grad(field)
|
|
669
|
+
if op == "outer":
|
|
670
|
+
a, b = args
|
|
671
|
+
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
672
|
+
raise TypeError("outer expects FieldRef operands.")
|
|
673
|
+
if a.role == b.role:
|
|
674
|
+
raise ValueError("outer requires one trial and one test field.")
|
|
675
|
+
test = a if a.role == "test" else b
|
|
676
|
+
trial = b if a.role == "test" else a
|
|
677
|
+
v_field = _eval_field(test, ctx, params)
|
|
678
|
+
u_field = _eval_field(trial, ctx, params)
|
|
679
|
+
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
680
|
+
raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
|
|
681
|
+
vN = v_field.N
|
|
682
|
+
uN = u_field.N
|
|
683
|
+
return jnp.einsum("qi,qj->qij", vN, uN)
|
|
684
|
+
if op == "add":
|
|
685
|
+
return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
686
|
+
if op == "sub":
|
|
687
|
+
return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
688
|
+
if op == "mul":
|
|
689
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
690
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
691
|
+
if hasattr(a, "ndim") and hasattr(b, "ndim"):
|
|
692
|
+
if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
|
|
693
|
+
a = a[:, None]
|
|
694
|
+
elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
|
|
695
|
+
b = b[:, None]
|
|
696
|
+
elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
|
|
697
|
+
b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
|
|
698
|
+
elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
|
|
699
|
+
a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
|
|
700
|
+
return a * b
|
|
701
|
+
if op == "matmul":
|
|
702
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
703
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
704
|
+
if (
|
|
705
|
+
hasattr(a, "ndim")
|
|
706
|
+
and hasattr(b, "ndim")
|
|
707
|
+
and a.ndim == 3
|
|
708
|
+
and b.ndim == 3
|
|
709
|
+
and a.shape[0] == b.shape[0]
|
|
710
|
+
and a.shape[-1] == b.shape[-1]
|
|
711
|
+
):
|
|
712
|
+
return jnp.einsum("qia,qja->qij", a, b)
|
|
713
|
+
return a @ b
|
|
714
|
+
if op == "matmul_std":
|
|
715
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
716
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
717
|
+
return jnp.matmul(a, b)
|
|
718
|
+
if op == "neg":
|
|
719
|
+
return -_eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
720
|
+
if op == "dot":
|
|
721
|
+
if isinstance(args[0], FieldRef):
|
|
722
|
+
return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
|
|
723
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
724
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
725
|
+
if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
|
|
726
|
+
return jnp.einsum("qia,qja->qij", a, b)
|
|
727
|
+
return jnp.matmul(a, b)
|
|
728
|
+
if op == "sdot":
|
|
729
|
+
if isinstance(args[0], FieldRef):
|
|
730
|
+
return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
|
|
731
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
732
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
733
|
+
if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
|
|
734
|
+
return jnp.einsum("qia,qja->qij", a, b)
|
|
735
|
+
return jnp.matmul(a, b)
|
|
736
|
+
if op == "ddot":
|
|
737
|
+
if len(args) == 2:
|
|
738
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
739
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
740
|
+
if (
|
|
741
|
+
hasattr(a, "ndim")
|
|
742
|
+
and hasattr(b, "ndim")
|
|
743
|
+
and a.ndim == 3
|
|
744
|
+
and b.ndim == 3
|
|
745
|
+
and a.shape[0] == b.shape[0]
|
|
746
|
+
and a.shape[1] == b.shape[1]
|
|
747
|
+
):
|
|
748
|
+
return jnp.einsum("qik,qim->qkm", a, b)
|
|
749
|
+
return _ops.ddot(a, b)
|
|
750
|
+
return _ops.ddot(
|
|
751
|
+
_eval_value(args[0], ctx, params, u_elem=u_elem),
|
|
752
|
+
_eval_value(args[1], ctx, params, u_elem=u_elem),
|
|
753
|
+
_eval_value(args[2], ctx, params, u_elem=u_elem),
|
|
754
|
+
)
|
|
755
|
+
if op == "inner":
|
|
756
|
+
a = _eval_value(args[0], ctx, params, u_elem=u_elem)
|
|
757
|
+
b = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
758
|
+
return jnp.einsum("...i,...i->...", a, b)
|
|
759
|
+
if op == "action":
|
|
760
|
+
if isinstance(args[1], FieldRef):
|
|
761
|
+
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
762
|
+
v_field = _eval_field(args[0], ctx, params)
|
|
763
|
+
s = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
764
|
+
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
765
|
+
if value_dim == 1:
|
|
766
|
+
if v_field.N.ndim != 2:
|
|
767
|
+
raise ValueError("action expects scalar test field with N shape (q, ndofs).")
|
|
768
|
+
if hasattr(s, "ndim") and s.ndim not in (0, 1):
|
|
769
|
+
raise ValueError("action expects scalar s with shape (q,) or scalar.")
|
|
770
|
+
return v_field.N * s
|
|
771
|
+
if hasattr(s, "ndim") and s.ndim not in (1, 2):
|
|
772
|
+
raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
|
|
773
|
+
return _ops.dot(v_field, s)
|
|
774
|
+
if op == "gaction":
|
|
775
|
+
v_field = _eval_field(args[0], ctx, params)
|
|
776
|
+
q = _eval_value(args[1], ctx, params, u_elem=u_elem)
|
|
777
|
+
if v_field.gradN.ndim != 3:
|
|
778
|
+
raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
|
|
779
|
+
if not hasattr(q, "ndim"):
|
|
780
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
781
|
+
if q.ndim == 2:
|
|
782
|
+
return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
|
|
783
|
+
if q.ndim == 3:
|
|
784
|
+
if int(getattr(v_field, "value_dim", 1)) == 1:
|
|
785
|
+
raise ValueError("gaction tensor flux requires vector test field.")
|
|
786
|
+
return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
|
|
787
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
788
|
+
if op == "transpose_last2":
|
|
789
|
+
return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
|
|
790
|
+
if op == "einsum":
|
|
791
|
+
subscripts = args[0]
|
|
792
|
+
operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
|
|
793
|
+
return jnp.einsum(subscripts, *operands)
|
|
794
|
+
|
|
795
|
+
raise ValueError(f"Unknown Expr op: {op}")
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
__all__ = [
|
|
799
|
+
"Expr",
|
|
800
|
+
"FieldRef",
|
|
801
|
+
"ParamRef",
|
|
802
|
+
"trial_ref",
|
|
803
|
+
"test_ref",
|
|
804
|
+
"unknown_ref",
|
|
805
|
+
"param_ref",
|
|
806
|
+
"Params",
|
|
807
|
+
"MixedWeakForm",
|
|
808
|
+
"ResidualForm",
|
|
809
|
+
"compile_bilinear",
|
|
810
|
+
"compile_linear",
|
|
811
|
+
"compile_residual",
|
|
812
|
+
"compile_mixed_residual",
|
|
813
|
+
"grad",
|
|
814
|
+
"sym_grad",
|
|
815
|
+
"dot",
|
|
816
|
+
"ddot",
|
|
817
|
+
"inner",
|
|
818
|
+
"action",
|
|
819
|
+
"gaction",
|
|
820
|
+
"I",
|
|
821
|
+
"det",
|
|
822
|
+
"inv",
|
|
823
|
+
"transpose",
|
|
824
|
+
"log",
|
|
825
|
+
"transpose_last2",
|
|
826
|
+
"matmul",
|
|
827
|
+
"einsum",
|
|
828
|
+
]
|