fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,818 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable
5
+
6
+ import jax.numpy as jnp
7
+ import jax
8
+
9
+ from ..physics import operators as _ops
10
+
11
+
12
+ class Expr:
13
+ """Expression tree node evaluated against a FormContext."""
14
+
15
+ def __init__(self, op: str, *args):
16
+ object.__setattr__(self, "op", op)
17
+ object.__setattr__(self, "args", args)
18
+
19
+ def eval(self, ctx, params=None, u_elem=None):
20
+ return _eval_expr(self, ctx, params, u_elem=u_elem)
21
+
22
+ def _binop(self, other, op):
23
+ return Expr(op, self, _as_expr(other))
24
+
25
+ def __add__(self, other):
26
+ return self._binop(other, "add")
27
+
28
+ def __radd__(self, other):
29
+ return _as_expr(other)._binop(self, "add")
30
+
31
+ def __sub__(self, other):
32
+ return self._binop(other, "sub")
33
+
34
+ def __rsub__(self, other):
35
+ return _as_expr(other)._binop(self, "sub")
36
+
37
+ def __mul__(self, other):
38
+ return self._binop(other, "mul")
39
+
40
+ def __rmul__(self, other):
41
+ return _as_expr(other)._binop(self, "mul")
42
+
43
+ def __matmul__(self, other):
44
+ return self._binop(other, "matmul")
45
+
46
+ def __rmatmul__(self, other):
47
+ return _as_expr(other)._binop(self, "matmul")
48
+
49
+ def __or__(self, other):
50
+ return self._binop(other, "inner")
51
+
52
+ def __ror__(self, other):
53
+ return _as_expr(other)._binop(self, "inner")
54
+
55
+ def __pow__(self, power, modulo=None):
56
+ if modulo is not None:
57
+ raise ValueError("modulo is not supported for Expr exponentiation.")
58
+ return Expr("pow", self, _as_expr(power))
59
+
60
+ def __neg__(self):
61
+ return Expr("neg", self)
62
+
63
+ @property
64
+ def T(self):
65
+ return Expr("transpose", self)
66
+
67
+
68
+ @dataclass(frozen=True)
69
+ class FieldRef(Expr):
70
+ """Symbolic reference to trial/test/unknown field, optionally by name."""
71
+
72
+ role: str
73
+ name: str | None = None
74
+
75
+ def __init__(self, role: str, name: str | None = None):
76
+ object.__setattr__(self, "role", role)
77
+ object.__setattr__(self, "name", name)
78
+ super().__init__("field", role, name)
79
+
80
+ @property
81
+ def val(self):
82
+ return Expr("value", self)
83
+
84
+ @property
85
+ def grad(self):
86
+ return Expr("grad", self)
87
+
88
+ @property
89
+ def sym_grad(self):
90
+ return Expr("sym_grad", self)
91
+
92
+ def __mul__(self, other):
93
+ if isinstance(other, FieldRef):
94
+ return Expr("outer", self, other)
95
+ return Expr("mul", Expr("value", self), _as_expr(other))
96
+
97
+ def __rmul__(self, other):
98
+ if isinstance(other, FieldRef):
99
+ return Expr("outer", other, self)
100
+ return Expr("mul", _as_expr(other), Expr("value", self))
101
+
102
+ def __add__(self, other):
103
+ return Expr("add", Expr("value", self), _as_expr(other))
104
+
105
+ def __radd__(self, other):
106
+ return Expr("add", _as_expr(other), Expr("value", self))
107
+
108
+ def __sub__(self, other):
109
+ return Expr("sub", Expr("value", self), _as_expr(other))
110
+
111
+ def __rsub__(self, other):
112
+ return Expr("sub", _as_expr(other), Expr("value", self))
113
+
114
+ def __or__(self, other):
115
+ if isinstance(other, FieldRef):
116
+ return Expr("inner", self, other)
117
+ return Expr("sdot", self, _as_expr(other))
118
+
119
+ def __ror__(self, other):
120
+ if isinstance(other, FieldRef):
121
+ return Expr("inner", other, self)
122
+ return Expr("sdot", _as_expr(other), self)
123
+
124
+
125
+ @dataclass(frozen=True)
126
+ class ParamRef(Expr):
127
+ """Symbolic reference to params passed into the kernel."""
128
+
129
+ def __init__(self):
130
+ super().__init__("param")
131
+
132
+ def __getattr__(self, name: str):
133
+ return Expr("getattr", self, name)
134
+
135
+
136
+ @jax.tree_util.register_pytree_node_class
137
+ class Params:
138
+ """Simple params container with attribute access (JAX pytree)."""
139
+
140
+ def __init__(self, **kwargs):
141
+ self._data = dict(kwargs)
142
+
143
+ def __getattr__(self, name: str):
144
+ try:
145
+ return self._data[name]
146
+ except KeyError as exc:
147
+ raise AttributeError(name) from exc
148
+
149
+ def __getitem__(self, key: str):
150
+ return self._data[key]
151
+
152
+ def tree_flatten(self):
153
+ keys = tuple(sorted(self._data.keys()))
154
+ values = tuple(self._data[k] for k in keys)
155
+ return values, keys
156
+
157
+ @classmethod
158
+ def tree_unflatten(cls, keys, values):
159
+ return cls(**dict(zip(keys, values)))
160
+
161
+
162
+ def trial_ref(name: str | None = "u") -> FieldRef:
163
+ """Create a symbolic trial field reference."""
164
+ return FieldRef(role="trial", name=name)
165
+
166
+
167
+ def test_ref(name: str | None = "v") -> FieldRef:
168
+ """Create a symbolic test field reference."""
169
+ return FieldRef(role="test", name=name)
170
+
171
+
172
+ def unknown_ref(name: str | None = "u") -> FieldRef:
173
+ """Create a symbolic unknown (current solution) field reference."""
174
+ return FieldRef(role="unknown", name=name)
175
+
176
+
177
+ def param_ref() -> ParamRef:
178
+ """Create a symbolic params reference."""
179
+ return ParamRef()
180
+
181
+
182
+ def _as_expr(obj) -> Expr:
183
+ if isinstance(obj, Expr):
184
+ return obj
185
+ return Expr("lit", obj)
186
+
187
+
188
+ def _eval_field(obj: Any, ctx, params):
189
+ if isinstance(obj, FieldRef):
190
+ if obj.name is not None:
191
+ mixed_fields = getattr(ctx, "fields", None)
192
+ if mixed_fields is not None and obj.name in mixed_fields:
193
+ group = mixed_fields[obj.name]
194
+ if hasattr(group, "trial") and obj.role == "trial":
195
+ return group.trial
196
+ if hasattr(group, "test") and obj.role == "test":
197
+ return group.test
198
+ if hasattr(group, "unknown") and obj.role == "unknown":
199
+ return group.unknown if group.unknown is not None else group.trial
200
+ if obj.role == "trial" and getattr(ctx, "trial_fields", None) is not None:
201
+ if obj.name in ctx.trial_fields:
202
+ return ctx.trial_fields[obj.name]
203
+ if obj.role == "test" and getattr(ctx, "test_fields", None) is not None:
204
+ if obj.name in ctx.test_fields:
205
+ return ctx.test_fields[obj.name]
206
+ if obj.role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
207
+ if obj.name in ctx.unknown_fields:
208
+ return ctx.unknown_fields[obj.name]
209
+ fields = getattr(ctx, "fields", None)
210
+ if fields is not None and obj.name in fields:
211
+ group = fields[obj.name]
212
+ if isinstance(group, dict):
213
+ if obj.role in group:
214
+ return group[obj.role]
215
+ if "field" in group:
216
+ return group["field"]
217
+ return group
218
+ if obj.role == "trial":
219
+ return ctx.trial
220
+ if obj.role == "test":
221
+ if hasattr(ctx, "test"):
222
+ return ctx.test
223
+ if hasattr(ctx, "v"):
224
+ return ctx.v
225
+ raise ValueError("Surface context is missing test field.")
226
+ if obj.role == "unknown":
227
+ return getattr(ctx, "unknown", ctx.trial)
228
+ raise ValueError(f"Unknown field role: {obj.role}")
229
+ if isinstance(obj, Expr):
230
+ val = obj.eval(ctx, params)
231
+ if hasattr(val, "N"):
232
+ return val
233
+ raise TypeError("Expected a field reference for this operator.")
234
+
235
+
236
+ def _eval_value(obj: Any, ctx, params, u_elem=None):
237
+ if isinstance(obj, FieldRef):
238
+ field = _eval_field(obj, ctx, params)
239
+ if obj.role == "unknown":
240
+ return _eval_unknown_value(obj, field, u_elem)
241
+ return field.N
242
+ if isinstance(obj, Expr):
243
+ return obj.eval(ctx, params, u_elem=u_elem)
244
+ return obj
245
+
246
+
247
+ def _extract_unknown_elem(field_ref: FieldRef, u_elem):
248
+ if u_elem is None:
249
+ raise ValueError("u_elem is required to evaluate unknown field value.")
250
+ if isinstance(u_elem, dict):
251
+ name = field_ref.name or "u"
252
+ if name not in u_elem:
253
+ raise ValueError(f"u_elem is missing key '{name}'.")
254
+ return u_elem[name]
255
+ return u_elem
256
+
257
+
258
+ def _eval_unknown_value(field_ref: FieldRef, field, u_elem):
259
+ u_local = _extract_unknown_elem(field_ref, u_elem)
260
+ value_dim = int(getattr(field, "value_dim", 1))
261
+ if value_dim == 1:
262
+ return jnp.einsum("qa,a->q", field.N, u_local)
263
+ u_nodes = u_local.reshape((-1, value_dim))
264
+ return jnp.einsum("qa,ai->qi", field.N, u_nodes)
265
+
266
+
267
+ def _eval_unknown_grad(field_ref: FieldRef, field, u_elem):
268
+ u_local = _extract_unknown_elem(field_ref, u_elem)
269
+ if u_local is None:
270
+ raise ValueError("u_elem is required to evaluate unknown field gradient.")
271
+ value_dim = int(getattr(field, "value_dim", 1))
272
+ if value_dim == 1:
273
+ return jnp.einsum("qaj,a->qj", field.gradN, u_local)
274
+ u_nodes = u_local.reshape((-1, value_dim))
275
+ return jnp.einsum("qaj,ai->qij", field.gradN, u_nodes)
276
+
277
+
278
+ def grad(field) -> Expr:
279
+ """Return basis gradients for a scalar or vector FormField."""
280
+ return Expr("grad", _as_expr(field))
281
+
282
+
283
+ def sym_grad(field) -> Expr:
284
+ """Return symmetric-gradient B-matrix for a vector FormField."""
285
+ return Expr("sym_grad", _as_expr(field))
286
+
287
+
288
+ def dot(a, b) -> Expr:
289
+ """Dot product or vector load helper."""
290
+ return Expr("dot", _as_expr(a), _as_expr(b))
291
+
292
+
293
+ def sdot(a, b) -> Expr:
294
+ """Surface dot product or vector load helper."""
295
+ return Expr("sdot", _as_expr(a), _as_expr(b))
296
+
297
+
298
+ def ddot(a, b, c=None) -> Expr:
299
+ """Double contraction or a^T b c."""
300
+ if c is None:
301
+ return Expr("ddot", _as_expr(a), _as_expr(b))
302
+ return Expr("ddot", _as_expr(a), _as_expr(b), _as_expr(c))
303
+
304
+
305
+ def inner(a, b) -> Expr:
306
+ """Inner product over the last axis."""
307
+ return Expr("inner", _as_expr(a), _as_expr(b))
308
+
309
+
310
+ def action(v, s) -> Expr:
311
+ """Test-function action: v.val * s -> (q, n_ldofs)."""
312
+ return Expr("action", _as_expr(v), _as_expr(s))
313
+
314
+
315
+ def gaction(v, q) -> Expr:
316
+ """Gradient action: v.grad · q -> (q, n_ldofs)."""
317
+ return Expr("gaction", _as_expr(v), _as_expr(q))
318
+
319
+
320
+ def normal() -> Expr:
321
+ """Surface normal vector (from SurfaceFormContext)."""
322
+ return Expr("surface_normal")
323
+
324
+
325
+ def ds() -> Expr:
326
+ """Surface quadrature measure (w * detJ)."""
327
+ return Expr("surface_measure")
328
+
329
+
330
+ def dOmega() -> Expr:
331
+ """Volume quadrature measure (w * detJ)."""
332
+ return Expr("volume_measure")
333
+
334
+
335
+ def I(dim: int) -> Expr:
336
+ """Identity matrix of size dim."""
337
+ return Expr("eye", dim)
338
+
339
+
340
+ def det(a) -> Expr:
341
+ """Determinant of a square matrix."""
342
+ return Expr("det", _as_expr(a))
343
+
344
+
345
+ def inv(a) -> Expr:
346
+ """Matrix inverse."""
347
+ return Expr("inv", _as_expr(a))
348
+
349
+
350
+ def transpose(a) -> Expr:
351
+ """Swap the last two axes."""
352
+ return Expr("transpose", _as_expr(a))
353
+
354
+
355
+ def log(a) -> Expr:
356
+ """Natural logarithm."""
357
+ return Expr("log", _as_expr(a))
358
+
359
+
360
+ def transpose_last2(a) -> Expr:
361
+ """Swap the last two axes."""
362
+ return Expr("transpose_last2", _as_expr(a))
363
+
364
+
365
+ def einsum(subscripts: str, *args) -> Expr:
366
+ """Einsum wrapper that supports Expr inputs."""
367
+ return Expr("einsum", subscripts, *[_as_expr(arg) for arg in args])
368
+
369
+
370
+ def _call_user(fn, *args, params):
371
+ try:
372
+ return fn(*args, params)
373
+ except TypeError:
374
+ return fn(*args)
375
+
376
+
377
+ def compile_bilinear(fn):
378
+ """get_compiled a bilinear weak form (u, v, params) -> Expr into a kernel."""
379
+ if isinstance(fn, Expr):
380
+ expr = fn
381
+ else:
382
+ u = trial_ref()
383
+ v = test_ref()
384
+ p = param_ref()
385
+ try:
386
+ expr = fn(u, v, p)
387
+ except TypeError:
388
+ expr = fn(u, v)
389
+
390
+ includes_measure = _expr_contains(expr, "volume_measure")
391
+ if not includes_measure:
392
+ raise ValueError("Volume bilinear form must include dOmega().")
393
+
394
+ def _form(ctx, params):
395
+ return _as_expr(expr).eval(ctx, params)
396
+
397
+ _form._includes_measure = includes_measure
398
+ return _form
399
+
400
+
401
+ def compile_linear(fn):
402
+ """get_compiled a linear weak form (v, params) -> Expr into a kernel."""
403
+ if isinstance(fn, Expr):
404
+ expr = fn
405
+ else:
406
+ v = test_ref()
407
+ p = param_ref()
408
+ try:
409
+ expr = fn(v, p)
410
+ except TypeError:
411
+ expr = fn(v)
412
+
413
+ includes_measure = _expr_contains(expr, "volume_measure")
414
+ if not includes_measure:
415
+ raise ValueError("Volume linear form must include dOmega().")
416
+
417
+ def _form(ctx, params):
418
+ return _as_expr(expr).eval(ctx, params)
419
+
420
+ _form._includes_measure = includes_measure
421
+ return _form
422
+
423
+
424
+ def _expr_contains(expr: Expr, op: str) -> bool:
425
+ if not isinstance(expr, Expr):
426
+ return False
427
+ if expr.op == op:
428
+ return True
429
+ return any(_expr_contains(arg, op) for arg in expr.args if isinstance(arg, Expr))
430
+
431
+
432
+ def compile_surface_linear(fn):
433
+ """get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
434
+ if isinstance(fn, Expr):
435
+ expr = fn
436
+ else:
437
+ v = test_ref()
438
+ p = param_ref()
439
+ expr = None
440
+ try:
441
+ expr = fn(v, p)
442
+ except TypeError:
443
+ try:
444
+ expr = fn(v)
445
+ except TypeError:
446
+ expr = None
447
+
448
+ if not isinstance(expr, Expr):
449
+ raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
450
+
451
+ includes_measure = _expr_contains(expr, "surface_measure")
452
+ if not includes_measure:
453
+ raise ValueError("Surface linear form must include ds().")
454
+
455
+ def _form(ctx, params):
456
+ return _as_expr(expr).eval(ctx, params)
457
+
458
+ _form._includes_measure = includes_measure # type: ignore[attr-defined]
459
+ return _form
460
+
461
+
462
+ class LinearForm:
463
+ """Linear form wrapper with volume/surface backends."""
464
+
465
+ def __init__(self, fn, *, kind: str):
466
+ self.fn = fn
467
+ self.kind = kind
468
+
469
+ @classmethod
470
+ def volume(cls, fn):
471
+ return cls(fn, kind="volume")
472
+
473
+ @classmethod
474
+ def surface(cls, fn):
475
+ return cls(fn, kind="surface")
476
+
477
+ def get_compiled(self, *, ctx_kind: str | None = None):
478
+ kind = self.kind if ctx_kind is None else ctx_kind
479
+ if kind == "volume":
480
+ return compile_linear(self.fn)
481
+ if kind == "surface":
482
+ return compile_surface_linear(self.fn)
483
+ raise ValueError(f"Unknown linear form kind: {kind}")
484
+
485
+
486
+ class BilinearForm:
487
+ """Bilinear form wrapper (volume only for now)."""
488
+
489
+ def __init__(self, fn):
490
+ self.fn = fn
491
+
492
+ @classmethod
493
+ def volume(cls, fn):
494
+ return cls(fn)
495
+
496
+ def get_compiled(self):
497
+ return compile_bilinear(self.fn)
498
+
499
+
500
+ class ResidualForm:
501
+ """Residual form wrapper (volume only for now)."""
502
+
503
+ def __init__(self, fn):
504
+ self.fn = fn
505
+
506
+ @classmethod
507
+ def volume(cls, fn):
508
+ return cls(fn)
509
+
510
+ def get_compiled(self):
511
+ return compile_residual(self.fn)
512
+
513
+
514
+ def compile_residual(fn):
515
+ """get_compiled a residual weak form (v, u, params) -> Expr into a kernel."""
516
+ if isinstance(fn, Expr):
517
+ expr = fn
518
+ else:
519
+ v = test_ref()
520
+ u = unknown_ref()
521
+ p = param_ref()
522
+ try:
523
+ expr = fn(v, u, p)
524
+ except TypeError:
525
+ expr = fn(v, u)
526
+
527
+ includes_measure = _expr_contains(expr, "volume_measure")
528
+ if not includes_measure:
529
+ raise ValueError("Volume residual form must include dOmega().")
530
+
531
+ def _form(ctx, u_elem, params):
532
+ return _as_expr(expr).eval(ctx, params, u_elem=u_elem)
533
+
534
+ _form._includes_measure = includes_measure
535
+ return _form
536
+
537
+
538
+ def compile_mixed_residual(residuals: dict[str, Callable]):
539
+ """get_compiled mixed residuals keyed by field name."""
540
+ compiled = {}
541
+ includes_measure = {}
542
+ for name, fn in residuals.items():
543
+ if isinstance(fn, Expr):
544
+ expr = fn
545
+ else:
546
+ v = test_ref(name)
547
+ u = unknown_ref(name)
548
+ p = param_ref()
549
+ try:
550
+ expr = fn(v, u, p)
551
+ except TypeError:
552
+ expr = fn(v, u)
553
+ compiled[name] = _as_expr(expr)
554
+ includes_measure[name] = _expr_contains(compiled[name], "volume_measure")
555
+ if not includes_measure[name]:
556
+ raise ValueError(f"Mixed residual '{name}' must include dOmega().")
557
+
558
+ def _form(ctx, u_elem, params):
559
+ return {name: expr.eval(ctx, params, u_elem=u_elem) for name, expr in compiled.items()}
560
+
561
+ _form._includes_measure = includes_measure
562
+ return _form
563
+
564
+
565
+ class MixedWeakForm:
566
+ """Container for mixed weak-form residuals keyed by field name."""
567
+
568
+ def __init__(self, *, residuals: dict[str, Callable]):
569
+ self.residuals = residuals
570
+
571
+ def get_compiled(self):
572
+ if not self.residuals:
573
+ raise ValueError("residuals are not defined")
574
+ return compile_mixed_residual(self.residuals)
575
+
576
+
577
+ def _eval_expr(expr: Expr, ctx, params, u_elem=None):
578
+ op = expr.op
579
+ args = expr.args
580
+
581
+ if op == "lit":
582
+ return args[0]
583
+ if op == "param":
584
+ return params
585
+ if op == "getattr":
586
+ base = _eval_value(args[0], ctx, params, u_elem=u_elem)
587
+ name = args[1]
588
+ if isinstance(base, dict):
589
+ return base[name]
590
+ return getattr(base, name)
591
+ if op == "field":
592
+ role, name = args
593
+ if name is not None:
594
+ if role == "trial" and getattr(ctx, "trial_fields", None) is not None:
595
+ if name in ctx.trial_fields:
596
+ return ctx.trial_fields[name]
597
+ if role == "test" and getattr(ctx, "test_fields", None) is not None:
598
+ if name in ctx.test_fields:
599
+ return ctx.test_fields[name]
600
+ if role == "unknown" and getattr(ctx, "unknown_fields", None) is not None:
601
+ if name in ctx.unknown_fields:
602
+ return ctx.unknown_fields[name]
603
+ fields = getattr(ctx, "fields", None)
604
+ if fields is not None and name in fields:
605
+ group = fields[name]
606
+ if isinstance(group, dict):
607
+ if role in group:
608
+ return group[role]
609
+ if "field" in group:
610
+ return group["field"]
611
+ return group
612
+ if role == "trial":
613
+ return ctx.trial
614
+ if role == "test":
615
+ return ctx.test
616
+ if role == "unknown":
617
+ return getattr(ctx, "unknown", ctx.trial)
618
+ raise ValueError(f"Unknown field role: {role}")
619
+ if op == "value":
620
+ field = _eval_field(args[0], ctx, params)
621
+ if isinstance(args[0], FieldRef) and args[0].role == "unknown":
622
+ return _eval_unknown_value(args[0], field, u_elem)
623
+ return field.N
624
+ if op == "grad":
625
+ field = _eval_field(args[0], ctx, params)
626
+ if isinstance(args[0], FieldRef) and args[0].role == "unknown":
627
+ return _eval_unknown_grad(args[0], field, u_elem)
628
+ return field.gradN
629
+ if op == "pow":
630
+ base = _eval_value(args[0], ctx, params, u_elem=u_elem)
631
+ exp = _eval_value(args[1], ctx, params, u_elem=u_elem)
632
+ return base**exp
633
+ if op == "eye":
634
+ return jnp.eye(int(args[0]))
635
+ if op == "det":
636
+ return jnp.linalg.det(_eval_value(args[0], ctx, params, u_elem=u_elem))
637
+ if op == "inv":
638
+ return jnp.linalg.inv(_eval_value(args[0], ctx, params, u_elem=u_elem))
639
+ if op == "transpose":
640
+ return jnp.swapaxes(_eval_value(args[0], ctx, params, u_elem=u_elem), -1, -2)
641
+ if op == "log":
642
+ return jnp.log(_eval_value(args[0], ctx, params, u_elem=u_elem))
643
+ if op == "surface_normal":
644
+ normal = getattr(ctx, "normal", None)
645
+ if normal is None:
646
+ raise ValueError("surface normal is not available in context")
647
+ return normal
648
+ if op == "surface_measure":
649
+ if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
650
+ raise ValueError("surface measure requires surface context with w and detJ.")
651
+ return ctx.w * ctx.detJ
652
+ if op == "volume_measure":
653
+ if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
654
+ raise ValueError("volume measure requires FormContext with w and test.detJ.")
655
+ return ctx.w * ctx.test.detJ
656
+ if op == "sym_grad":
657
+ field = _eval_field(args[0], ctx, params)
658
+ if isinstance(args[0], FieldRef) and args[0].role == "unknown":
659
+ if u_elem is None:
660
+ raise ValueError("u_elem is required to evaluate unknown sym_grad.")
661
+ u_local = _extract_unknown_elem(args[0], u_elem)
662
+ return _ops.sym_grad_u(field, u_local)
663
+ return _ops.sym_grad(field)
664
+ if op == "outer":
665
+ a, b = args
666
+ if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
667
+ raise TypeError("outer expects FieldRef operands.")
668
+ if a.role == b.role:
669
+ raise ValueError("outer requires one trial and one test field.")
670
+ test = a if a.role == "test" else b
671
+ trial = b if a.role == "test" else a
672
+ v_field = _eval_field(test, ctx, params)
673
+ u_field = _eval_field(trial, ctx, params)
674
+ if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
675
+ raise ValueError("u*v is only defined for scalar fields; use dot/inner for vectors.")
676
+ vN = v_field.N
677
+ uN = u_field.N
678
+ return jnp.einsum("qi,qj->qij", vN, uN)
679
+ if op == "add":
680
+ return _eval_value(args[0], ctx, params, u_elem=u_elem) + _eval_value(args[1], ctx, params, u_elem=u_elem)
681
+ if op == "sub":
682
+ return _eval_value(args[0], ctx, params, u_elem=u_elem) - _eval_value(args[1], ctx, params, u_elem=u_elem)
683
+ if op == "mul":
684
+ a = _eval_value(args[0], ctx, params, u_elem=u_elem)
685
+ b = _eval_value(args[1], ctx, params, u_elem=u_elem)
686
+ if hasattr(a, "ndim") and hasattr(b, "ndim"):
687
+ if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
688
+ a = a[:, None]
689
+ elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
690
+ b = b[:, None]
691
+ elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
692
+ b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
693
+ elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
694
+ a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
695
+ return a * b
696
+ if op == "matmul":
697
+ a = _eval_value(args[0], ctx, params, u_elem=u_elem)
698
+ b = _eval_value(args[1], ctx, params, u_elem=u_elem)
699
+ if (
700
+ hasattr(a, "ndim")
701
+ and hasattr(b, "ndim")
702
+ and a.ndim == 3
703
+ and b.ndim == 3
704
+ and a.shape[0] == b.shape[0]
705
+ and a.shape[-1] == b.shape[-1]
706
+ ):
707
+ return jnp.einsum("qia,qja->qij", a, b)
708
+ return a @ b
709
+ if op == "neg":
710
+ return -_eval_value(args[0], ctx, params, u_elem=u_elem)
711
+ if op == "dot":
712
+ if isinstance(args[0], FieldRef):
713
+ return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
714
+ a = _eval_value(args[0], ctx, params, u_elem=u_elem)
715
+ b = _eval_value(args[1], ctx, params, u_elem=u_elem)
716
+ if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
717
+ return jnp.einsum("qia,qja->qij", a, b)
718
+ return jnp.matmul(a, b)
719
+ if op == "sdot":
720
+ if isinstance(args[0], FieldRef):
721
+ return _ops.dot(_eval_field(args[0], ctx, params), _eval_value(args[1], ctx, params, u_elem=u_elem))
722
+ a = _eval_value(args[0], ctx, params, u_elem=u_elem)
723
+ b = _eval_value(args[1], ctx, params, u_elem=u_elem)
724
+ if hasattr(a, "ndim") and hasattr(b, "ndim") and a.ndim == 3 and b.ndim == 3 and a.shape[-1] == b.shape[-1]:
725
+ return jnp.einsum("qia,qja->qij", a, b)
726
+ return jnp.matmul(a, b)
727
+ if op == "ddot":
728
+ if len(args) == 2:
729
+ a = _eval_value(args[0], ctx, params, u_elem=u_elem)
730
+ b = _eval_value(args[1], ctx, params, u_elem=u_elem)
731
+ if (
732
+ hasattr(a, "ndim")
733
+ and hasattr(b, "ndim")
734
+ and a.ndim == 3
735
+ and b.ndim == 3
736
+ and a.shape[0] == b.shape[0]
737
+ and a.shape[1] == b.shape[1]
738
+ ):
739
+ return jnp.einsum("qik,qim->qkm", a, b)
740
+ return _ops.ddot(a, b)
741
+ return _ops.ddot(
742
+ _eval_value(args[0], ctx, params, u_elem=u_elem),
743
+ _eval_value(args[1], ctx, params, u_elem=u_elem),
744
+ _eval_value(args[2], ctx, params, u_elem=u_elem),
745
+ )
746
+ if op == "inner":
747
+ a = _eval_value(args[0], ctx, params, u_elem=u_elem)
748
+ b = _eval_value(args[1], ctx, params, u_elem=u_elem)
749
+ return jnp.einsum("...i,...i->...", a, b)
750
+ if op == "action":
751
+ if isinstance(args[1], FieldRef):
752
+ raise ValueError("action expects a scalar expression; use u.val for unknowns.")
753
+ v_field = _eval_field(args[0], ctx, params)
754
+ s = _eval_value(args[1], ctx, params, u_elem=u_elem)
755
+ value_dim = int(getattr(v_field, "value_dim", 1))
756
+ if value_dim == 1:
757
+ if v_field.N.ndim != 2:
758
+ raise ValueError("action expects scalar test field with N shape (q, ndofs).")
759
+ if hasattr(s, "ndim") and s.ndim not in (0, 1):
760
+ raise ValueError("action expects scalar s with shape (q,) or scalar.")
761
+ return v_field.N * s
762
+ if hasattr(s, "ndim") and s.ndim not in (1, 2):
763
+ raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
764
+ return _ops.dot(v_field, s)
765
+ if op == "gaction":
766
+ v_field = _eval_field(args[0], ctx, params)
767
+ q = _eval_value(args[1], ctx, params, u_elem=u_elem)
768
+ if v_field.gradN.ndim != 3:
769
+ raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
770
+ if not hasattr(q, "ndim"):
771
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
772
+ if q.ndim == 2:
773
+ return jnp.einsum("qaj,qj->qa", v_field.gradN, q)
774
+ if q.ndim == 3:
775
+ if int(getattr(v_field, "value_dim", 1)) == 1:
776
+ raise ValueError("gaction tensor flux requires vector test field.")
777
+ return jnp.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
778
+ raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
779
+ if op == "transpose_last2":
780
+ return _ops.transpose_last2(_eval_value(args[0], ctx, params, u_elem=u_elem))
781
+ if op == "einsum":
782
+ subscripts = args[0]
783
+ operands = [_eval_value(arg, ctx, params, u_elem=u_elem) for arg in args[1:]]
784
+ return jnp.einsum(subscripts, *operands)
785
+
786
+ raise ValueError(f"Unknown Expr op: {op}")
787
+
788
+
789
+ __all__ = [
790
+ "Expr",
791
+ "FieldRef",
792
+ "ParamRef",
793
+ "trial_ref",
794
+ "test_ref",
795
+ "unknown_ref",
796
+ "param_ref",
797
+ "Params",
798
+ "MixedWeakForm",
799
+ "ResidualForm",
800
+ "compile_bilinear",
801
+ "compile_linear",
802
+ "compile_residual",
803
+ "compile_mixed_residual",
804
+ "grad",
805
+ "sym_grad",
806
+ "dot",
807
+ "ddot",
808
+ "inner",
809
+ "action",
810
+ "gaction",
811
+ "I",
812
+ "det",
813
+ "inv",
814
+ "transpose",
815
+ "log",
816
+ "transpose_last2",
817
+ "einsum",
818
+ ]