fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/weakform.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Callable, Iterator, Literal, get_args
|
|
4
|
+
from typing import Any, Callable, Iterator, Literal, Mapping, TypeAlias, cast, get_args
|
|
5
5
|
import inspect
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import update_wrapper
|
|
6
8
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
|
|
@@ -10,7 +12,7 @@ import jax.numpy as jnp
|
|
|
10
12
|
import jax
|
|
11
13
|
|
|
12
14
|
from ..physics import operators as _ops
|
|
13
|
-
from .context_types import FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext
|
|
15
|
+
from .context_types import ArrayLike, FormFieldLike, ParamsLike, SurfaceContext, UElement, VolumeContext, WeakFormContext
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
OpName = Literal[
|
|
@@ -45,6 +47,7 @@ OpName = Literal[
|
|
|
45
47
|
"einsum",
|
|
46
48
|
]
|
|
47
49
|
|
|
50
|
+
|
|
48
51
|
# Use OpName as the single source of truth for valid ops.
|
|
49
52
|
_OP_NAMES: frozenset[str] = frozenset(get_args(OpName))
|
|
50
53
|
|
|
@@ -396,6 +399,28 @@ class Params:
|
|
|
396
399
|
return cls(**dict(zip(keys, values)))
|
|
397
400
|
|
|
398
401
|
|
|
402
|
+
class _ZeroField:
|
|
403
|
+
"""Field-like object that evaluates to zeros with the same shape."""
|
|
404
|
+
|
|
405
|
+
def __init__(self, base):
|
|
406
|
+
self.N = jnp.zeros_like(base.N)
|
|
407
|
+
self.gradN = None if getattr(base, "gradN", None) is None else jnp.zeros_like(base.gradN)
|
|
408
|
+
self.detJ = getattr(base, "detJ", None)
|
|
409
|
+
self.value_dim = int(getattr(base, "value_dim", 1))
|
|
410
|
+
self.basis = getattr(base, "basis", None)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class _ZeroFieldNp:
|
|
414
|
+
"""Numpy variant of a zero-valued field (for numpy backend evaluation)."""
|
|
415
|
+
|
|
416
|
+
def __init__(self, base):
|
|
417
|
+
self.N = np.zeros_like(base.N)
|
|
418
|
+
self.gradN = None if getattr(base, "gradN", None) is None else np.zeros_like(base.gradN)
|
|
419
|
+
self.detJ = getattr(base, "detJ", None)
|
|
420
|
+
self.value_dim = int(getattr(base, "value_dim", 1))
|
|
421
|
+
self.basis = getattr(base, "basis", None)
|
|
422
|
+
|
|
423
|
+
|
|
399
424
|
def trial_ref(name: str | None = "u") -> FieldRef:
|
|
400
425
|
"""Create a symbolic trial field reference."""
|
|
401
426
|
return FieldRef(role="trial", name=name)
|
|
@@ -416,12 +441,105 @@ def param_ref() -> ParamRef:
|
|
|
416
441
|
return ParamRef()
|
|
417
442
|
|
|
418
443
|
|
|
444
|
+
def zero_ref(name: str) -> FieldRef:
|
|
445
|
+
"""Create a zero-valued field reference (shape derived from context)."""
|
|
446
|
+
return FieldRef("zero", name)
|
|
447
|
+
|
|
448
|
+
|
|
419
449
|
def _eval_field(
|
|
420
|
-
obj:
|
|
421
|
-
ctx:
|
|
450
|
+
obj: FieldRef,
|
|
451
|
+
ctx: WeakFormContext,
|
|
422
452
|
params: ParamsLike,
|
|
423
453
|
) -> FormFieldLike:
|
|
424
454
|
if isinstance(obj, FieldRef):
|
|
455
|
+
if obj.role == "zero":
|
|
456
|
+
if obj.name is None:
|
|
457
|
+
raise ValueError("zero_ref requires a named field.")
|
|
458
|
+
base = None
|
|
459
|
+
test_fields = getattr(ctx, "test_fields", None)
|
|
460
|
+
if test_fields is not None and obj.name in test_fields:
|
|
461
|
+
base = test_fields[obj.name]
|
|
462
|
+
trial_fields = getattr(ctx, "trial_fields", None)
|
|
463
|
+
if base is None and trial_fields is not None and obj.name in trial_fields:
|
|
464
|
+
base = trial_fields[obj.name]
|
|
465
|
+
fields = getattr(ctx, "fields", None)
|
|
466
|
+
if base is None and fields is not None and obj.name in fields:
|
|
467
|
+
group = fields[obj.name]
|
|
468
|
+
if hasattr(group, "test"):
|
|
469
|
+
base = group.test
|
|
470
|
+
elif hasattr(group, "trial"):
|
|
471
|
+
base = group.trial
|
|
472
|
+
if base is None:
|
|
473
|
+
raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
|
|
474
|
+
return _ZeroField(base)
|
|
475
|
+
if obj.name is not None:
|
|
476
|
+
mixed_fields = getattr(ctx, "fields", None)
|
|
477
|
+
if mixed_fields is not None and obj.name in mixed_fields:
|
|
478
|
+
group = mixed_fields[obj.name]
|
|
479
|
+
if hasattr(group, "trial") and obj.role == "trial":
|
|
480
|
+
return group.trial
|
|
481
|
+
if hasattr(group, "test") and obj.role == "test":
|
|
482
|
+
return group.test
|
|
483
|
+
if hasattr(group, "unknown") and obj.role == "unknown":
|
|
484
|
+
return group.unknown if group.unknown is not None else group.trial
|
|
485
|
+
trial_fields = getattr(ctx, "trial_fields", None)
|
|
486
|
+
if obj.role == "trial" and trial_fields is not None and obj.name in trial_fields:
|
|
487
|
+
return trial_fields[obj.name]
|
|
488
|
+
test_fields = getattr(ctx, "test_fields", None)
|
|
489
|
+
if obj.role == "test" and test_fields is not None and obj.name in test_fields:
|
|
490
|
+
return test_fields[obj.name]
|
|
491
|
+
unknown_fields = getattr(ctx, "unknown_fields", None)
|
|
492
|
+
if obj.role == "unknown" and unknown_fields is not None and obj.name in unknown_fields:
|
|
493
|
+
return unknown_fields[obj.name]
|
|
494
|
+
fields = getattr(ctx, "fields", None)
|
|
495
|
+
if fields is not None and obj.name in fields:
|
|
496
|
+
group = fields[obj.name]
|
|
497
|
+
if isinstance(group, dict):
|
|
498
|
+
if obj.role in group:
|
|
499
|
+
return group[obj.role]
|
|
500
|
+
if "field" in group:
|
|
501
|
+
return group["field"]
|
|
502
|
+
return group
|
|
503
|
+
if obj.role == "trial":
|
|
504
|
+
return ctx.trial
|
|
505
|
+
if obj.role == "test":
|
|
506
|
+
if hasattr(ctx, "test"):
|
|
507
|
+
return ctx.test
|
|
508
|
+
if hasattr(ctx, "v"):
|
|
509
|
+
return ctx.v
|
|
510
|
+
raise ValueError("Surface context is missing test field.")
|
|
511
|
+
if obj.role == "unknown":
|
|
512
|
+
return getattr(ctx, "unknown", ctx.trial)
|
|
513
|
+
raise ValueError(f"Unknown field role: {obj.role}")
|
|
514
|
+
raise TypeError("Expected a field reference for this operator.")
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _eval_field_np(
|
|
518
|
+
obj: FieldRef,
|
|
519
|
+
ctx: WeakFormContext,
|
|
520
|
+
params: ParamsLike,
|
|
521
|
+
) -> FormFieldLike:
|
|
522
|
+
if isinstance(obj, FieldRef):
|
|
523
|
+
if obj.role == "zero":
|
|
524
|
+
if obj.name is None:
|
|
525
|
+
raise ValueError("zero_ref requires a named field.")
|
|
526
|
+
base = None
|
|
527
|
+
test_fields = getattr(ctx, "test_fields", None)
|
|
528
|
+
if test_fields is not None and obj.name in test_fields:
|
|
529
|
+
base = test_fields[obj.name]
|
|
530
|
+
trial_fields = getattr(ctx, "trial_fields", None)
|
|
531
|
+
if base is None and trial_fields is not None and obj.name in trial_fields:
|
|
532
|
+
base = trial_fields[obj.name]
|
|
533
|
+
fields = getattr(ctx, "fields", None)
|
|
534
|
+
if base is None and fields is not None and obj.name in fields:
|
|
535
|
+
group = fields[obj.name]
|
|
536
|
+
if hasattr(group, "test"):
|
|
537
|
+
base = group.test
|
|
538
|
+
elif hasattr(group, "trial"):
|
|
539
|
+
base = group.trial
|
|
540
|
+
if base is None:
|
|
541
|
+
raise ValueError(f"zero_ref could not resolve field '{obj.name}'.")
|
|
542
|
+
return _ZeroFieldNp(base)
|
|
425
543
|
if obj.name is not None:
|
|
426
544
|
mixed_fields = getattr(ctx, "fields", None)
|
|
427
545
|
if mixed_fields is not None and obj.name in mixed_fields:
|
|
@@ -432,15 +550,15 @@ def _eval_field(
|
|
|
432
550
|
return group.test
|
|
433
551
|
if hasattr(group, "unknown") and obj.role == "unknown":
|
|
434
552
|
return group.unknown if group.unknown is not None else group.trial
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
553
|
+
trial_fields = getattr(ctx, "trial_fields", None)
|
|
554
|
+
if obj.role == "trial" and trial_fields is not None and obj.name in trial_fields:
|
|
555
|
+
return trial_fields[obj.name]
|
|
556
|
+
test_fields = getattr(ctx, "test_fields", None)
|
|
557
|
+
if obj.role == "test" and test_fields is not None and obj.name in test_fields:
|
|
558
|
+
return test_fields[obj.name]
|
|
559
|
+
unknown_fields = getattr(ctx, "unknown_fields", None)
|
|
560
|
+
if obj.role == "unknown" and unknown_fields is not None and obj.name in unknown_fields:
|
|
561
|
+
return unknown_fields[obj.name]
|
|
444
562
|
fields = getattr(ctx, "fields", None)
|
|
445
563
|
if fields is not None and obj.name in fields:
|
|
446
564
|
group = fields[obj.name]
|
|
@@ -477,7 +595,7 @@ def _eval_field(
|
|
|
477
595
|
# return obj
|
|
478
596
|
|
|
479
597
|
|
|
480
|
-
def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
|
|
598
|
+
def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement) -> ArrayLike:
|
|
481
599
|
if u_elem is None:
|
|
482
600
|
raise ValueError("u_elem is required to evaluate unknown field value.")
|
|
483
601
|
if isinstance(u_elem, dict):
|
|
@@ -489,8 +607,9 @@ def _extract_unknown_elem(field_ref: FieldRef, u_elem: UElement):
|
|
|
489
607
|
|
|
490
608
|
|
|
491
609
|
def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
|
|
492
|
-
|
|
493
|
-
|
|
610
|
+
ctx_w = cast(WeakFormContext, ctx)
|
|
611
|
+
v_field = _eval_field(test, ctx_w, params)
|
|
612
|
+
u_field = _eval_field(trial, ctx_w, params)
|
|
494
613
|
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
495
614
|
raise ValueError(
|
|
496
615
|
"inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
|
|
@@ -498,6 +617,17 @@ def _basis_outer(test: FieldRef, trial: FieldRef, ctx, params):
|
|
|
498
617
|
return jnp.einsum("qi,qj->qij", v_field.N, u_field.N)
|
|
499
618
|
|
|
500
619
|
|
|
620
|
+
def _basis_outer_np(test: FieldRef, trial: FieldRef, ctx, params):
|
|
621
|
+
ctx_w = cast(WeakFormContext, ctx)
|
|
622
|
+
v_field = _eval_field_np(test, ctx_w, params)
|
|
623
|
+
u_field = _eval_field_np(trial, ctx_w, params)
|
|
624
|
+
if getattr(v_field, "value_dim", 1) != 1 or getattr(u_field, "value_dim", 1) != 1:
|
|
625
|
+
raise ValueError(
|
|
626
|
+
"inner/outer is only defined for scalar fields; use dot/action/einsum for vector/tensor cases."
|
|
627
|
+
)
|
|
628
|
+
return np.einsum("qi,qj->qij", v_field.N, u_field.N)
|
|
629
|
+
|
|
630
|
+
|
|
501
631
|
def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
502
632
|
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
503
633
|
value_dim = int(getattr(field, "value_dim", 1))
|
|
@@ -507,6 +637,21 @@ def _eval_unknown_value(field_ref: FieldRef, field: FormFieldLike, u_elem: UElem
|
|
|
507
637
|
return jnp.einsum("qa,ai->qi", field.N, u_nodes)
|
|
508
638
|
|
|
509
639
|
|
|
640
|
+
def _eval_unknown_value_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
641
|
+
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
642
|
+
value_dim = int(getattr(field, "value_dim", 1))
|
|
643
|
+
u_arr = np.asarray(u_local)
|
|
644
|
+
if value_dim == 1:
|
|
645
|
+
if u_arr.ndim == 2:
|
|
646
|
+
return np.einsum("qa,ab->qb", field.N, u_arr)
|
|
647
|
+
return np.einsum("qa,a->q", field.N, u_arr)
|
|
648
|
+
if u_arr.ndim == 2:
|
|
649
|
+
u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
|
|
650
|
+
return np.einsum("qa,aib->qib", field.N, u_nodes)
|
|
651
|
+
u_nodes = u_arr.reshape((-1, value_dim))
|
|
652
|
+
return np.einsum("qa,ai->qi", field.N, u_nodes)
|
|
653
|
+
|
|
654
|
+
|
|
510
655
|
def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
511
656
|
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
512
657
|
if u_local is None:
|
|
@@ -518,6 +663,88 @@ def _eval_unknown_grad(field_ref: FieldRef, field: FormFieldLike, u_elem: UEleme
|
|
|
518
663
|
return jnp.einsum("qaj,ai->qij", field.gradN, u_nodes)
|
|
519
664
|
|
|
520
665
|
|
|
666
|
+
def _eval_unknown_grad_np(field_ref: FieldRef, field: FormFieldLike, u_elem: UElement):
|
|
667
|
+
u_local = _extract_unknown_elem(field_ref, u_elem)
|
|
668
|
+
if u_local is None:
|
|
669
|
+
raise ValueError("u_elem is required to evaluate unknown field gradient.")
|
|
670
|
+
value_dim = int(getattr(field, "value_dim", 1))
|
|
671
|
+
u_arr = np.asarray(u_local)
|
|
672
|
+
if value_dim == 1:
|
|
673
|
+
if u_arr.ndim == 2:
|
|
674
|
+
return np.einsum("qaj,ab->qjb", field.gradN, u_arr)
|
|
675
|
+
return np.einsum("qaj,a->qj", field.gradN, u_arr)
|
|
676
|
+
if u_arr.ndim == 2:
|
|
677
|
+
u_nodes = u_arr.reshape((-1, value_dim, u_arr.shape[1]))
|
|
678
|
+
return np.einsum("qaj,aib->qijb", field.gradN, u_nodes)
|
|
679
|
+
u_nodes = u_arr.reshape((-1, value_dim))
|
|
680
|
+
return np.einsum("qaj,ai->qij", field.gradN, u_nodes)
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def _vector_load_form_np(field: FormFieldLike, load_vec: ArrayLike) -> np.ndarray:
|
|
684
|
+
lv = np.asarray(load_vec)
|
|
685
|
+
if lv.ndim == 1:
|
|
686
|
+
lv = lv[None, :]
|
|
687
|
+
elif lv.ndim not in (2, 3):
|
|
688
|
+
raise ValueError("load_vec must be shape (dim,), (n_q, dim), or (n_q, dim, batch)")
|
|
689
|
+
if lv.shape[0] == 1:
|
|
690
|
+
lv = np.broadcast_to(lv, (field.N.shape[0], lv.shape[1]))
|
|
691
|
+
elif lv.shape[0] != field.N.shape[0]:
|
|
692
|
+
raise ValueError("load_vec must be shape (dim,) or (n_q, dim)")
|
|
693
|
+
if lv.ndim == 3:
|
|
694
|
+
load = field.N[..., None, None] * lv[:, None, :, :]
|
|
695
|
+
return load.reshape(load.shape[0], -1, load.shape[-1])
|
|
696
|
+
load = field.N[..., None] * lv[:, None, :]
|
|
697
|
+
return load.reshape(load.shape[0], -1)
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
def _sym_grad_np(field) -> np.ndarray:
|
|
701
|
+
gradN = np.asarray(field.gradN)
|
|
702
|
+
dofs = int(getattr(field.basis, "dofs_per_node", 3))
|
|
703
|
+
n_q, n_nodes, _ = gradN.shape
|
|
704
|
+
n_dofs = dofs * n_nodes
|
|
705
|
+
B = np.zeros((n_q, 6, n_dofs), dtype=gradN.dtype)
|
|
706
|
+
for a in range(n_nodes):
|
|
707
|
+
col = dofs * a
|
|
708
|
+
dNdx = gradN[:, a, 0]
|
|
709
|
+
dNdy = gradN[:, a, 1]
|
|
710
|
+
dNdz = gradN[:, a, 2]
|
|
711
|
+
B[:, 0, col + 0] = dNdx
|
|
712
|
+
B[:, 1, col + 1] = dNdy
|
|
713
|
+
B[:, 2, col + 2] = dNdz
|
|
714
|
+
B[:, 3, col + 0] = dNdy
|
|
715
|
+
B[:, 3, col + 1] = dNdx
|
|
716
|
+
B[:, 4, col + 1] = dNdz
|
|
717
|
+
B[:, 4, col + 2] = dNdy
|
|
718
|
+
B[:, 5, col + 0] = dNdz
|
|
719
|
+
B[:, 5, col + 2] = dNdx
|
|
720
|
+
return B
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def _sym_grad_u_np(field: FormFieldLike, u_elem: ArrayLike) -> np.ndarray:
|
|
724
|
+
B = _sym_grad_np(field)
|
|
725
|
+
u_arr = np.asarray(u_elem)
|
|
726
|
+
if u_arr.ndim == 2:
|
|
727
|
+
return np.einsum("qik,kb->qib", B, u_arr)
|
|
728
|
+
return np.einsum("qik,k->qi", B, u_arr)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def _ddot_np(a: ArrayLike, b: ArrayLike, c: ArrayLike | None = None) -> np.ndarray:
|
|
732
|
+
if c is None:
|
|
733
|
+
return np.einsum("...ij,...ij->...", a, b)
|
|
734
|
+
a_t = np.swapaxes(a, -1, -2)
|
|
735
|
+
return np.einsum("...ik,kl,...lm->...im", a_t, b, c)
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def _dot_np(a: FormFieldLike | ArrayLike, b: ArrayLike) -> np.ndarray:
|
|
739
|
+
if hasattr(a, "N") and getattr(a, "value_dim", None) is not None:
|
|
740
|
+
return _vector_load_form_np(cast(FormFieldLike, a), b)
|
|
741
|
+
return np.matmul(a, b)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def _transpose_last2_np(a: ArrayLike) -> np.ndarray:
|
|
745
|
+
return np.swapaxes(a, -1, -2)
|
|
746
|
+
|
|
747
|
+
|
|
521
748
|
def grad(field) -> Expr:
|
|
522
749
|
"""Return basis gradients for a scalar or vector FormField."""
|
|
523
750
|
return Expr("grad", _as_expr(field))
|
|
@@ -649,6 +876,62 @@ def _call_user(fn, *args, params):
|
|
|
649
876
|
return fn(*args)
|
|
650
877
|
|
|
651
878
|
|
|
879
|
+
@dataclass(frozen=True)
|
|
880
|
+
class KernelSpec:
|
|
881
|
+
kind: str
|
|
882
|
+
domain: str
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
class TaggedKernel:
|
|
886
|
+
def __init__(self, fn, spec: KernelSpec):
|
|
887
|
+
self._fn = fn
|
|
888
|
+
self._ff_spec = spec
|
|
889
|
+
self._ff_kind = spec.kind
|
|
890
|
+
self._ff_domain = spec.domain
|
|
891
|
+
update_wrapper(self, fn)
|
|
892
|
+
self.__wrapped__ = fn
|
|
893
|
+
|
|
894
|
+
def __call__(self, *args, **kwargs):
|
|
895
|
+
return self._fn(*args, **kwargs)
|
|
896
|
+
|
|
897
|
+
def __repr__(self) -> str:
|
|
898
|
+
return f"TaggedKernel(kind={self._ff_kind!r}, domain={self._ff_domain!r})"
|
|
899
|
+
|
|
900
|
+
@property
|
|
901
|
+
def spec(self) -> KernelSpec:
|
|
902
|
+
return self._ff_spec
|
|
903
|
+
|
|
904
|
+
@property
|
|
905
|
+
def kind(self) -> str:
|
|
906
|
+
return self._ff_kind
|
|
907
|
+
|
|
908
|
+
@property
|
|
909
|
+
def domain(self) -> str:
|
|
910
|
+
return self._ff_domain
|
|
911
|
+
|
|
912
|
+
def __hash__(self) -> int:
|
|
913
|
+
return hash(self._fn)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def _tag_form(fn, *, kind: str, domain: str):
|
|
917
|
+
spec = KernelSpec(kind=kind, domain=domain)
|
|
918
|
+
fn._ff_spec = spec
|
|
919
|
+
fn._ff_kind = kind
|
|
920
|
+
fn._ff_domain = domain
|
|
921
|
+
return fn
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def kernel(*, kind: str, domain: str = "volume"):
|
|
925
|
+
"""
|
|
926
|
+
Decorator to tag raw kernels with kind/domain metadata for assembly inference.
|
|
927
|
+
"""
|
|
928
|
+
def _deco(fn):
|
|
929
|
+
spec = KernelSpec(kind=kind, domain=domain)
|
|
930
|
+
return TaggedKernel(fn, spec)
|
|
931
|
+
|
|
932
|
+
return _deco
|
|
933
|
+
|
|
934
|
+
|
|
652
935
|
def compile_bilinear(fn):
|
|
653
936
|
"""get_compiled a bilinear weak form (u, v, params) -> Expr into a kernel."""
|
|
654
937
|
if isinstance(fn, Expr):
|
|
@@ -658,9 +941,10 @@ def compile_bilinear(fn):
|
|
|
658
941
|
v = test_ref()
|
|
659
942
|
p = param_ref()
|
|
660
943
|
expr = _call_user(fn, u, v, params=p)
|
|
661
|
-
|
|
662
|
-
if not isinstance(
|
|
944
|
+
expr_raw = _as_expr(expr)
|
|
945
|
+
if not isinstance(expr_raw, Expr):
|
|
663
946
|
raise TypeError("Bilinear form must return an Expr.")
|
|
947
|
+
expr = cast(Expr, expr_raw)
|
|
664
948
|
|
|
665
949
|
volume_count = _count_op(expr, "volume_measure")
|
|
666
950
|
surface_count = _count_op(expr, "surface_measure")
|
|
@@ -676,8 +960,8 @@ def compile_bilinear(fn):
|
|
|
676
960
|
def _form(ctx, params):
|
|
677
961
|
return eval_with_plan(plan, ctx, params)
|
|
678
962
|
|
|
679
|
-
_form._includes_measure = True
|
|
680
|
-
return _form
|
|
963
|
+
_form._includes_measure = True # type: ignore[attr-defined]
|
|
964
|
+
return _tag_form(_form, kind="bilinear", domain="volume")
|
|
681
965
|
|
|
682
966
|
|
|
683
967
|
def compile_linear(fn):
|
|
@@ -688,9 +972,10 @@ def compile_linear(fn):
|
|
|
688
972
|
v = test_ref()
|
|
689
973
|
p = param_ref()
|
|
690
974
|
expr = _call_user(fn, v, params=p)
|
|
691
|
-
|
|
692
|
-
if not isinstance(
|
|
975
|
+
expr_raw = _as_expr(expr)
|
|
976
|
+
if not isinstance(expr_raw, Expr):
|
|
693
977
|
raise TypeError("Linear form must return an Expr.")
|
|
978
|
+
expr = cast(Expr, expr_raw)
|
|
694
979
|
|
|
695
980
|
volume_count = _count_op(expr, "volume_measure")
|
|
696
981
|
surface_count = _count_op(expr, "surface_measure")
|
|
@@ -706,8 +991,8 @@ def compile_linear(fn):
|
|
|
706
991
|
def _form(ctx, params):
|
|
707
992
|
return eval_with_plan(plan, ctx, params)
|
|
708
993
|
|
|
709
|
-
_form._includes_measure = True
|
|
710
|
-
return _form
|
|
994
|
+
_form._includes_measure = True # type: ignore[attr-defined]
|
|
995
|
+
return _tag_form(_form, kind="linear", domain="volume")
|
|
711
996
|
|
|
712
997
|
|
|
713
998
|
def _expr_contains(expr: Expr, op: str) -> bool:
|
|
@@ -795,6 +1080,7 @@ def eval_with_plan(
|
|
|
795
1080
|
nodes = plan.nodes
|
|
796
1081
|
index = plan.index
|
|
797
1082
|
vals: list[Any] = [None] * len(nodes)
|
|
1083
|
+
ctx_w = cast(WeakFormContext, ctx)
|
|
798
1084
|
|
|
799
1085
|
def get(obj):
|
|
800
1086
|
if isinstance(obj, Expr):
|
|
@@ -825,7 +1111,7 @@ def eval_with_plan(
|
|
|
825
1111
|
if op == "value":
|
|
826
1112
|
ref = args[0]
|
|
827
1113
|
assert isinstance(ref, FieldRef)
|
|
828
|
-
field = _eval_field(ref,
|
|
1114
|
+
field = _eval_field(ref, ctx_w, params)
|
|
829
1115
|
if ref.role == "unknown":
|
|
830
1116
|
vals[i] = _eval_unknown_value(ref, field, u_elem)
|
|
831
1117
|
else:
|
|
@@ -834,7 +1120,7 @@ def eval_with_plan(
|
|
|
834
1120
|
if op == "grad":
|
|
835
1121
|
ref = args[0]
|
|
836
1122
|
assert isinstance(ref, FieldRef)
|
|
837
|
-
field = _eval_field(ref,
|
|
1123
|
+
field = _eval_field(ref, ctx_w, params)
|
|
838
1124
|
if ref.role == "unknown":
|
|
839
1125
|
vals[i] = _eval_unknown_grad(ref, field, u_elem)
|
|
840
1126
|
else:
|
|
@@ -879,7 +1165,7 @@ def eval_with_plan(
|
|
|
879
1165
|
if op == "sym_grad":
|
|
880
1166
|
ref = args[0]
|
|
881
1167
|
assert isinstance(ref, FieldRef)
|
|
882
|
-
field = _eval_field(ref,
|
|
1168
|
+
field = _eval_field(ref, ctx_w, params)
|
|
883
1169
|
if ref.role == "unknown":
|
|
884
1170
|
if u_elem is None:
|
|
885
1171
|
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
@@ -943,7 +1229,7 @@ def eval_with_plan(
|
|
|
943
1229
|
if op == "dot":
|
|
944
1230
|
ref = args[0]
|
|
945
1231
|
if isinstance(ref, FieldRef):
|
|
946
|
-
vals[i] = _ops.dot(_eval_field(ref,
|
|
1232
|
+
vals[i] = _ops.dot(_eval_field(ref, ctx_w, params), get(args[1]))
|
|
947
1233
|
else:
|
|
948
1234
|
a = get(args[0])
|
|
949
1235
|
b = get(args[1])
|
|
@@ -961,7 +1247,7 @@ def eval_with_plan(
|
|
|
961
1247
|
if op == "sdot":
|
|
962
1248
|
ref = args[0]
|
|
963
1249
|
if isinstance(ref, FieldRef):
|
|
964
|
-
vals[i] = _ops.dot(_eval_field(ref,
|
|
1250
|
+
vals[i] = _ops.dot(_eval_field(ref, ctx_w, params), get(args[1]))
|
|
965
1251
|
else:
|
|
966
1252
|
a = get(args[0])
|
|
967
1253
|
b = get(args[1])
|
|
@@ -1004,7 +1290,7 @@ def eval_with_plan(
|
|
|
1004
1290
|
assert isinstance(ref, FieldRef)
|
|
1005
1291
|
if isinstance(args[1], FieldRef):
|
|
1006
1292
|
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
1007
|
-
v_field = _eval_field(ref,
|
|
1293
|
+
v_field = _eval_field(ref, ctx_w, params)
|
|
1008
1294
|
s = get(args[1])
|
|
1009
1295
|
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
1010
1296
|
# action maps a test field with a scalar/vector expression into nodal space.
|
|
@@ -1022,7 +1308,7 @@ def eval_with_plan(
|
|
|
1022
1308
|
if op == "gaction":
|
|
1023
1309
|
ref = args[0]
|
|
1024
1310
|
assert isinstance(ref, FieldRef)
|
|
1025
|
-
v_field = _eval_field(ref,
|
|
1311
|
+
v_field = _eval_field(ref, ctx_w, params)
|
|
1026
1312
|
q = get(args[1])
|
|
1027
1313
|
# gaction maps a flux-like expression to nodal space via test gradients.
|
|
1028
1314
|
if v_field.gradN.ndim != 3:
|
|
@@ -1043,7 +1329,10 @@ def eval_with_plan(
|
|
|
1043
1329
|
continue
|
|
1044
1330
|
if op == "einsum":
|
|
1045
1331
|
subscripts = args[0]
|
|
1046
|
-
operands = [
|
|
1332
|
+
operands = [
|
|
1333
|
+
(jnp.asarray(arg) if isinstance(arg, tuple) else arg)
|
|
1334
|
+
for arg in (get(arg) for arg in args[1:])
|
|
1335
|
+
]
|
|
1047
1336
|
vals[i] = jnp.einsum(subscripts, *operands)
|
|
1048
1337
|
continue
|
|
1049
1338
|
|
|
@@ -1052,6 +1341,294 @@ def eval_with_plan(
|
|
|
1052
1341
|
return vals[index[id(plan.expr)]]
|
|
1053
1342
|
|
|
1054
1343
|
|
|
1344
|
+
def eval_with_plan_numpy(
|
|
1345
|
+
plan: EvalPlan,
|
|
1346
|
+
ctx: VolumeContext | SurfaceContext,
|
|
1347
|
+
params: ParamsLike,
|
|
1348
|
+
u_elem: UElement | None = None,
|
|
1349
|
+
):
|
|
1350
|
+
nodes = plan.nodes
|
|
1351
|
+
index = plan.index
|
|
1352
|
+
vals: list[Any] = [None] * len(nodes)
|
|
1353
|
+
ctx_w = cast(WeakFormContext, ctx)
|
|
1354
|
+
|
|
1355
|
+
def get(obj):
|
|
1356
|
+
if isinstance(obj, Expr):
|
|
1357
|
+
return vals[index[id(obj)]]
|
|
1358
|
+
if isinstance(obj, FieldRef):
|
|
1359
|
+
raise TypeError(
|
|
1360
|
+
"FieldRef must be wrapped with .val/.grad/.sym_grad or used as the first arg of dot/action."
|
|
1361
|
+
)
|
|
1362
|
+
if isinstance(obj, ParamRef):
|
|
1363
|
+
return params
|
|
1364
|
+
return obj
|
|
1365
|
+
|
|
1366
|
+
for i, node in enumerate(nodes):
|
|
1367
|
+
op = node.op
|
|
1368
|
+
args = node.args
|
|
1369
|
+
|
|
1370
|
+
if op == "lit":
|
|
1371
|
+
vals[i] = args[0]
|
|
1372
|
+
continue
|
|
1373
|
+
if op == "getattr":
|
|
1374
|
+
base = get(args[0])
|
|
1375
|
+
name = args[1]
|
|
1376
|
+
if isinstance(base, dict):
|
|
1377
|
+
vals[i] = base[name]
|
|
1378
|
+
else:
|
|
1379
|
+
vals[i] = getattr(base, name)
|
|
1380
|
+
continue
|
|
1381
|
+
if op == "value":
|
|
1382
|
+
ref = args[0]
|
|
1383
|
+
assert isinstance(ref, FieldRef)
|
|
1384
|
+
field = _eval_field_np(ref, ctx_w, params)
|
|
1385
|
+
if ref.role == "unknown":
|
|
1386
|
+
vals[i] = _eval_unknown_value_np(ref, field, u_elem)
|
|
1387
|
+
else:
|
|
1388
|
+
vals[i] = field.N
|
|
1389
|
+
continue
|
|
1390
|
+
if op == "grad":
|
|
1391
|
+
ref = args[0]
|
|
1392
|
+
assert isinstance(ref, FieldRef)
|
|
1393
|
+
field = _eval_field_np(ref, ctx_w, params)
|
|
1394
|
+
if ref.role == "unknown":
|
|
1395
|
+
vals[i] = _eval_unknown_grad_np(ref, field, u_elem)
|
|
1396
|
+
else:
|
|
1397
|
+
vals[i] = field.gradN
|
|
1398
|
+
continue
|
|
1399
|
+
if op == "pow":
|
|
1400
|
+
base = get(args[0])
|
|
1401
|
+
exp = get(args[1])
|
|
1402
|
+
vals[i] = base**exp
|
|
1403
|
+
continue
|
|
1404
|
+
if op == "eye":
|
|
1405
|
+
vals[i] = np.eye(int(args[0]))
|
|
1406
|
+
continue
|
|
1407
|
+
if op == "det":
|
|
1408
|
+
vals[i] = np.linalg.det(get(args[0]))
|
|
1409
|
+
continue
|
|
1410
|
+
if op == "inv":
|
|
1411
|
+
vals[i] = np.linalg.inv(get(args[0]))
|
|
1412
|
+
continue
|
|
1413
|
+
if op == "transpose":
|
|
1414
|
+
vals[i] = np.swapaxes(get(args[0]), -1, -2)
|
|
1415
|
+
continue
|
|
1416
|
+
if op == "log":
|
|
1417
|
+
vals[i] = np.log(get(args[0]))
|
|
1418
|
+
continue
|
|
1419
|
+
if op == "surface_normal":
|
|
1420
|
+
normal = getattr(ctx, "normal", None)
|
|
1421
|
+
if normal is None:
|
|
1422
|
+
raise ValueError("surface normal is not available in context")
|
|
1423
|
+
vals[i] = normal
|
|
1424
|
+
continue
|
|
1425
|
+
if op == "surface_measure":
|
|
1426
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "detJ"):
|
|
1427
|
+
raise TypeError("surface measure requires SurfaceContext.")
|
|
1428
|
+
vals[i] = ctx.w * ctx.detJ
|
|
1429
|
+
continue
|
|
1430
|
+
if op == "volume_measure":
|
|
1431
|
+
if not hasattr(ctx, "w") or not hasattr(ctx, "test"):
|
|
1432
|
+
raise TypeError("volume measure requires VolumeContext.")
|
|
1433
|
+
vals[i] = ctx.w * ctx.test.detJ
|
|
1434
|
+
continue
|
|
1435
|
+
if op == "sym_grad":
|
|
1436
|
+
ref = args[0]
|
|
1437
|
+
assert isinstance(ref, FieldRef)
|
|
1438
|
+
field = _eval_field_np(ref, ctx_w, params)
|
|
1439
|
+
if ref.role == "unknown":
|
|
1440
|
+
if u_elem is None:
|
|
1441
|
+
raise ValueError("u_elem is required to evaluate unknown sym_grad.")
|
|
1442
|
+
u_local = _extract_unknown_elem(ref, u_elem)
|
|
1443
|
+
vals[i] = _sym_grad_u_np(field, u_local)
|
|
1444
|
+
else:
|
|
1445
|
+
vals[i] = _sym_grad_np(field)
|
|
1446
|
+
continue
|
|
1447
|
+
if op == "outer":
|
|
1448
|
+
a, b = args
|
|
1449
|
+
if not isinstance(a, FieldRef) or not isinstance(b, FieldRef):
|
|
1450
|
+
raise TypeError("outer expects FieldRef operands.")
|
|
1451
|
+
test, trial = a, b
|
|
1452
|
+
vals[i] = _basis_outer_np(test, trial, ctx, params)
|
|
1453
|
+
continue
|
|
1454
|
+
if op == "add":
|
|
1455
|
+
vals[i] = get(args[0]) + get(args[1])
|
|
1456
|
+
continue
|
|
1457
|
+
if op == "sub":
|
|
1458
|
+
vals[i] = get(args[0]) - get(args[1])
|
|
1459
|
+
continue
|
|
1460
|
+
if op == "mul":
|
|
1461
|
+
a = get(args[0])
|
|
1462
|
+
b = get(args[1])
|
|
1463
|
+
if hasattr(a, "ndim") and hasattr(b, "ndim"):
|
|
1464
|
+
if a.ndim == 1 and b.ndim == 2 and a.shape[0] == b.shape[0]:
|
|
1465
|
+
a = a[:, None]
|
|
1466
|
+
elif b.ndim == 1 and a.ndim == 2 and b.shape[0] == a.shape[0]:
|
|
1467
|
+
b = b[:, None]
|
|
1468
|
+
elif a.ndim >= 2 and b.ndim == 1 and a.shape[0] == b.shape[0]:
|
|
1469
|
+
b = b.reshape((b.shape[0],) + (1,) * (a.ndim - 1))
|
|
1470
|
+
elif b.ndim >= 2 and a.ndim == 1 and b.shape[0] == a.shape[0]:
|
|
1471
|
+
a = a.reshape((a.shape[0],) + (1,) * (b.ndim - 1))
|
|
1472
|
+
vals[i] = a * b
|
|
1473
|
+
continue
|
|
1474
|
+
if op == "matmul":
|
|
1475
|
+
a = get(args[0])
|
|
1476
|
+
b = get(args[1])
|
|
1477
|
+
if (
|
|
1478
|
+
hasattr(a, "ndim")
|
|
1479
|
+
and hasattr(b, "ndim")
|
|
1480
|
+
and a.ndim == 3
|
|
1481
|
+
and b.ndim == 3
|
|
1482
|
+
and a.shape[0] == b.shape[0]
|
|
1483
|
+
and a.shape[-1] == b.shape[-1]
|
|
1484
|
+
):
|
|
1485
|
+
vals[i] = np.einsum("qia,qja->qij", a, b)
|
|
1486
|
+
else:
|
|
1487
|
+
raise TypeError(
|
|
1488
|
+
"Expr '@' (matmul) is FEM-specific; use matmul_std(a, b) for standard matmul."
|
|
1489
|
+
)
|
|
1490
|
+
continue
|
|
1491
|
+
if op == "matmul_std":
|
|
1492
|
+
a = get(args[0])
|
|
1493
|
+
b = get(args[1])
|
|
1494
|
+
vals[i] = np.matmul(a, b)
|
|
1495
|
+
continue
|
|
1496
|
+
if op == "neg":
|
|
1497
|
+
vals[i] = -get(args[0])
|
|
1498
|
+
continue
|
|
1499
|
+
if op == "dot":
|
|
1500
|
+
ref = args[0]
|
|
1501
|
+
if isinstance(ref, FieldRef):
|
|
1502
|
+
vals[i] = _dot_np(_eval_field_np(ref, ctx_w, params), get(args[1]))
|
|
1503
|
+
else:
|
|
1504
|
+
a = get(args[0])
|
|
1505
|
+
b = get(args[1])
|
|
1506
|
+
if (
|
|
1507
|
+
hasattr(a, "ndim")
|
|
1508
|
+
and hasattr(b, "ndim")
|
|
1509
|
+
and a.ndim >= 3
|
|
1510
|
+
and b.ndim >= 3
|
|
1511
|
+
and a.shape[0] == b.shape[0]
|
|
1512
|
+
and a.shape[1] == b.shape[1]
|
|
1513
|
+
):
|
|
1514
|
+
vals[i] = np.einsum("qi...,qj...->qij...", a, b)
|
|
1515
|
+
else:
|
|
1516
|
+
vals[i] = np.matmul(a, b)
|
|
1517
|
+
continue
|
|
1518
|
+
if op == "sdot":
|
|
1519
|
+
ref = args[0]
|
|
1520
|
+
if isinstance(ref, FieldRef):
|
|
1521
|
+
vals[i] = _dot_np(_eval_field_np(ref, ctx_w, params), get(args[1]))
|
|
1522
|
+
else:
|
|
1523
|
+
a = get(args[0])
|
|
1524
|
+
b = get(args[1])
|
|
1525
|
+
if (
|
|
1526
|
+
hasattr(a, "ndim")
|
|
1527
|
+
and hasattr(b, "ndim")
|
|
1528
|
+
and a.ndim >= 3
|
|
1529
|
+
and b.ndim >= 3
|
|
1530
|
+
and a.shape[0] == b.shape[0]
|
|
1531
|
+
and a.shape[1] == b.shape[1]
|
|
1532
|
+
):
|
|
1533
|
+
vals[i] = np.einsum("qi...,qj...->qij...", a, b)
|
|
1534
|
+
else:
|
|
1535
|
+
vals[i] = np.matmul(a, b)
|
|
1536
|
+
continue
|
|
1537
|
+
if op == "ddot":
|
|
1538
|
+
if len(args) == 2:
|
|
1539
|
+
a = get(args[0])
|
|
1540
|
+
b = get(args[1])
|
|
1541
|
+
if (
|
|
1542
|
+
hasattr(a, "ndim")
|
|
1543
|
+
and hasattr(b, "ndim")
|
|
1544
|
+
and a.ndim == 3
|
|
1545
|
+
and b.ndim == 3
|
|
1546
|
+
and a.shape[0] == b.shape[0]
|
|
1547
|
+
and a.shape[1] == b.shape[1]
|
|
1548
|
+
):
|
|
1549
|
+
vals[i] = np.einsum("qik,qim->qkm", a, b)
|
|
1550
|
+
else:
|
|
1551
|
+
vals[i] = _ddot_np(a, b)
|
|
1552
|
+
else:
|
|
1553
|
+
vals[i] = _ddot_np(get(args[0]), get(args[1]), get(args[2]))
|
|
1554
|
+
continue
|
|
1555
|
+
if op == "inner":
|
|
1556
|
+
a = get(args[0])
|
|
1557
|
+
b = get(args[1])
|
|
1558
|
+
vals[i] = np.einsum("...i,...i->...", a, b)
|
|
1559
|
+
continue
|
|
1560
|
+
if op == "action":
|
|
1561
|
+
ref = args[0]
|
|
1562
|
+
assert isinstance(ref, FieldRef)
|
|
1563
|
+
if isinstance(args[1], FieldRef):
|
|
1564
|
+
raise ValueError("action expects a scalar expression; use u.val for unknowns.")
|
|
1565
|
+
v_field = _eval_field_np(ref, ctx_w, params)
|
|
1566
|
+
s = get(args[1])
|
|
1567
|
+
value_dim = int(getattr(v_field, "value_dim", 1))
|
|
1568
|
+
if value_dim == 1:
|
|
1569
|
+
if v_field.N.ndim != 2:
|
|
1570
|
+
raise ValueError("action expects scalar test field with N shape (q, ndofs).")
|
|
1571
|
+
if hasattr(s, "ndim") and s.ndim not in (0, 1):
|
|
1572
|
+
raise ValueError("action expects scalar s with shape (q,) or scalar.")
|
|
1573
|
+
vals[i] = v_field.N * s
|
|
1574
|
+
else:
|
|
1575
|
+
if hasattr(s, "ndim") and s.ndim not in (1, 2):
|
|
1576
|
+
raise ValueError("action expects vector s with shape (q, dim) or (dim,).")
|
|
1577
|
+
vals[i] = _dot_np(v_field, s)
|
|
1578
|
+
continue
|
|
1579
|
+
if op == "gaction":
|
|
1580
|
+
ref = args[0]
|
|
1581
|
+
assert isinstance(ref, FieldRef)
|
|
1582
|
+
v_field = _eval_field_np(ref, ctx_w, params)
|
|
1583
|
+
q = get(args[1])
|
|
1584
|
+
if v_field.gradN.ndim != 3:
|
|
1585
|
+
raise ValueError("gaction expects test gradient with shape (q, ndofs, dim).")
|
|
1586
|
+
if not hasattr(q, "ndim"):
|
|
1587
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
1588
|
+
if q.ndim == 2:
|
|
1589
|
+
vals[i] = np.einsum("qaj,qj->qa", v_field.gradN, q)
|
|
1590
|
+
elif q.ndim == 3:
|
|
1591
|
+
if int(getattr(v_field, "value_dim", 1)) == 1:
|
|
1592
|
+
raise ValueError("gaction tensor flux requires vector test field.")
|
|
1593
|
+
vals[i] = np.einsum("qij,qaj->qai", q, v_field.gradN).reshape(q.shape[0], -1)
|
|
1594
|
+
else:
|
|
1595
|
+
raise ValueError("gaction expects q with shape (q, dim) or (q, dim, dim).")
|
|
1596
|
+
continue
|
|
1597
|
+
if op == "transpose_last2":
|
|
1598
|
+
vals[i] = _transpose_last2_np(get(args[0]))
|
|
1599
|
+
continue
|
|
1600
|
+
if op == "einsum":
|
|
1601
|
+
subscripts = args[0]
|
|
1602
|
+
operands = [
|
|
1603
|
+
(np.asarray(arg) if isinstance(arg, tuple) else arg)
|
|
1604
|
+
for arg in (get(arg) for arg in args[1:])
|
|
1605
|
+
]
|
|
1606
|
+
if "..." not in subscripts:
|
|
1607
|
+
has_extra = False
|
|
1608
|
+
parts = subscripts.split("->")
|
|
1609
|
+
in_terms = parts[0].split(",")
|
|
1610
|
+
out_term = parts[1] if len(parts) > 1 else None
|
|
1611
|
+
updated_terms = []
|
|
1612
|
+
for term, opnd in zip(in_terms, operands):
|
|
1613
|
+
if hasattr(opnd, "ndim") and opnd.ndim > len(term):
|
|
1614
|
+
has_extra = True
|
|
1615
|
+
updated_terms.append(term + "...")
|
|
1616
|
+
else:
|
|
1617
|
+
updated_terms.append(term)
|
|
1618
|
+
if has_extra:
|
|
1619
|
+
if out_term is not None:
|
|
1620
|
+
out_term = out_term + "..."
|
|
1621
|
+
subscripts = ",".join(updated_terms) + "->" + out_term
|
|
1622
|
+
else:
|
|
1623
|
+
subscripts = ",".join(updated_terms)
|
|
1624
|
+
vals[i] = np.einsum(subscripts, *operands)
|
|
1625
|
+
continue
|
|
1626
|
+
|
|
1627
|
+
raise ValueError(f"Unknown Expr op: {op}")
|
|
1628
|
+
|
|
1629
|
+
return vals[index[id(plan.expr)]]
|
|
1630
|
+
|
|
1631
|
+
|
|
1055
1632
|
def compile_surface_linear(fn):
|
|
1056
1633
|
"""get_compiled a surface linear form into a kernel (ctx, params) -> ndarray."""
|
|
1057
1634
|
if isinstance(fn, Expr):
|
|
@@ -1061,9 +1638,10 @@ def compile_surface_linear(fn):
|
|
|
1061
1638
|
p = param_ref()
|
|
1062
1639
|
expr = _call_user(fn, v, params=p)
|
|
1063
1640
|
|
|
1064
|
-
|
|
1065
|
-
if not isinstance(
|
|
1641
|
+
expr_raw = _as_expr(expr)
|
|
1642
|
+
if not isinstance(expr_raw, Expr):
|
|
1066
1643
|
raise ValueError("Surface linear form must return an Expr; use ds() in the expression.")
|
|
1644
|
+
expr = cast(Expr, expr_raw)
|
|
1067
1645
|
|
|
1068
1646
|
surface_count = _count_op(expr, "surface_measure")
|
|
1069
1647
|
volume_count = _count_op(expr, "volume_measure")
|
|
@@ -1080,7 +1658,40 @@ def compile_surface_linear(fn):
|
|
|
1080
1658
|
return eval_with_plan(plan, ctx, params)
|
|
1081
1659
|
|
|
1082
1660
|
_form._includes_measure = True # type: ignore[attr-defined]
|
|
1083
|
-
return _form
|
|
1661
|
+
return _tag_form(_form, kind="linear", domain="surface")
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
def compile_surface_bilinear(fn):
|
|
1665
|
+
"""get_compiled a surface bilinear form into a kernel (ctx, params) -> ndarray."""
|
|
1666
|
+
if isinstance(fn, Expr):
|
|
1667
|
+
expr = fn
|
|
1668
|
+
else:
|
|
1669
|
+
v = test_ref()
|
|
1670
|
+
u = trial_ref()
|
|
1671
|
+
p = param_ref()
|
|
1672
|
+
expr = _call_user(fn, u, v, params=p)
|
|
1673
|
+
|
|
1674
|
+
expr_raw = _as_expr(expr)
|
|
1675
|
+
if not isinstance(expr_raw, Expr):
|
|
1676
|
+
raise ValueError("Surface bilinear form must return an Expr; use ds() in the expression.")
|
|
1677
|
+
expr = cast(Expr, expr_raw)
|
|
1678
|
+
|
|
1679
|
+
surface_count = _count_op(expr, "surface_measure")
|
|
1680
|
+
volume_count = _count_op(expr, "volume_measure")
|
|
1681
|
+
if surface_count == 0:
|
|
1682
|
+
raise ValueError("Surface bilinear form must include ds().")
|
|
1683
|
+
if surface_count > 1:
|
|
1684
|
+
raise ValueError("Surface bilinear form must include ds() exactly once.")
|
|
1685
|
+
if volume_count > 0:
|
|
1686
|
+
raise ValueError("Surface bilinear form must not include dOmega().")
|
|
1687
|
+
|
|
1688
|
+
plan = make_eval_plan(expr)
|
|
1689
|
+
|
|
1690
|
+
def _form(ctx, params):
|
|
1691
|
+
return eval_with_plan(plan, ctx, params)
|
|
1692
|
+
|
|
1693
|
+
_form._includes_measure = True # type: ignore[attr-defined]
|
|
1694
|
+
return _tag_form(_form, kind="bilinear", domain="surface")
|
|
1084
1695
|
|
|
1085
1696
|
|
|
1086
1697
|
class LinearForm:
|
|
@@ -1144,9 +1755,10 @@ def compile_residual(fn):
|
|
|
1144
1755
|
u = unknown_ref()
|
|
1145
1756
|
p = param_ref()
|
|
1146
1757
|
expr = _call_user(fn, v, u, params=p)
|
|
1147
|
-
|
|
1148
|
-
if not isinstance(
|
|
1758
|
+
expr_raw = _as_expr(expr)
|
|
1759
|
+
if not isinstance(expr_raw, Expr):
|
|
1149
1760
|
raise TypeError("Residual form must return an Expr.")
|
|
1761
|
+
expr = cast(Expr, expr_raw)
|
|
1150
1762
|
|
|
1151
1763
|
volume_count = _count_op(expr, "volume_measure")
|
|
1152
1764
|
surface_count = _count_op(expr, "surface_measure")
|
|
@@ -1162,8 +1774,8 @@ def compile_residual(fn):
|
|
|
1162
1774
|
def _form(ctx, u_elem, params):
|
|
1163
1775
|
return eval_with_plan(plan, ctx, params, u_elem=u_elem)
|
|
1164
1776
|
|
|
1165
|
-
_form._includes_measure = True
|
|
1166
|
-
return _form
|
|
1777
|
+
_form._includes_measure = True # type: ignore[attr-defined]
|
|
1778
|
+
return _tag_form(_form, kind="residual", domain="volume")
|
|
1167
1779
|
|
|
1168
1780
|
|
|
1169
1781
|
def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
@@ -1194,11 +1806,158 @@ def compile_mixed_residual(residuals: dict[str, Callable]):
|
|
|
1194
1806
|
if surface_count > 0:
|
|
1195
1807
|
raise ValueError(f"Mixed residual '{name}' must not include ds().")
|
|
1196
1808
|
|
|
1809
|
+
class _MixedContextView:
|
|
1810
|
+
def __init__(self, ctx, field_name: str):
|
|
1811
|
+
self._ctx = ctx
|
|
1812
|
+
self.fields = ctx.fields
|
|
1813
|
+
self.x_q = ctx.x_q
|
|
1814
|
+
self.w = ctx.w
|
|
1815
|
+
self.elem_id = ctx.elem_id
|
|
1816
|
+
self.trial_fields = ctx.trial_fields
|
|
1817
|
+
self.test_fields = ctx.test_fields
|
|
1818
|
+
self.unknown_fields = ctx.unknown_fields
|
|
1819
|
+
self.unknown = ctx.unknown
|
|
1820
|
+
|
|
1821
|
+
pair = ctx.fields[field_name]
|
|
1822
|
+
self.test = pair.test
|
|
1823
|
+
self.trial = pair.trial
|
|
1824
|
+
self.v = pair.test
|
|
1825
|
+
self.u = pair.trial
|
|
1826
|
+
|
|
1827
|
+
if hasattr(ctx, "normal"):
|
|
1828
|
+
self.normal = ctx.normal
|
|
1829
|
+
|
|
1830
|
+
def __getattr__(self, name: str):
|
|
1831
|
+
return getattr(self._ctx, name)
|
|
1832
|
+
|
|
1833
|
+
def _form(ctx, u_elem, params):
|
|
1834
|
+
return {
|
|
1835
|
+
name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
|
|
1836
|
+
for name, plan in plans.items()
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
_form._includes_measure = includes_measure # type: ignore[attr-defined]
|
|
1840
|
+
return _tag_form(_form, kind="residual", domain="volume")
|
|
1841
|
+
|
|
1842
|
+
|
|
1843
|
+
def compile_mixed_surface_residual(residuals: dict[str, Callable]):
|
|
1844
|
+
"""get_compiled mixed surface residuals keyed by field name."""
|
|
1845
|
+
compiled = {}
|
|
1846
|
+
plans = {}
|
|
1847
|
+
includes_measure = {}
|
|
1848
|
+
for name, fn in residuals.items():
|
|
1849
|
+
if isinstance(fn, Expr):
|
|
1850
|
+
expr = fn
|
|
1851
|
+
else:
|
|
1852
|
+
v = test_ref(name)
|
|
1853
|
+
u = unknown_ref(name)
|
|
1854
|
+
p = param_ref()
|
|
1855
|
+
expr = _call_user(fn, v, u, params=p)
|
|
1856
|
+
expr = _as_expr(expr)
|
|
1857
|
+
if not isinstance(expr, Expr):
|
|
1858
|
+
raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
|
|
1859
|
+
compiled[name] = expr
|
|
1860
|
+
plans[name] = make_eval_plan(expr)
|
|
1861
|
+
volume_count = _count_op(compiled[name], "volume_measure")
|
|
1862
|
+
surface_count = _count_op(compiled[name], "surface_measure")
|
|
1863
|
+
includes_measure[name] = surface_count == 1
|
|
1864
|
+
if surface_count == 0:
|
|
1865
|
+
raise ValueError(f"Mixed surface residual '{name}' must include ds().")
|
|
1866
|
+
if surface_count > 1:
|
|
1867
|
+
raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
|
|
1868
|
+
if volume_count > 0:
|
|
1869
|
+
raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
|
|
1870
|
+
|
|
1871
|
+
class _MixedContextView:
|
|
1872
|
+
def __init__(self, ctx, field_name: str):
|
|
1873
|
+
self._ctx = ctx
|
|
1874
|
+
self.fields = ctx.fields
|
|
1875
|
+
self.x_q = ctx.x_q
|
|
1876
|
+
self.w = ctx.w
|
|
1877
|
+
self.detJ = ctx.detJ
|
|
1878
|
+
self.normal = getattr(ctx, "normal", None)
|
|
1879
|
+
self.trial_fields = ctx.trial_fields
|
|
1880
|
+
self.test_fields = ctx.test_fields
|
|
1881
|
+
self.unknown_fields = ctx.unknown_fields
|
|
1882
|
+
self.unknown = getattr(ctx, "unknown", None)
|
|
1883
|
+
|
|
1884
|
+
pair = ctx.fields[field_name]
|
|
1885
|
+
self.test = pair.test
|
|
1886
|
+
self.trial = pair.trial
|
|
1887
|
+
self.v = pair.test
|
|
1888
|
+
self.u = pair.trial
|
|
1889
|
+
|
|
1890
|
+
def __getattr__(self, name: str):
|
|
1891
|
+
return getattr(self._ctx, name)
|
|
1892
|
+
|
|
1893
|
+
def _form(ctx, u_elem, params):
|
|
1894
|
+
return {
|
|
1895
|
+
name: eval_with_plan(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
|
|
1896
|
+
for name, plan in plans.items()
|
|
1897
|
+
}
|
|
1898
|
+
|
|
1899
|
+
_form._includes_measure = includes_measure # type: ignore[attr-defined]
|
|
1900
|
+
return _tag_form(_form, kind="residual", domain="surface")
|
|
1901
|
+
|
|
1902
|
+
|
|
1903
|
+
def compile_mixed_surface_residual_numpy(residuals: dict[str, Callable]):
|
|
1904
|
+
"""Mixed surface residual compiled for numpy evaluation."""
|
|
1905
|
+
compiled = {}
|
|
1906
|
+
plans = {}
|
|
1907
|
+
includes_measure = {}
|
|
1908
|
+
for name, fn in residuals.items():
|
|
1909
|
+
if isinstance(fn, Expr):
|
|
1910
|
+
expr = fn
|
|
1911
|
+
else:
|
|
1912
|
+
v = test_ref(name)
|
|
1913
|
+
u = unknown_ref(name)
|
|
1914
|
+
p = param_ref()
|
|
1915
|
+
expr = _call_user(fn, v, u, params=p)
|
|
1916
|
+
expr = _as_expr(expr)
|
|
1917
|
+
if not isinstance(expr, Expr):
|
|
1918
|
+
raise TypeError(f"Mixed surface residual '{name}' must return an Expr.")
|
|
1919
|
+
compiled[name] = expr
|
|
1920
|
+
plans[name] = make_eval_plan(expr)
|
|
1921
|
+
volume_count = _count_op(compiled[name], "volume_measure")
|
|
1922
|
+
surface_count = _count_op(compiled[name], "surface_measure")
|
|
1923
|
+
includes_measure[name] = surface_count == 1
|
|
1924
|
+
if surface_count == 0:
|
|
1925
|
+
raise ValueError(f"Mixed surface residual '{name}' must include ds().")
|
|
1926
|
+
if surface_count > 1:
|
|
1927
|
+
raise ValueError(f"Mixed surface residual '{name}' must include ds() exactly once.")
|
|
1928
|
+
if volume_count > 0:
|
|
1929
|
+
raise ValueError(f"Mixed surface residual '{name}' must not include dOmega().")
|
|
1930
|
+
|
|
1931
|
+
class _MixedContextView:
|
|
1932
|
+
def __init__(self, ctx, field_name: str):
|
|
1933
|
+
self._ctx = ctx
|
|
1934
|
+
self.fields = ctx.fields
|
|
1935
|
+
self.x_q = ctx.x_q
|
|
1936
|
+
self.w = ctx.w
|
|
1937
|
+
self.detJ = ctx.detJ
|
|
1938
|
+
self.normal = getattr(ctx, "normal", None)
|
|
1939
|
+
self.trial_fields = ctx.trial_fields
|
|
1940
|
+
self.test_fields = ctx.test_fields
|
|
1941
|
+
self.unknown_fields = ctx.unknown_fields
|
|
1942
|
+
self.unknown = getattr(ctx, "unknown", None)
|
|
1943
|
+
|
|
1944
|
+
pair = ctx.fields[field_name]
|
|
1945
|
+
self.test = pair.test
|
|
1946
|
+
self.trial = pair.trial
|
|
1947
|
+
self.v = pair.test
|
|
1948
|
+
self.u = pair.trial
|
|
1949
|
+
|
|
1950
|
+
def __getattr__(self, name: str):
|
|
1951
|
+
return getattr(self._ctx, name)
|
|
1952
|
+
|
|
1197
1953
|
def _form(ctx, u_elem, params):
|
|
1198
|
-
return {
|
|
1954
|
+
return {
|
|
1955
|
+
name: eval_with_plan_numpy(plan, _MixedContextView(ctx, name), params, u_elem=u_elem)
|
|
1956
|
+
for name, plan in plans.items()
|
|
1957
|
+
}
|
|
1199
1958
|
|
|
1200
|
-
_form._includes_measure = includes_measure
|
|
1201
|
-
return _form
|
|
1959
|
+
_form._includes_measure = includes_measure # type: ignore[attr-defined]
|
|
1960
|
+
return _tag_form(_form, kind="residual", domain="surface")
|
|
1202
1961
|
|
|
1203
1962
|
|
|
1204
1963
|
class MixedWeakForm:
|
|
@@ -1213,6 +1972,20 @@ class MixedWeakForm:
|
|
|
1213
1972
|
return compile_mixed_residual(self.residuals)
|
|
1214
1973
|
|
|
1215
1974
|
|
|
1975
|
+
def make_mixed_residuals(residuals: dict[str, Callable] | None = None, **kwargs) -> dict[str, Callable]:
|
|
1976
|
+
"""
|
|
1977
|
+
Helper to build mixed residual dictionaries.
|
|
1978
|
+
|
|
1979
|
+
Example:
|
|
1980
|
+
res = make_mixed_residuals(u=res_u, p=res_p)
|
|
1981
|
+
"""
|
|
1982
|
+
if residuals is not None and kwargs:
|
|
1983
|
+
raise ValueError("Pass either residuals dict or keyword residuals, not both.")
|
|
1984
|
+
if residuals is None:
|
|
1985
|
+
return dict(kwargs)
|
|
1986
|
+
return dict(residuals)
|
|
1987
|
+
|
|
1988
|
+
|
|
1216
1989
|
def _eval_expr(
|
|
1217
1990
|
expr: Expr,
|
|
1218
1991
|
ctx: VolumeContext | SurfaceContext,
|
|
@@ -1230,13 +2003,19 @@ __all__ = [
|
|
|
1230
2003
|
"trial_ref",
|
|
1231
2004
|
"test_ref",
|
|
1232
2005
|
"unknown_ref",
|
|
2006
|
+
"zero_ref",
|
|
1233
2007
|
"param_ref",
|
|
1234
2008
|
"Params",
|
|
1235
2009
|
"MixedWeakForm",
|
|
2010
|
+
"make_mixed_residuals",
|
|
2011
|
+
"kernel",
|
|
1236
2012
|
"ResidualForm",
|
|
1237
2013
|
"compile_bilinear",
|
|
1238
2014
|
"compile_linear",
|
|
1239
2015
|
"compile_residual",
|
|
2016
|
+
"compile_surface_bilinear",
|
|
2017
|
+
"compile_mixed_surface_residual",
|
|
2018
|
+
"compile_mixed_surface_residual_numpy",
|
|
1240
2019
|
"compile_mixed_residual",
|
|
1241
2020
|
"grad",
|
|
1242
2021
|
"sym_grad",
|