fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/assembly.py
CHANGED
|
@@ -1,21 +1,362 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Callable, Protocol, TypeVar,
|
|
2
|
+
from typing import Any, Callable, Literal, Mapping, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union, cast
|
|
3
3
|
import numpy as np
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
|
|
7
7
|
from ..mesh import HexMesh, StructuredHexBox
|
|
8
|
+
from .dtypes import INDEX_DTYPE
|
|
8
9
|
from .forms import FormContext
|
|
9
10
|
from .space import FESpaceBase
|
|
10
11
|
|
|
11
12
|
# Shared call signatures for kernels/forms
|
|
12
|
-
Array = jnp.ndarray
|
|
13
|
+
Array: TypeAlias = jnp.ndarray
|
|
13
14
|
P = TypeVar("P")
|
|
14
15
|
|
|
15
|
-
|
|
16
|
+
FormKernel: TypeAlias = Callable[[FormContext, P], Array]
|
|
17
|
+
# Form kernels return integrands; element kernels return integrated element arrays.
|
|
18
|
+
Kernel: TypeAlias = Callable[[FormContext, P], Array]
|
|
19
|
+
ResidualInput: TypeAlias = Array | Mapping[str, Array]
|
|
20
|
+
ResidualValue: TypeAlias = Array | Mapping[str, Array]
|
|
16
21
|
ResidualForm = Callable[[FormContext, Array, P], Array]
|
|
22
|
+
ResidualFormLike = Callable[[FormContext, ResidualInput, P], ResidualValue]
|
|
17
23
|
ElementDofMapper = Callable[[Array], Array]
|
|
18
24
|
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from ..solver import FluxSparseMatrix, SparsityPattern
|
|
27
|
+
else:
|
|
28
|
+
FluxSparseMatrix = Any
|
|
29
|
+
SparsityPattern = Any
|
|
30
|
+
|
|
31
|
+
SparseCOO: TypeAlias = tuple[Array, Array, Array, int]
|
|
32
|
+
LinearCOO: TypeAlias = tuple[Array, Array, int]
|
|
33
|
+
JacobianReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
|
|
34
|
+
BilinearReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
|
|
35
|
+
LinearReturn: TypeAlias = Union[Array, LinearCOO]
|
|
36
|
+
MassReturn: TypeAlias = Union[FluxSparseMatrix, Array]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ElementBilinearKernel(Protocol):
|
|
40
|
+
def __call__(self, ctx: FormContext) -> Array: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ElementLinearKernel(Protocol):
|
|
44
|
+
def __call__(self, ctx: FormContext) -> Array: ...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ElementResidualKernel(Protocol):
|
|
48
|
+
def __call__(self, ctx: FormContext, u_elem: Array) -> Array: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ElementJacobianKernel(Protocol):
|
|
52
|
+
def __call__(self, u_elem: Array, ctx: FormContext) -> Array: ...
|
|
53
|
+
|
|
54
|
+
ElementKernel: TypeAlias = (
|
|
55
|
+
ElementBilinearKernel
|
|
56
|
+
| ElementLinearKernel
|
|
57
|
+
| ElementResidualKernel
|
|
58
|
+
| ElementJacobianKernel
|
|
59
|
+
)
|
|
60
|
+
def _get_pattern(space: SpaceLike, *, with_idx: bool) -> SparsityPattern | None:
|
|
61
|
+
if hasattr(space, "get_sparsity_pattern"):
|
|
62
|
+
return space.get_sparsity_pattern(with_idx=with_idx)
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _get_elem_rows(space: SpaceLike) -> Array:
|
|
67
|
+
if hasattr(space, "get_elem_rows"):
|
|
68
|
+
return space.get_elem_rows()
|
|
69
|
+
return space.elem_dofs.reshape(-1)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def chunk_pad_stats(n_elems: int, n_chunks: Optional[int]) -> dict[str, int | float | None]:
|
|
73
|
+
"""
|
|
74
|
+
Compute padding overhead for chunked assembly.
|
|
75
|
+
Returns dict with chunk_size, pad, n_pad, and pad_ratio.
|
|
76
|
+
"""
|
|
77
|
+
n_elems = int(n_elems)
|
|
78
|
+
if n_chunks is None or n_elems <= 0:
|
|
79
|
+
return {"chunk_size": None, "pad": 0, "n_pad": n_elems, "pad_ratio": 0.0}
|
|
80
|
+
n_chunks = min(int(n_chunks), n_elems)
|
|
81
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
82
|
+
pad = (-n_elems) % chunk_size
|
|
83
|
+
n_pad = n_elems + pad
|
|
84
|
+
pad_ratio = float(pad) / float(n_elems) if n_elems else 0.0
|
|
85
|
+
return {"chunk_size": int(chunk_size), "pad": int(pad), "n_pad": int(n_pad), "pad_ratio": pad_ratio}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _maybe_trace_pad(
|
|
89
|
+
stats: dict[str, int | float | None], *, n_chunks: Optional[int], pad_trace: bool
|
|
90
|
+
) -> None:
|
|
91
|
+
if not pad_trace or not jax.core.trace_ctx.is_top_level():
|
|
92
|
+
return
|
|
93
|
+
if n_chunks is None:
|
|
94
|
+
return
|
|
95
|
+
print(
|
|
96
|
+
"[pad]",
|
|
97
|
+
f"n_chunks={int(n_chunks)}",
|
|
98
|
+
f"chunk_size={stats['chunk_size']}",
|
|
99
|
+
f"pad={stats['pad']}",
|
|
100
|
+
f"pad_ratio={stats['pad_ratio']:.4f}",
|
|
101
|
+
flush=True,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class BatchedAssembler:
|
|
106
|
+
"""
|
|
107
|
+
Assemble on a fixed space with optional masking to keep shapes static.
|
|
108
|
+
|
|
109
|
+
Use `mask` to zero padded elements while keeping input shapes fixed.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
space: SpaceLike,
|
|
115
|
+
elem_data: FormContext,
|
|
116
|
+
elem_dofs: Array,
|
|
117
|
+
*,
|
|
118
|
+
pattern: SparsityPattern | None = None,
|
|
119
|
+
) -> None:
|
|
120
|
+
self.space = space
|
|
121
|
+
self.elem_data = elem_data
|
|
122
|
+
self.elem_dofs = elem_dofs
|
|
123
|
+
self.n_elems = int(elem_dofs.shape[0])
|
|
124
|
+
self.n_ldofs = int(space.n_ldofs)
|
|
125
|
+
self.n_dofs = int(space.n_dofs)
|
|
126
|
+
self.pattern = pattern
|
|
127
|
+
self._rows: Array | None = None
|
|
128
|
+
self._cols: Array | None = None
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_space(
|
|
132
|
+
cls,
|
|
133
|
+
space: SpaceLike,
|
|
134
|
+
*,
|
|
135
|
+
dep: jnp.ndarray | None = None,
|
|
136
|
+
pattern: SparsityPattern | None = None,
|
|
137
|
+
) -> "BatchedAssembler":
|
|
138
|
+
elem_data = space.build_form_contexts(dep=dep)
|
|
139
|
+
return cls(space, elem_data, space.elem_dofs, pattern=pattern)
|
|
140
|
+
|
|
141
|
+
def make_mask(self, n_active: int) -> Array:
|
|
142
|
+
n_active = max(0, min(int(n_active), self.n_elems))
|
|
143
|
+
mask: np.ndarray = np.zeros((self.n_elems,), dtype=float)
|
|
144
|
+
if n_active:
|
|
145
|
+
mask[:n_active] = 1.0
|
|
146
|
+
return jnp.asarray(mask)
|
|
147
|
+
|
|
148
|
+
def slice(self, n_active: int) -> "BatchedAssembler":
|
|
149
|
+
n_active = max(0, min(int(n_active), self.n_elems))
|
|
150
|
+
elem_data = jax.tree_util.tree_map(lambda x: x[:n_active], self.elem_data)
|
|
151
|
+
elem_dofs = self.elem_dofs[:n_active]
|
|
152
|
+
return BatchedAssembler(self.space, elem_data, elem_dofs, pattern=None)
|
|
153
|
+
|
|
154
|
+
def _rows_cols(self) -> tuple[Array, Array]:
|
|
155
|
+
if self.pattern is not None:
|
|
156
|
+
return self.pattern.rows, self.pattern.cols
|
|
157
|
+
if self._rows is None or self._cols is None:
|
|
158
|
+
elem_dofs = self.elem_dofs
|
|
159
|
+
n_ldofs = int(elem_dofs.shape[1])
|
|
160
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
161
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
162
|
+
self._rows = rows
|
|
163
|
+
self._cols = cols
|
|
164
|
+
return self._rows, self._cols
|
|
165
|
+
|
|
166
|
+
def assemble_bilinear_with_kernel(
|
|
167
|
+
self, kernel: ElementBilinearKernel, *, mask: Array | None = None
|
|
168
|
+
) -> FluxSparseMatrix:
|
|
169
|
+
"""
|
|
170
|
+
kernel(ctx) -> (n_ldofs, n_ldofs)
|
|
171
|
+
"""
|
|
172
|
+
from ..solver import FluxSparseMatrix
|
|
173
|
+
|
|
174
|
+
Ke = jax.vmap(kernel)(self.elem_data)
|
|
175
|
+
if mask is not None:
|
|
176
|
+
Ke = Ke * jnp.asarray(mask)[:, None, None]
|
|
177
|
+
data = Ke.reshape(-1)
|
|
178
|
+
if self.pattern is not None:
|
|
179
|
+
return FluxSparseMatrix(self.pattern, data)
|
|
180
|
+
rows, cols = self._rows_cols()
|
|
181
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
|
|
182
|
+
|
|
183
|
+
def assemble_bilinear(
|
|
184
|
+
self,
|
|
185
|
+
form: FormKernel[P],
|
|
186
|
+
params: P,
|
|
187
|
+
*,
|
|
188
|
+
mask: Array | None = None,
|
|
189
|
+
kernel: ElementBilinearKernel | None = None,
|
|
190
|
+
jit: bool = True,
|
|
191
|
+
) -> FluxSparseMatrix:
|
|
192
|
+
if kernel is None:
|
|
193
|
+
kernel = make_element_bilinear_kernel(form, params, jit=jit)
|
|
194
|
+
return self.assemble_bilinear_with_kernel(kernel, mask=mask)
|
|
195
|
+
|
|
196
|
+
def assemble_linear_with_kernel(
|
|
197
|
+
self,
|
|
198
|
+
kernel: ElementLinearKernel,
|
|
199
|
+
*,
|
|
200
|
+
mask: Array | None = None,
|
|
201
|
+
dep: jnp.ndarray | None = None,
|
|
202
|
+
) -> Array:
|
|
203
|
+
"""
|
|
204
|
+
kernel(ctx) -> (n_ldofs,)
|
|
205
|
+
"""
|
|
206
|
+
elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
|
|
207
|
+
Fe = jax.vmap(kernel)(elem_data)
|
|
208
|
+
if mask is not None:
|
|
209
|
+
Fe = Fe * jnp.asarray(mask)[:, None]
|
|
210
|
+
rows = self.elem_dofs.reshape(-1)
|
|
211
|
+
data = Fe.reshape(-1)
|
|
212
|
+
return jax.ops.segment_sum(data, rows, self.n_dofs)
|
|
213
|
+
|
|
214
|
+
def assemble_linear(
|
|
215
|
+
self,
|
|
216
|
+
form: FormKernel[P],
|
|
217
|
+
params: P,
|
|
218
|
+
*,
|
|
219
|
+
mask: Array | None = None,
|
|
220
|
+
dep: jnp.ndarray | None = None,
|
|
221
|
+
kernel: ElementLinearKernel | None = None,
|
|
222
|
+
) -> Array:
|
|
223
|
+
if kernel is not None:
|
|
224
|
+
return self.assemble_linear_with_kernel(kernel, mask=mask, dep=dep)
|
|
225
|
+
elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
|
|
226
|
+
includes_measure = getattr(form, "_includes_measure", False)
|
|
227
|
+
|
|
228
|
+
def per_element(ctx: FormContext):
|
|
229
|
+
integrand = form(ctx, params)
|
|
230
|
+
if includes_measure:
|
|
231
|
+
return integrand.sum(axis=0)
|
|
232
|
+
wJ = ctx.w * ctx.test.detJ
|
|
233
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
234
|
+
|
|
235
|
+
Fe = jax.vmap(per_element)(elem_data)
|
|
236
|
+
if mask is not None:
|
|
237
|
+
Fe = Fe * jnp.asarray(mask)[:, None]
|
|
238
|
+
rows = self.elem_dofs.reshape(-1)
|
|
239
|
+
data = Fe.reshape(-1)
|
|
240
|
+
return jax.ops.segment_sum(data, rows, self.n_dofs)
|
|
241
|
+
|
|
242
|
+
def assemble_mass_matrix(
|
|
243
|
+
self, *, mask: Array | None = None, lumped: bool = False
|
|
244
|
+
) -> MassReturn:
|
|
245
|
+
from ..solver import FluxSparseMatrix
|
|
246
|
+
|
|
247
|
+
n_ldofs = self.n_ldofs
|
|
248
|
+
|
|
249
|
+
def per_element(ctx: FormContext):
|
|
250
|
+
N = ctx.test.N
|
|
251
|
+
base = jnp.einsum("qa,qb->qab", N, N)
|
|
252
|
+
if hasattr(ctx.test, "value_dim"):
|
|
253
|
+
vd = int(ctx.test.value_dim)
|
|
254
|
+
I = jnp.eye(vd, dtype=N.dtype)
|
|
255
|
+
base = base[:, :, :, None, None] * I[None, None, None, :, :]
|
|
256
|
+
base = base.reshape(base.shape[0], n_ldofs, n_ldofs)
|
|
257
|
+
wJ = ctx.w * ctx.test.detJ
|
|
258
|
+
return jnp.einsum("qab,q->ab", base, wJ)
|
|
259
|
+
|
|
260
|
+
Me = jax.vmap(per_element)(self.elem_data)
|
|
261
|
+
if mask is not None:
|
|
262
|
+
Me = Me * jnp.asarray(mask)[:, None, None]
|
|
263
|
+
data = Me.reshape(-1)
|
|
264
|
+
rows, cols = self._rows_cols()
|
|
265
|
+
|
|
266
|
+
if lumped:
|
|
267
|
+
M = jnp.zeros((self.n_dofs,), dtype=data.dtype)
|
|
268
|
+
M = M.at[rows].add(data)
|
|
269
|
+
return M
|
|
270
|
+
|
|
271
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
|
|
272
|
+
|
|
273
|
+
def assemble_residual_with_kernel(
|
|
274
|
+
self, kernel: ElementResidualKernel, u: Array, *, mask: Array | None = None
|
|
275
|
+
) -> Array:
|
|
276
|
+
"""
|
|
277
|
+
kernel(ctx, u_elem) -> (n_ldofs,)
|
|
278
|
+
"""
|
|
279
|
+
u_elems = jnp.asarray(u)[self.elem_dofs]
|
|
280
|
+
elem_res = jax.vmap(kernel)(self.elem_data, u_elems)
|
|
281
|
+
if mask is not None:
|
|
282
|
+
elem_res = elem_res * jnp.asarray(mask)[:, None]
|
|
283
|
+
rows = self.elem_dofs.reshape(-1)
|
|
284
|
+
data = elem_res.reshape(-1)
|
|
285
|
+
return jax.ops.segment_sum(data, rows, self.n_dofs)
|
|
286
|
+
|
|
287
|
+
def assemble_residual(
|
|
288
|
+
self,
|
|
289
|
+
res_form: ResidualForm[P],
|
|
290
|
+
u: Array,
|
|
291
|
+
params: P,
|
|
292
|
+
*,
|
|
293
|
+
mask: Array | None = None,
|
|
294
|
+
kernel: ElementResidualKernel | None = None,
|
|
295
|
+
) -> Array:
|
|
296
|
+
if kernel is None:
|
|
297
|
+
kernel = make_element_residual_kernel(res_form, params)
|
|
298
|
+
return self.assemble_residual_with_kernel(kernel, u, mask=mask)
|
|
299
|
+
|
|
300
|
+
def assemble_jacobian_with_kernel(
|
|
301
|
+
self,
|
|
302
|
+
kernel: ElementJacobianKernel,
|
|
303
|
+
u: Array,
|
|
304
|
+
*,
|
|
305
|
+
mask: Array | None = None,
|
|
306
|
+
sparse: bool = True,
|
|
307
|
+
return_flux_matrix: bool = False,
|
|
308
|
+
) -> JacobianReturn:
|
|
309
|
+
"""
|
|
310
|
+
kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
|
|
311
|
+
"""
|
|
312
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
313
|
+
|
|
314
|
+
u_elems = jnp.asarray(u)[self.elem_dofs]
|
|
315
|
+
J_e = jax.vmap(kernel)(u_elems, self.elem_data)
|
|
316
|
+
if mask is not None:
|
|
317
|
+
J_e = J_e * jnp.asarray(mask)[:, None, None]
|
|
318
|
+
data = J_e.reshape(-1)
|
|
319
|
+
if sparse:
|
|
320
|
+
if self.pattern is not None:
|
|
321
|
+
if return_flux_matrix:
|
|
322
|
+
return FluxSparseMatrix(self.pattern, data)
|
|
323
|
+
return self.pattern.rows, self.pattern.cols, data, self.n_dofs
|
|
324
|
+
rows, cols = self._rows_cols()
|
|
325
|
+
if return_flux_matrix:
|
|
326
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
|
|
327
|
+
return rows, cols, data, self.n_dofs
|
|
328
|
+
rows, cols = self._rows_cols()
|
|
329
|
+
idx = (rows.astype(jnp.int64) * int(self.n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
330
|
+
n_entries = self.n_dofs * self.n_dofs
|
|
331
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
332
|
+
update_window_dims=(),
|
|
333
|
+
inserted_window_dims=(0,),
|
|
334
|
+
scatter_dims_to_operand_dims=(0,),
|
|
335
|
+
)
|
|
336
|
+
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
337
|
+
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
338
|
+
return K_flat.reshape(self.n_dofs, self.n_dofs)
|
|
339
|
+
|
|
340
|
+
def assemble_jacobian(
|
|
341
|
+
self,
|
|
342
|
+
res_form: ResidualForm[P],
|
|
343
|
+
u: Array,
|
|
344
|
+
params: P,
|
|
345
|
+
*,
|
|
346
|
+
mask: Array | None = None,
|
|
347
|
+
kernel: ElementJacobianKernel | None = None,
|
|
348
|
+
sparse: bool = True,
|
|
349
|
+
return_flux_matrix: bool = False,
|
|
350
|
+
) -> JacobianReturn:
|
|
351
|
+
if kernel is None:
|
|
352
|
+
kernel = make_element_jacobian_kernel(res_form, params)
|
|
353
|
+
return self.assemble_jacobian_with_kernel(
|
|
354
|
+
kernel,
|
|
355
|
+
u,
|
|
356
|
+
mask=mask,
|
|
357
|
+
sparse=sparse,
|
|
358
|
+
return_flux_matrix=return_flux_matrix,
|
|
359
|
+
)
|
|
19
360
|
|
|
20
361
|
class SpaceLike(FESpaceBase, Protocol):
|
|
21
362
|
pass
|
|
@@ -23,12 +364,12 @@ class SpaceLike(FESpaceBase, Protocol):
|
|
|
23
364
|
|
|
24
365
|
def assemble_bilinear_dense(
|
|
25
366
|
space: SpaceLike,
|
|
26
|
-
kernel:
|
|
367
|
+
kernel: FormKernel[P],
|
|
27
368
|
params: P,
|
|
28
369
|
*,
|
|
29
370
|
sparse: bool = False,
|
|
30
371
|
return_flux_matrix: bool = False,
|
|
31
|
-
):
|
|
372
|
+
) -> BilinearReturn:
|
|
32
373
|
"""
|
|
33
374
|
Similar to scikit-fem's asm(biform, basis).
|
|
34
375
|
kernel: FormContext, params -> (n_ldofs, n_ldofs)
|
|
@@ -47,11 +388,15 @@ def assemble_bilinear_dense(
|
|
|
47
388
|
|
|
48
389
|
# ---- scatter into COO format ----
|
|
49
390
|
# row/col indices (n_elems, n_ldofs, n_ldofs)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
391
|
+
pat = _get_pattern(space, with_idx=False)
|
|
392
|
+
if pat is None:
|
|
393
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
|
|
394
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
|
|
395
|
+
rows = rows.reshape(-1)
|
|
396
|
+
cols = cols.reshape(-1)
|
|
397
|
+
else:
|
|
398
|
+
rows = pat.rows
|
|
399
|
+
cols = pat.cols
|
|
55
400
|
data = K_e_all.reshape(-1)
|
|
56
401
|
|
|
57
402
|
# Flatten indices for segment_sum via (row * n_dofs + col)
|
|
@@ -71,18 +416,22 @@ def assemble_bilinear_dense(
|
|
|
71
416
|
|
|
72
417
|
|
|
73
418
|
def assemble_bilinear_form(
|
|
74
|
-
space,
|
|
75
|
-
form,
|
|
76
|
-
params,
|
|
419
|
+
space: SpaceLike,
|
|
420
|
+
form: FormKernel[P],
|
|
421
|
+
params: P,
|
|
77
422
|
*,
|
|
78
|
-
pattern=None,
|
|
79
|
-
|
|
423
|
+
pattern: SparsityPattern | None = None,
|
|
424
|
+
n_chunks: Optional[int] = None, # None -> no chunking
|
|
80
425
|
dep: jnp.ndarray | None = None,
|
|
81
|
-
|
|
426
|
+
kernel: ElementBilinearKernel | None = None,
|
|
427
|
+
jit: bool = True,
|
|
428
|
+
pad_trace: bool = False,
|
|
429
|
+
) -> FluxSparseMatrix:
|
|
82
430
|
"""
|
|
83
431
|
Assemble a sparse bilinear form into a FluxSparseMatrix.
|
|
84
432
|
|
|
85
433
|
Expects form(ctx, params) -> (n_q, n_ldofs, n_ldofs).
|
|
434
|
+
If kernel is provided: kernel(ctx) -> (n_ldofs, n_ldofs).
|
|
86
435
|
"""
|
|
87
436
|
from ..solver import FluxSparseMatrix
|
|
88
437
|
|
|
@@ -104,18 +453,27 @@ def assemble_bilinear_form(
|
|
|
104
453
|
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
105
454
|
return (integrand * wJ[:, None, None]).sum(axis=0) # (m, m)
|
|
106
455
|
|
|
456
|
+
if kernel is None:
|
|
457
|
+
kernel = make_element_bilinear_kernel(form, params, jit=jit)
|
|
458
|
+
|
|
107
459
|
# --- no-chunk path (your current implementation) ---
|
|
108
|
-
if
|
|
109
|
-
K_e_all = jax.vmap(
|
|
460
|
+
if n_chunks is None:
|
|
461
|
+
K_e_all = jax.vmap(kernel)(elem_data) # (n_elems, m, m)
|
|
110
462
|
data = K_e_all.reshape(-1)
|
|
111
463
|
return FluxSparseMatrix(pat, data)
|
|
112
464
|
|
|
113
465
|
# --- chunked path ---
|
|
114
466
|
n_elems = space.elem_dofs.shape[0]
|
|
467
|
+
if n_chunks <= 0:
|
|
468
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
469
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
470
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
471
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
472
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
115
473
|
# Ideally get m from pat (otherwise infer from one element).
|
|
116
474
|
m = getattr(pat, "n_ldofs", None)
|
|
117
475
|
if m is None:
|
|
118
|
-
m =
|
|
476
|
+
m = kernel(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
|
|
119
477
|
|
|
120
478
|
# Pad to fixed-size chunks for JIT stability.
|
|
121
479
|
pad = (-n_elems) % chunk_size
|
|
@@ -141,7 +499,7 @@ def assemble_bilinear_form(
|
|
|
141
499
|
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
142
500
|
elem_data_pad,
|
|
143
501
|
)
|
|
144
|
-
Ke = jax.vmap(
|
|
502
|
+
Ke = jax.vmap(kernel)(ctx_chunk) # (chunk, m, m)
|
|
145
503
|
return Ke.reshape(-1) # (chunk*m*m,)
|
|
146
504
|
|
|
147
505
|
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
@@ -149,7 +507,13 @@ def assemble_bilinear_form(
|
|
|
149
507
|
return FluxSparseMatrix(pat, data)
|
|
150
508
|
|
|
151
509
|
|
|
152
|
-
def assemble_mass_matrix(
|
|
510
|
+
def assemble_mass_matrix(
|
|
511
|
+
space: SpaceLike,
|
|
512
|
+
*,
|
|
513
|
+
lumped: bool = False,
|
|
514
|
+
n_chunks: Optional[int] = None,
|
|
515
|
+
pad_trace: bool = False,
|
|
516
|
+
) -> MassReturn:
|
|
153
517
|
"""
|
|
154
518
|
Assemble mass matrix M_ij = ∫ N_i N_j dΩ.
|
|
155
519
|
Supports scalar and vector spaces. If lumped=True, rows are summed to diagonal.
|
|
@@ -170,11 +534,17 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
|
|
|
170
534
|
wJ = ctx.w * ctx.test.detJ
|
|
171
535
|
return jnp.einsum("qab,q->ab", base, wJ)
|
|
172
536
|
|
|
173
|
-
if
|
|
537
|
+
if n_chunks is None:
|
|
174
538
|
M_e_all = jax.vmap(per_element)(ctxs) # (n_elems, n_ldofs, n_ldofs)
|
|
175
539
|
data = M_e_all.reshape(-1)
|
|
176
540
|
else:
|
|
177
541
|
n_elems = space.elem_dofs.shape[0]
|
|
542
|
+
if n_chunks <= 0:
|
|
543
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
544
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
545
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
546
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
547
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
178
548
|
pad = (-n_elems) % chunk_size
|
|
179
549
|
if pad:
|
|
180
550
|
ctxs_pad = jax.tree_util.tree_map(
|
|
@@ -205,8 +575,13 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
|
|
|
205
575
|
data = data_chunks.reshape(-1)[: n_elems * n_ldofs * n_ldofs]
|
|
206
576
|
|
|
207
577
|
elem_dofs = space.elem_dofs
|
|
208
|
-
|
|
209
|
-
|
|
578
|
+
pat = _get_pattern(space, with_idx=False)
|
|
579
|
+
if pat is None:
|
|
580
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
581
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
582
|
+
else:
|
|
583
|
+
rows = pat.rows
|
|
584
|
+
cols = pat.cols
|
|
210
585
|
|
|
211
586
|
if lumped:
|
|
212
587
|
n_dofs = space.n_dofs
|
|
@@ -219,15 +594,18 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
|
|
|
219
594
|
|
|
220
595
|
def assemble_linear_form(
|
|
221
596
|
space: SpaceLike,
|
|
222
|
-
form:
|
|
597
|
+
form: FormKernel[P],
|
|
223
598
|
params: P,
|
|
224
599
|
*,
|
|
600
|
+
kernel: ElementLinearKernel | None = None,
|
|
225
601
|
sparse: bool = False,
|
|
226
|
-
|
|
602
|
+
n_chunks: Optional[int] = None,
|
|
227
603
|
dep: jnp.ndarray | None = None,
|
|
228
|
-
|
|
604
|
+
pad_trace: bool = False,
|
|
605
|
+
) -> LinearReturn:
|
|
229
606
|
"""
|
|
230
607
|
Expects form(ctx, params) -> (n_q, n_ldofs) and integrates Σ_q form * wJ for RHS.
|
|
608
|
+
If kernel is provided: kernel(ctx) -> (n_ldofs,).
|
|
231
609
|
"""
|
|
232
610
|
elem_dofs = space.elem_dofs
|
|
233
611
|
n_dofs = space.n_dofs
|
|
@@ -237,19 +615,28 @@ def assemble_linear_form(
|
|
|
237
615
|
|
|
238
616
|
includes_measure = getattr(form, "_includes_measure", False)
|
|
239
617
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
618
|
+
if kernel is None:
|
|
619
|
+
def per_element(ctx: FormContext):
|
|
620
|
+
integrand = form(ctx, params) # (n_q, m)
|
|
621
|
+
if includes_measure:
|
|
622
|
+
return integrand.sum(axis=0)
|
|
623
|
+
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
624
|
+
return (integrand * wJ[:, None]).sum(axis=0) # (m,)
|
|
625
|
+
else:
|
|
626
|
+
per_element = kernel
|
|
246
627
|
|
|
247
|
-
if
|
|
628
|
+
if n_chunks is None:
|
|
248
629
|
F_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m)
|
|
249
630
|
data = F_e_all.reshape(-1)
|
|
250
631
|
else:
|
|
251
632
|
n_elems = space.elem_dofs.shape[0]
|
|
252
633
|
m = n_ldofs
|
|
634
|
+
if n_chunks <= 0:
|
|
635
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
636
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
637
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
638
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
639
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
253
640
|
pad = (-n_elems) % chunk_size
|
|
254
641
|
if pad:
|
|
255
642
|
elem_data_pad = jax.tree_util.tree_map(
|
|
@@ -279,7 +666,7 @@ def assemble_linear_form(
|
|
|
279
666
|
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
280
667
|
data = data_chunks.reshape(-1)[: n_elems * m]
|
|
281
668
|
|
|
282
|
-
rows =
|
|
669
|
+
rows = _get_elem_rows(space)
|
|
283
670
|
|
|
284
671
|
if sparse:
|
|
285
672
|
return rows, data, n_dofs
|
|
@@ -288,7 +675,7 @@ def assemble_linear_form(
|
|
|
288
675
|
return F
|
|
289
676
|
|
|
290
677
|
|
|
291
|
-
def assemble_functional(space: SpaceLike, form:
|
|
678
|
+
def assemble_functional(space: SpaceLike, form: FormKernel[P], params: P) -> jnp.ndarray:
|
|
292
679
|
"""
|
|
293
680
|
Assemble scalar functional J = ∫ form(ctx, params) dΩ.
|
|
294
681
|
Expects form(ctx, params) -> (n_q,) or (n_q, 1).
|
|
@@ -318,7 +705,7 @@ def assemble_jacobian_global(
|
|
|
318
705
|
*,
|
|
319
706
|
sparse: bool = False,
|
|
320
707
|
return_flux_matrix: bool = False,
|
|
321
|
-
):
|
|
708
|
+
) -> JacobianReturn:
|
|
322
709
|
"""
|
|
323
710
|
Assemble Jacobian (dR/du) from element residual res_form.
|
|
324
711
|
res_form(ctx, u_elem, params) -> (n_q, n_ldofs)
|
|
@@ -339,11 +726,16 @@ def assemble_jacobian_global(
|
|
|
339
726
|
jac_fun = jax.jacrev(fe_fun, argnums=0)
|
|
340
727
|
|
|
341
728
|
u_elems = u[elem_dofs] # (n_elems, n_ldofs)
|
|
342
|
-
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=
|
|
729
|
+
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
|
|
343
730
|
J_e_all = jax.vmap(jac_fun)(u_elems, elem_data, elem_ids) # (n_elems, m, m)
|
|
344
731
|
|
|
345
|
-
|
|
346
|
-
|
|
732
|
+
pat = _get_pattern(space, with_idx=False)
|
|
733
|
+
if pat is None:
|
|
734
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
735
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
736
|
+
else:
|
|
737
|
+
rows = pat.rows
|
|
738
|
+
cols = pat.cols
|
|
347
739
|
data = J_e_all.reshape(-1)
|
|
348
740
|
|
|
349
741
|
if sparse:
|
|
@@ -358,7 +750,7 @@ def assemble_jacobian_global(
|
|
|
358
750
|
return K_flat.reshape(n_dofs, n_dofs)
|
|
359
751
|
|
|
360
752
|
|
|
361
|
-
def
|
|
753
|
+
def assemble_jacobian_elementwise(
|
|
362
754
|
space: SpaceLike,
|
|
363
755
|
res_form: ResidualForm[P],
|
|
364
756
|
u: jnp.ndarray,
|
|
@@ -366,9 +758,9 @@ def assemble_jacobian_elementwise_xla(
|
|
|
366
758
|
*,
|
|
367
759
|
sparse: bool = False,
|
|
368
760
|
return_flux_matrix: bool = False,
|
|
369
|
-
):
|
|
761
|
+
) -> JacobianReturn:
|
|
370
762
|
"""
|
|
371
|
-
Assemble Jacobian with element kernels
|
|
763
|
+
Assemble Jacobian with element kernels via vmap + scatter_add.
|
|
372
764
|
Recompiles if n_dofs changes, but independent of element count.
|
|
373
765
|
"""
|
|
374
766
|
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
@@ -388,8 +780,13 @@ def assemble_jacobian_elementwise_xla(
|
|
|
388
780
|
u_elems = u[elem_dofs]
|
|
389
781
|
J_e_all = jax.vmap(jac_fun)(u_elems, ctxs) # (n_elems, m, m)
|
|
390
782
|
|
|
391
|
-
|
|
392
|
-
|
|
783
|
+
pat = _get_pattern(space, with_idx=False)
|
|
784
|
+
if pat is None:
|
|
785
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
786
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
787
|
+
else:
|
|
788
|
+
rows = pat.rows
|
|
789
|
+
cols = pat.cols
|
|
393
790
|
data = J_e_all.reshape(-1)
|
|
394
791
|
|
|
395
792
|
if sparse:
|
|
@@ -406,7 +803,7 @@ def assemble_jacobian_elementwise_xla(
|
|
|
406
803
|
)
|
|
407
804
|
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
408
805
|
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
409
|
-
return K_flat.reshape(
|
|
806
|
+
return K_flat.reshape(n_dofs, n_dofs)
|
|
410
807
|
|
|
411
808
|
|
|
412
809
|
def assemble_residual_global(
|
|
@@ -416,7 +813,7 @@ def assemble_residual_global(
|
|
|
416
813
|
params: P,
|
|
417
814
|
*,
|
|
418
815
|
sparse: bool = False
|
|
419
|
-
):
|
|
816
|
+
) -> LinearReturn:
|
|
420
817
|
"""
|
|
421
818
|
Assemble residual vector that depends on u.
|
|
422
819
|
form(ctx, u_elem, params) -> (n_q, n_ldofs)
|
|
@@ -435,10 +832,10 @@ def assemble_residual_global(
|
|
|
435
832
|
fe = (integrand * wJ[:, None]).sum(axis=0)
|
|
436
833
|
return fe
|
|
437
834
|
|
|
438
|
-
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=
|
|
835
|
+
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
|
|
439
836
|
F_e_all = jax.vmap(per_element)(elem_data, elem_dofs, elem_ids) # (n_elems, m)
|
|
440
837
|
|
|
441
|
-
rows =
|
|
838
|
+
rows = _get_elem_rows(space)
|
|
442
839
|
data = F_e_all.reshape(-1)
|
|
443
840
|
|
|
444
841
|
if sparse:
|
|
@@ -448,16 +845,16 @@ def assemble_residual_global(
|
|
|
448
845
|
return F
|
|
449
846
|
|
|
450
847
|
|
|
451
|
-
def
|
|
848
|
+
def assemble_residual_elementwise(
|
|
452
849
|
space: SpaceLike,
|
|
453
850
|
res_form: ResidualForm[P],
|
|
454
851
|
u: jnp.ndarray,
|
|
455
852
|
params: P,
|
|
456
853
|
*,
|
|
457
854
|
sparse: bool = False,
|
|
458
|
-
):
|
|
855
|
+
) -> LinearReturn:
|
|
459
856
|
"""
|
|
460
|
-
Assemble residual using element kernels
|
|
857
|
+
Assemble residual using element kernels via vmap + scatter_add.
|
|
461
858
|
Recompiles if n_dofs changes, but independent of element count.
|
|
462
859
|
"""
|
|
463
860
|
elem_dofs = space.elem_dofs
|
|
@@ -471,7 +868,7 @@ def assemble_residual_elementwise_xla(
|
|
|
471
868
|
|
|
472
869
|
u_elems = u[elem_dofs]
|
|
473
870
|
F_e_all = jax.vmap(per_element)(ctxs, u_elems) # (n_elems, m)
|
|
474
|
-
rows =
|
|
871
|
+
rows = _get_elem_rows(space)
|
|
475
872
|
data = F_e_all.reshape(-1)
|
|
476
873
|
|
|
477
874
|
if sparse:
|
|
@@ -487,7 +884,44 @@ def assemble_residual_elementwise_xla(
|
|
|
487
884
|
return F
|
|
488
885
|
|
|
489
886
|
|
|
490
|
-
|
|
887
|
+
# Backward compatibility aliases (prefer assemble_*_elementwise).
|
|
888
|
+
assemble_jacobian_elementwise_xla = assemble_jacobian_elementwise
|
|
889
|
+
assemble_residual_elementwise_xla = assemble_residual_elementwise
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
def make_element_bilinear_kernel(
|
|
893
|
+
form: FormKernel[P], params: P, *, jit: bool = True
|
|
894
|
+
) -> ElementBilinearKernel:
|
|
895
|
+
"""Element kernel: (ctx) -> Ke."""
|
|
896
|
+
|
|
897
|
+
def per_element(ctx: FormContext):
|
|
898
|
+
integrand = form(ctx, params)
|
|
899
|
+
if getattr(form, "_includes_measure", False):
|
|
900
|
+
return integrand.sum(axis=0)
|
|
901
|
+
wJ = ctx.w * ctx.test.detJ
|
|
902
|
+
return (integrand * wJ[:, None, None]).sum(axis=0)
|
|
903
|
+
|
|
904
|
+
return jax.jit(per_element) if jit else per_element
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def make_element_linear_kernel(
|
|
908
|
+
form: FormKernel[P], params: P, *, jit: bool = True
|
|
909
|
+
) -> ElementLinearKernel:
|
|
910
|
+
"""Element kernel: (ctx) -> fe."""
|
|
911
|
+
|
|
912
|
+
def per_element(ctx: FormContext):
|
|
913
|
+
integrand = form(ctx, params)
|
|
914
|
+
if getattr(form, "_includes_measure", False):
|
|
915
|
+
return integrand.sum(axis=0)
|
|
916
|
+
wJ = ctx.w * ctx.test.detJ
|
|
917
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
918
|
+
|
|
919
|
+
return jax.jit(per_element) if jit else per_element
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def make_element_residual_kernel(
|
|
923
|
+
res_form: ResidualForm[P], params: P
|
|
924
|
+
) -> ElementResidualKernel:
|
|
491
925
|
"""Jitted element residual kernel: (ctx, u_elem) -> fe."""
|
|
492
926
|
|
|
493
927
|
def per_element(ctx: FormContext, u_elem: jnp.ndarray):
|
|
@@ -500,7 +934,9 @@ def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
|
|
|
500
934
|
return jax.jit(per_element)
|
|
501
935
|
|
|
502
936
|
|
|
503
|
-
def make_element_jacobian_kernel(
|
|
937
|
+
def make_element_jacobian_kernel(
|
|
938
|
+
res_form: ResidualForm[P], params: P
|
|
939
|
+
) -> ElementJacobianKernel:
|
|
504
940
|
"""Jitted element Jacobian kernel: (ctx, u_elem) -> Ke."""
|
|
505
941
|
|
|
506
942
|
def fe_fun(u_elem, ctx: FormContext):
|
|
@@ -513,7 +949,9 @@ def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
|
|
|
513
949
|
return jax.jit(jax.jacrev(fe_fun, argnums=0))
|
|
514
950
|
|
|
515
951
|
|
|
516
|
-
def element_residual(
|
|
952
|
+
def element_residual(
|
|
953
|
+
res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
|
|
954
|
+
) -> ResidualValue:
|
|
517
955
|
"""
|
|
518
956
|
Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
|
|
519
957
|
Returns shape (n_ldofs,).
|
|
@@ -538,7 +976,9 @@ def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
|
|
|
538
976
|
return jax.tree_util.tree_map(lambda x: jnp.einsum("qa,q->a", x, ctx.w * ctx.test.detJ), integrand)
|
|
539
977
|
|
|
540
978
|
|
|
541
|
-
def element_jacobian(
|
|
979
|
+
def element_jacobian(
|
|
980
|
+
res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
|
|
981
|
+
) -> ResidualValue:
|
|
542
982
|
"""
|
|
543
983
|
Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
|
|
544
984
|
"""
|
|
@@ -548,7 +988,46 @@ def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
|
|
|
548
988
|
return jax.jacfwd(_r_elem)(u_elem)
|
|
549
989
|
|
|
550
990
|
|
|
551
|
-
def
|
|
991
|
+
def make_element_kernel(
|
|
992
|
+
form: FormKernel[P] | ResidualForm[P],
|
|
993
|
+
params: P,
|
|
994
|
+
*,
|
|
995
|
+
kind: Literal["bilinear", "linear", "residual", "jacobian"],
|
|
996
|
+
jit: bool = True,
|
|
997
|
+
) -> ElementKernel:
|
|
998
|
+
"""
|
|
999
|
+
Unified entry point for element kernels.
|
|
1000
|
+
|
|
1001
|
+
kind:
|
|
1002
|
+
- "bilinear": kernel(ctx) -> (n_ldofs, n_ldofs)
|
|
1003
|
+
- "linear": kernel(ctx) -> (n_ldofs,)
|
|
1004
|
+
- "residual": kernel(ctx, u_elem) -> (n_ldofs,)
|
|
1005
|
+
- "jacobian": kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
|
|
1006
|
+
"""
|
|
1007
|
+
kind = cast(Literal["bilinear", "linear", "residual", "jacobian"], kind.lower())
|
|
1008
|
+
if kind == "bilinear":
|
|
1009
|
+
form_bilinear = cast(FormKernel[P], form)
|
|
1010
|
+
return make_element_bilinear_kernel(form_bilinear, params, jit=jit)
|
|
1011
|
+
if kind == "linear":
|
|
1012
|
+
form_linear = cast(FormKernel[P], form)
|
|
1013
|
+
def per_element(ctx: FormContext):
|
|
1014
|
+
integrand = form_linear(ctx, params)
|
|
1015
|
+
if getattr(form_linear, "_includes_measure", False):
|
|
1016
|
+
return integrand.sum(axis=0)
|
|
1017
|
+
wJ = ctx.w * ctx.test.detJ
|
|
1018
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
1019
|
+
|
|
1020
|
+
return jax.jit(per_element) if jit else per_element
|
|
1021
|
+
if kind == "residual":
|
|
1022
|
+
form_residual = cast(ResidualForm[P], form)
|
|
1023
|
+
return make_element_residual_kernel(form_residual, params)
|
|
1024
|
+
if kind == "jacobian":
|
|
1025
|
+
form_residual = cast(ResidualForm[P], form)
|
|
1026
|
+
return make_element_jacobian_kernel(form_residual, params)
|
|
1027
|
+
raise ValueError(f"Unknown kernel kind: {kind}")
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True) -> SparsityPattern:
|
|
552
1031
|
"""
|
|
553
1032
|
Build a SparsityPattern (rows/cols[/idx]) that is independent of the solution.
|
|
554
1033
|
NOTE: rows/cols ordering matches assemble_jacobian_values(...).reshape(-1)
|
|
@@ -557,24 +1036,24 @@ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
|
|
|
557
1036
|
"""
|
|
558
1037
|
from ..solver import SparsityPattern # local import to avoid circular
|
|
559
1038
|
|
|
560
|
-
elem_dofs = jnp.asarray(space.elem_dofs, dtype=
|
|
1039
|
+
elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE)
|
|
561
1040
|
n_dofs = int(space.n_dofs)
|
|
562
1041
|
n_ldofs = int(space.n_ldofs)
|
|
563
1042
|
|
|
564
|
-
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(
|
|
565
|
-
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(
|
|
1043
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(INDEX_DTYPE)
|
|
1044
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(INDEX_DTYPE)
|
|
566
1045
|
|
|
567
1046
|
key = rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)
|
|
568
|
-
order = jnp.argsort(key).astype(
|
|
1047
|
+
order = jnp.argsort(key).astype(INDEX_DTYPE)
|
|
569
1048
|
rows_sorted = rows[order]
|
|
570
1049
|
cols_sorted = cols[order]
|
|
571
|
-
counts = jnp.bincount(rows_sorted, length=n_dofs).astype(
|
|
572
|
-
indptr_j = jnp.concatenate([jnp.array([0], dtype=
|
|
573
|
-
indices_j = cols_sorted.astype(
|
|
1050
|
+
counts = jnp.bincount(rows_sorted, length=n_dofs).astype(INDEX_DTYPE)
|
|
1051
|
+
indptr_j = jnp.concatenate([jnp.array([0], dtype=INDEX_DTYPE), jnp.cumsum(counts)])
|
|
1052
|
+
indices_j = cols_sorted.astype(INDEX_DTYPE)
|
|
574
1053
|
perm = order
|
|
575
1054
|
|
|
576
1055
|
if with_idx:
|
|
577
|
-
idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(
|
|
1056
|
+
idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
578
1057
|
return SparsityPattern(
|
|
579
1058
|
rows=rows,
|
|
580
1059
|
cols=cols,
|
|
@@ -601,8 +1080,10 @@ def assemble_jacobian_values(
|
|
|
601
1080
|
u: jnp.ndarray,
|
|
602
1081
|
params: P,
|
|
603
1082
|
*,
|
|
604
|
-
kernel=None,
|
|
605
|
-
|
|
1083
|
+
kernel: ElementJacobianKernel | None = None,
|
|
1084
|
+
n_chunks: Optional[int] = None,
|
|
1085
|
+
pad_trace: bool = False,
|
|
1086
|
+
) -> Array:
|
|
606
1087
|
"""
|
|
607
1088
|
Assemble only the numeric values for the Jacobian (pattern-free).
|
|
608
1089
|
"""
|
|
@@ -610,8 +1091,49 @@ def assemble_jacobian_values(
|
|
|
610
1091
|
ker = kernel if kernel is not None else make_element_jacobian_kernel(res_form, params)
|
|
611
1092
|
|
|
612
1093
|
u_elems = u[space.elem_dofs]
|
|
613
|
-
|
|
614
|
-
|
|
1094
|
+
if n_chunks is None:
|
|
1095
|
+
J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
|
|
1096
|
+
return J_e_all.reshape(-1)
|
|
1097
|
+
|
|
1098
|
+
n_elems = int(u_elems.shape[0])
|
|
1099
|
+
if n_chunks <= 0:
|
|
1100
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
1101
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
1102
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
1103
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
1104
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
1105
|
+
pad = (-n_elems) % chunk_size
|
|
1106
|
+
if pad:
|
|
1107
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
1108
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
1109
|
+
ctxs,
|
|
1110
|
+
)
|
|
1111
|
+
u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
|
|
1112
|
+
else:
|
|
1113
|
+
ctxs_pad = ctxs
|
|
1114
|
+
u_elems_pad = u_elems
|
|
1115
|
+
|
|
1116
|
+
n_pad = n_elems + pad
|
|
1117
|
+
n_chunks = n_pad // chunk_size
|
|
1118
|
+
m = int(space.n_ldofs)
|
|
1119
|
+
|
|
1120
|
+
def _slice_first_dim(x, start, size):
|
|
1121
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
1122
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
1123
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
1124
|
+
|
|
1125
|
+
def chunk_fn(i):
|
|
1126
|
+
start = i * chunk_size
|
|
1127
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
1128
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
1129
|
+
ctxs_pad,
|
|
1130
|
+
)
|
|
1131
|
+
u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
|
|
1132
|
+
J_e = jax.vmap(ker)(u_chunk, ctx_chunk)
|
|
1133
|
+
return J_e.reshape(-1)
|
|
1134
|
+
|
|
1135
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
1136
|
+
return data_chunks.reshape(-1)[: n_elems * m * m]
|
|
615
1137
|
|
|
616
1138
|
|
|
617
1139
|
def assemble_residual_scatter(
|
|
@@ -620,9 +1142,11 @@ def assemble_residual_scatter(
|
|
|
620
1142
|
u: jnp.ndarray,
|
|
621
1143
|
params: P,
|
|
622
1144
|
*,
|
|
623
|
-
kernel=None,
|
|
1145
|
+
kernel: ElementResidualKernel | None = None,
|
|
624
1146
|
sparse: bool = False,
|
|
625
|
-
|
|
1147
|
+
n_chunks: Optional[int] = None,
|
|
1148
|
+
pad_trace: bool = False,
|
|
1149
|
+
) -> LinearReturn:
|
|
626
1150
|
"""
|
|
627
1151
|
Assemble residual using jitted element kernel + vmap + scatter_add.
|
|
628
1152
|
Avoids Python loops; good for JIT stability.
|
|
@@ -633,20 +1157,62 @@ def assemble_residual_scatter(
|
|
|
633
1157
|
"""
|
|
634
1158
|
elem_dofs = space.elem_dofs
|
|
635
1159
|
n_dofs = space.n_dofs
|
|
636
|
-
if
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
1160
|
+
if jax.core.trace_ctx.is_top_level():
|
|
1161
|
+
if np.max(elem_dofs) >= n_dofs:
|
|
1162
|
+
raise ValueError("elem_dofs contains index outside n_dofs")
|
|
1163
|
+
if np.min(elem_dofs) < 0:
|
|
1164
|
+
raise ValueError("elem_dofs contains negative index")
|
|
640
1165
|
ctxs = space.build_form_contexts()
|
|
641
1166
|
ker = kernel if kernel is not None else make_element_residual_kernel(res_form, params)
|
|
642
1167
|
|
|
643
1168
|
u_elems = u[elem_dofs]
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
1169
|
+
if n_chunks is None:
|
|
1170
|
+
elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
|
|
1171
|
+
else:
|
|
1172
|
+
n_elems = int(u_elems.shape[0])
|
|
1173
|
+
if n_chunks <= 0:
|
|
1174
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
1175
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
1176
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
1177
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
1178
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
1179
|
+
pad = (-n_elems) % chunk_size
|
|
1180
|
+
if pad:
|
|
1181
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
1182
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
1183
|
+
ctxs,
|
|
1184
|
+
)
|
|
1185
|
+
u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
|
|
1186
|
+
else:
|
|
1187
|
+
ctxs_pad = ctxs
|
|
1188
|
+
u_elems_pad = u_elems
|
|
1189
|
+
|
|
1190
|
+
n_pad = n_elems + pad
|
|
1191
|
+
n_chunks = n_pad // chunk_size
|
|
1192
|
+
|
|
1193
|
+
def _slice_first_dim(x, start, size):
|
|
1194
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
1195
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
1196
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
648
1197
|
|
|
649
|
-
|
|
1198
|
+
def chunk_fn(i):
|
|
1199
|
+
start = i * chunk_size
|
|
1200
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
1201
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
1202
|
+
ctxs_pad,
|
|
1203
|
+
)
|
|
1204
|
+
u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
|
|
1205
|
+
res_chunk = jax.vmap(ker)(ctx_chunk, u_chunk)
|
|
1206
|
+
return res_chunk.reshape(-1)
|
|
1207
|
+
|
|
1208
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
1209
|
+
elem_res = data_chunks.reshape(-1)[: n_elems * int(space.n_ldofs)].reshape(n_elems, -1)
|
|
1210
|
+
if jax.core.trace_ctx.is_top_level():
|
|
1211
|
+
if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
|
|
1212
|
+
bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
|
|
1213
|
+
raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
|
|
1214
|
+
|
|
1215
|
+
rows = _get_elem_rows(space)
|
|
650
1216
|
data = elem_res.reshape(-1)
|
|
651
1217
|
|
|
652
1218
|
if sparse:
|
|
@@ -668,11 +1234,13 @@ def assemble_jacobian_scatter(
|
|
|
668
1234
|
u: jnp.ndarray,
|
|
669
1235
|
params: P,
|
|
670
1236
|
*,
|
|
671
|
-
kernel=None,
|
|
1237
|
+
kernel: ElementJacobianKernel | None = None,
|
|
672
1238
|
sparse: bool = False,
|
|
673
1239
|
return_flux_matrix: bool = False,
|
|
674
|
-
pattern=None,
|
|
675
|
-
|
|
1240
|
+
pattern: SparsityPattern | None = None,
|
|
1241
|
+
n_chunks: Optional[int] = None,
|
|
1242
|
+
pad_trace: bool = False,
|
|
1243
|
+
) -> JacobianReturn:
|
|
676
1244
|
"""
|
|
677
1245
|
Assemble Jacobian using jitted element kernel + vmap + scatter_add.
|
|
678
1246
|
If a SparsityPattern is provided, rows/cols are reused without regeneration.
|
|
@@ -682,7 +1250,9 @@ def assemble_jacobian_scatter(
|
|
|
682
1250
|
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
683
1251
|
|
|
684
1252
|
pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
|
|
685
|
-
data = assemble_jacobian_values(
|
|
1253
|
+
data = assemble_jacobian_values(
|
|
1254
|
+
space, res_form, u, params, kernel=kernel, n_chunks=n_chunks, pad_trace=pad_trace
|
|
1255
|
+
)
|
|
686
1256
|
|
|
687
1257
|
if sparse:
|
|
688
1258
|
if return_flux_matrix:
|
|
@@ -691,7 +1261,7 @@ def assemble_jacobian_scatter(
|
|
|
691
1261
|
|
|
692
1262
|
idx = pat.idx
|
|
693
1263
|
if idx is None:
|
|
694
|
-
idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(
|
|
1264
|
+
idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
695
1265
|
|
|
696
1266
|
n_entries = pat.n_dofs * pat.n_dofs
|
|
697
1267
|
sdn = jax.lax.ScatterDimensionNumbers(
|
|
@@ -708,12 +1278,21 @@ def assemble_jacobian_scatter(
|
|
|
708
1278
|
def assemble_residual(
|
|
709
1279
|
space: SpaceLike,
|
|
710
1280
|
form: ResidualForm[P],
|
|
711
|
-
u: jnp.ndarray,
|
|
1281
|
+
u: jnp.ndarray,
|
|
1282
|
+
params: P,
|
|
712
1283
|
*,
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
1284
|
+
kernel: ElementResidualKernel | None = None,
|
|
1285
|
+
sparse: bool = False,
|
|
1286
|
+
n_chunks: Optional[int] = None,
|
|
1287
|
+
pad_trace: bool = False,
|
|
1288
|
+
) -> LinearReturn:
|
|
1289
|
+
"""
|
|
1290
|
+
Assemble the global residual vector (scatter-based).
|
|
1291
|
+
If kernel is provided: kernel(ctx, u_elem) -> (n_ldofs,).
|
|
1292
|
+
"""
|
|
1293
|
+
return assemble_residual_scatter(
|
|
1294
|
+
space, form, u, params, kernel=kernel, sparse=sparse, n_chunks=n_chunks, pad_trace=pad_trace
|
|
1295
|
+
)
|
|
717
1296
|
|
|
718
1297
|
|
|
719
1298
|
def assemble_jacobian(
|
|
@@ -722,19 +1301,28 @@ def assemble_jacobian(
|
|
|
722
1301
|
u: jnp.ndarray,
|
|
723
1302
|
params: P,
|
|
724
1303
|
*,
|
|
1304
|
+
kernel: ElementJacobianKernel | None = None,
|
|
725
1305
|
sparse: bool = True,
|
|
726
1306
|
return_flux_matrix: bool = False,
|
|
727
|
-
pattern=None,
|
|
728
|
-
|
|
729
|
-
|
|
1307
|
+
pattern: SparsityPattern | None = None,
|
|
1308
|
+
n_chunks: Optional[int] = None,
|
|
1309
|
+
pad_trace: bool = False,
|
|
1310
|
+
) -> JacobianReturn:
|
|
1311
|
+
"""
|
|
1312
|
+
Assemble the global Jacobian (scatter-based).
|
|
1313
|
+
If kernel is provided: kernel(u_elem, ctx) -> (n_ldofs, n_ldofs).
|
|
1314
|
+
"""
|
|
730
1315
|
return assemble_jacobian_scatter(
|
|
731
1316
|
space,
|
|
732
1317
|
res_form,
|
|
733
1318
|
u,
|
|
734
1319
|
params,
|
|
1320
|
+
kernel=kernel,
|
|
735
1321
|
sparse=sparse,
|
|
736
1322
|
return_flux_matrix=return_flux_matrix,
|
|
737
1323
|
pattern=pattern,
|
|
1324
|
+
n_chunks=n_chunks,
|
|
1325
|
+
pad_trace=pad_trace,
|
|
738
1326
|
)
|
|
739
1327
|
|
|
740
1328
|
|
|
@@ -748,13 +1336,19 @@ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
|
|
|
748
1336
|
return load * ctx.test.N # (n_q, n_ldofs)
|
|
749
1337
|
|
|
750
1338
|
|
|
751
|
-
|
|
1339
|
+
scalar_body_force_form._ff_kind = "linear" # type: ignore[attr-defined]
|
|
1340
|
+
scalar_body_force_form._ff_domain = "volume" # type: ignore[attr-defined]
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> FormKernel[Any]:
|
|
752
1344
|
"""
|
|
753
1345
|
Build a scalar linear form from a callable f(x_q) -> (n_q,).
|
|
754
1346
|
"""
|
|
755
1347
|
def _form(ctx: FormContext, _params):
|
|
756
1348
|
f_q = body_force(ctx.x_q)
|
|
757
1349
|
return f_q[..., None] * ctx.test.N
|
|
1350
|
+
_form._ff_kind = "linear" # type: ignore[attr-defined]
|
|
1351
|
+
_form._ff_domain = "volume" # type: ignore[attr-defined]
|
|
758
1352
|
return _form
|
|
759
1353
|
|
|
760
1354
|
|
|
@@ -762,7 +1356,7 @@ def make_scalar_body_force_form(body_force):
|
|
|
762
1356
|
constant_body_force_form = scalar_body_force_form
|
|
763
1357
|
|
|
764
1358
|
|
|
765
|
-
def _check_structured_box_connectivity():
|
|
1359
|
+
def _check_structured_box_connectivity() -> None:
|
|
766
1360
|
"""Quick connectivity check for nx=2, ny=1, nz=1 (non-structured order)."""
|
|
767
1361
|
box = StructuredHexBox(nx=2, ny=1, nz=1, lx=2.0, ly=1.0, lz=1.0)
|
|
768
1362
|
mesh = box.build()
|
|
@@ -775,7 +1369,7 @@ def _check_structured_box_connectivity():
|
|
|
775
1369
|
[0, 1, 4, 3, 6, 7, 10, 9], # element at i=0
|
|
776
1370
|
[1, 2, 5, 4, 7, 8, 11, 10], # element at i=1
|
|
777
1371
|
],
|
|
778
|
-
dtype=
|
|
1372
|
+
dtype=INDEX_DTYPE,
|
|
779
1373
|
)
|
|
780
1374
|
max_diff = int(jnp.max(jnp.abs(mesh.conn - expected_conn)))
|
|
781
1375
|
print("StructuredHexBox nx=2,ny=1,nz=1 conn matches expected:", max_diff == 0)
|