fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +136 -161
- fluxfem/core/__init__.py +172 -41
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/context_types.py +36 -0
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +15 -1
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +1503 -312
- fluxfem/helpers_wf.py +53 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +322 -8
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +18 -16
- fluxfem/mesh/io.py +8 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.3.dist-info/METADATA +0 -125
- fluxfem-0.1.3.dist-info/RECORD +0 -47
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/assembly.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Callable, Protocol, TypeVar,
|
|
2
|
+
from typing import Any, Callable, Literal, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union
|
|
3
3
|
import numpy as np
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
|
|
7
7
|
from ..mesh import HexMesh, StructuredHexBox
|
|
8
|
+
from .dtypes import INDEX_DTYPE
|
|
8
9
|
from .forms import FormContext
|
|
9
10
|
from .space import FESpaceBase
|
|
10
11
|
|
|
@@ -16,6 +17,341 @@ Kernel = Callable[[FormContext, P], Array]
|
|
|
16
17
|
ResidualForm = Callable[[FormContext, Array, P], Array]
|
|
17
18
|
ElementDofMapper = Callable[[Array], Array]
|
|
18
19
|
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from ..solver import FluxSparseMatrix, SparsityPattern
|
|
22
|
+
else:
|
|
23
|
+
FluxSparseMatrix = Any
|
|
24
|
+
SparsityPattern = Any
|
|
25
|
+
|
|
26
|
+
SparseCOO: TypeAlias = tuple[Array, Array, Array, int]
|
|
27
|
+
LinearCOO: TypeAlias = tuple[Array, Array, int]
|
|
28
|
+
JacobianReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
|
|
29
|
+
BilinearReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
|
|
30
|
+
LinearReturn: TypeAlias = Union[Array, LinearCOO]
|
|
31
|
+
MassReturn: TypeAlias = Union[FluxSparseMatrix, Array]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ElementBilinearKernel(Protocol):
|
|
35
|
+
def __call__(self, ctx: FormContext) -> Array: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ElementLinearKernel(Protocol):
|
|
39
|
+
def __call__(self, ctx: FormContext) -> Array: ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ElementResidualKernel(Protocol):
|
|
43
|
+
def __call__(self, ctx: FormContext, u_elem: Array) -> Array: ...
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ElementJacobianKernel(Protocol):
|
|
47
|
+
def __call__(self, u_elem: Array, ctx: FormContext) -> Array: ...
|
|
48
|
+
|
|
49
|
+
ElementKernel: TypeAlias = (
|
|
50
|
+
ElementBilinearKernel
|
|
51
|
+
| ElementLinearKernel
|
|
52
|
+
| ElementResidualKernel
|
|
53
|
+
| ElementJacobianKernel
|
|
54
|
+
)
|
|
55
|
+
def _get_pattern(space: SpaceLike, *, with_idx: bool) -> SparsityPattern | None:
|
|
56
|
+
if hasattr(space, "get_sparsity_pattern"):
|
|
57
|
+
return space.get_sparsity_pattern(with_idx=with_idx)
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _get_elem_rows(space: SpaceLike) -> Array:
|
|
62
|
+
if hasattr(space, "get_elem_rows"):
|
|
63
|
+
return space.get_elem_rows()
|
|
64
|
+
return space.elem_dofs.reshape(-1)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def chunk_pad_stats(n_elems: int, n_chunks: Optional[int]) -> dict[str, int | float | None]:
|
|
68
|
+
"""
|
|
69
|
+
Compute padding overhead for chunked assembly.
|
|
70
|
+
Returns dict with chunk_size, pad, n_pad, and pad_ratio.
|
|
71
|
+
"""
|
|
72
|
+
n_elems = int(n_elems)
|
|
73
|
+
if n_chunks is None or n_elems <= 0:
|
|
74
|
+
return {"chunk_size": None, "pad": 0, "n_pad": n_elems, "pad_ratio": 0.0}
|
|
75
|
+
n_chunks = min(int(n_chunks), n_elems)
|
|
76
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
77
|
+
pad = (-n_elems) % chunk_size
|
|
78
|
+
n_pad = n_elems + pad
|
|
79
|
+
pad_ratio = float(pad) / float(n_elems) if n_elems else 0.0
|
|
80
|
+
return {"chunk_size": int(chunk_size), "pad": int(pad), "n_pad": int(n_pad), "pad_ratio": pad_ratio}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _maybe_trace_pad(
|
|
84
|
+
stats: dict[str, int | float | None], *, n_chunks: Optional[int], pad_trace: bool
|
|
85
|
+
) -> None:
|
|
86
|
+
if not pad_trace or not jax.core.trace_ctx.is_top_level():
|
|
87
|
+
return
|
|
88
|
+
if n_chunks is None:
|
|
89
|
+
return
|
|
90
|
+
print(
|
|
91
|
+
"[pad]",
|
|
92
|
+
f"n_chunks={int(n_chunks)}",
|
|
93
|
+
f"chunk_size={stats['chunk_size']}",
|
|
94
|
+
f"pad={stats['pad']}",
|
|
95
|
+
f"pad_ratio={stats['pad_ratio']:.4f}",
|
|
96
|
+
flush=True,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class BatchedAssembler:
|
|
101
|
+
"""
|
|
102
|
+
Assemble on a fixed space with optional masking to keep shapes static.
|
|
103
|
+
|
|
104
|
+
Use `mask` to zero padded elements while keeping input shapes fixed.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
space: SpaceLike,
|
|
110
|
+
elem_data: Any,
|
|
111
|
+
elem_dofs: Array,
|
|
112
|
+
*,
|
|
113
|
+
pattern: SparsityPattern | None = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
self.space = space
|
|
116
|
+
self.elem_data = elem_data
|
|
117
|
+
self.elem_dofs = elem_dofs
|
|
118
|
+
self.n_elems = int(elem_dofs.shape[0])
|
|
119
|
+
self.n_ldofs = int(space.n_ldofs)
|
|
120
|
+
self.n_dofs = int(space.n_dofs)
|
|
121
|
+
self.pattern = pattern
|
|
122
|
+
self._rows = None
|
|
123
|
+
self._cols = None
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_space(
|
|
127
|
+
cls,
|
|
128
|
+
space: SpaceLike,
|
|
129
|
+
*,
|
|
130
|
+
dep: jnp.ndarray | None = None,
|
|
131
|
+
pattern: SparsityPattern | None = None,
|
|
132
|
+
) -> "BatchedAssembler":
|
|
133
|
+
elem_data = space.build_form_contexts(dep=dep)
|
|
134
|
+
return cls(space, elem_data, space.elem_dofs, pattern=pattern)
|
|
135
|
+
|
|
136
|
+
def make_mask(self, n_active: int) -> Array:
|
|
137
|
+
n_active = max(0, min(int(n_active), self.n_elems))
|
|
138
|
+
mask = np.zeros((self.n_elems,), dtype=float)
|
|
139
|
+
if n_active:
|
|
140
|
+
mask[:n_active] = 1.0
|
|
141
|
+
return jnp.asarray(mask)
|
|
142
|
+
|
|
143
|
+
def slice(self, n_active: int) -> "BatchedAssembler":
|
|
144
|
+
n_active = max(0, min(int(n_active), self.n_elems))
|
|
145
|
+
elem_data = jax.tree_util.tree_map(lambda x: x[:n_active], self.elem_data)
|
|
146
|
+
elem_dofs = self.elem_dofs[:n_active]
|
|
147
|
+
return BatchedAssembler(self.space, elem_data, elem_dofs, pattern=None)
|
|
148
|
+
|
|
149
|
+
def _rows_cols(self) -> tuple[Array, Array]:
|
|
150
|
+
if self.pattern is not None:
|
|
151
|
+
return self.pattern.rows, self.pattern.cols
|
|
152
|
+
if self._rows is None or self._cols is None:
|
|
153
|
+
elem_dofs = self.elem_dofs
|
|
154
|
+
n_ldofs = int(elem_dofs.shape[1])
|
|
155
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
156
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
157
|
+
self._rows = rows
|
|
158
|
+
self._cols = cols
|
|
159
|
+
return self._rows, self._cols
|
|
160
|
+
|
|
161
|
+
def assemble_bilinear_with_kernel(
|
|
162
|
+
self, kernel: ElementBilinearKernel, *, mask: Array | None = None
|
|
163
|
+
) -> FluxSparseMatrix:
|
|
164
|
+
"""
|
|
165
|
+
kernel(ctx) -> (n_ldofs, n_ldofs)
|
|
166
|
+
"""
|
|
167
|
+
from ..solver import FluxSparseMatrix
|
|
168
|
+
|
|
169
|
+
Ke = jax.vmap(kernel)(self.elem_data)
|
|
170
|
+
if mask is not None:
|
|
171
|
+
Ke = Ke * jnp.asarray(mask)[:, None, None]
|
|
172
|
+
data = Ke.reshape(-1)
|
|
173
|
+
if self.pattern is not None:
|
|
174
|
+
return FluxSparseMatrix(self.pattern, data)
|
|
175
|
+
rows, cols = self._rows_cols()
|
|
176
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
|
|
177
|
+
|
|
178
|
+
def assemble_bilinear(
|
|
179
|
+
self,
|
|
180
|
+
form: Kernel[P],
|
|
181
|
+
params: P,
|
|
182
|
+
*,
|
|
183
|
+
mask: Array | None = None,
|
|
184
|
+
kernel: ElementBilinearKernel | None = None,
|
|
185
|
+
jit: bool = True,
|
|
186
|
+
) -> FluxSparseMatrix:
|
|
187
|
+
if kernel is None:
|
|
188
|
+
kernel = make_element_bilinear_kernel(form, params, jit=jit)
|
|
189
|
+
return self.assemble_bilinear_with_kernel(kernel, mask=mask)
|
|
190
|
+
|
|
191
|
+
def assemble_linear_with_kernel(
|
|
192
|
+
self,
|
|
193
|
+
kernel: ElementLinearKernel,
|
|
194
|
+
*,
|
|
195
|
+
mask: Array | None = None,
|
|
196
|
+
dep: jnp.ndarray | None = None,
|
|
197
|
+
) -> Array:
|
|
198
|
+
"""
|
|
199
|
+
kernel(ctx) -> (n_ldofs,)
|
|
200
|
+
"""
|
|
201
|
+
elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
|
|
202
|
+
Fe = jax.vmap(kernel)(elem_data)
|
|
203
|
+
if mask is not None:
|
|
204
|
+
Fe = Fe * jnp.asarray(mask)[:, None]
|
|
205
|
+
rows = self.elem_dofs.reshape(-1)
|
|
206
|
+
data = Fe.reshape(-1)
|
|
207
|
+
return jax.ops.segment_sum(data, rows, self.n_dofs)
|
|
208
|
+
|
|
209
|
+
def assemble_linear(
|
|
210
|
+
self,
|
|
211
|
+
form: Kernel[P],
|
|
212
|
+
params: P,
|
|
213
|
+
*,
|
|
214
|
+
mask: Array | None = None,
|
|
215
|
+
dep: jnp.ndarray | None = None,
|
|
216
|
+
kernel: ElementLinearKernel | None = None,
|
|
217
|
+
) -> Array:
|
|
218
|
+
if kernel is not None:
|
|
219
|
+
return self.assemble_linear_with_kernel(kernel, mask=mask, dep=dep)
|
|
220
|
+
elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
|
|
221
|
+
includes_measure = getattr(form, "_includes_measure", False)
|
|
222
|
+
|
|
223
|
+
def per_element(ctx: FormContext):
|
|
224
|
+
integrand = form(ctx, params)
|
|
225
|
+
if includes_measure:
|
|
226
|
+
return integrand.sum(axis=0)
|
|
227
|
+
wJ = ctx.w * ctx.test.detJ
|
|
228
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
229
|
+
|
|
230
|
+
Fe = jax.vmap(per_element)(elem_data)
|
|
231
|
+
if mask is not None:
|
|
232
|
+
Fe = Fe * jnp.asarray(mask)[:, None]
|
|
233
|
+
rows = self.elem_dofs.reshape(-1)
|
|
234
|
+
data = Fe.reshape(-1)
|
|
235
|
+
return jax.ops.segment_sum(data, rows, self.n_dofs)
|
|
236
|
+
|
|
237
|
+
def assemble_mass_matrix(
|
|
238
|
+
self, *, mask: Array | None = None, lumped: bool = False
|
|
239
|
+
) -> MassReturn:
|
|
240
|
+
from ..solver import FluxSparseMatrix
|
|
241
|
+
|
|
242
|
+
n_ldofs = self.n_ldofs
|
|
243
|
+
|
|
244
|
+
def per_element(ctx: FormContext):
|
|
245
|
+
N = ctx.test.N
|
|
246
|
+
base = jnp.einsum("qa,qb->qab", N, N)
|
|
247
|
+
if hasattr(ctx.test, "value_dim"):
|
|
248
|
+
vd = int(ctx.test.value_dim)
|
|
249
|
+
I = jnp.eye(vd, dtype=N.dtype)
|
|
250
|
+
base = base[:, :, :, None, None] * I[None, None, None, :, :]
|
|
251
|
+
base = base.reshape(base.shape[0], n_ldofs, n_ldofs)
|
|
252
|
+
wJ = ctx.w * ctx.test.detJ
|
|
253
|
+
return jnp.einsum("qab,q->ab", base, wJ)
|
|
254
|
+
|
|
255
|
+
Me = jax.vmap(per_element)(self.elem_data)
|
|
256
|
+
if mask is not None:
|
|
257
|
+
Me = Me * jnp.asarray(mask)[:, None, None]
|
|
258
|
+
data = Me.reshape(-1)
|
|
259
|
+
rows, cols = self._rows_cols()
|
|
260
|
+
|
|
261
|
+
if lumped:
|
|
262
|
+
M = jnp.zeros((self.n_dofs,), dtype=data.dtype)
|
|
263
|
+
M = M.at[rows].add(data)
|
|
264
|
+
return M
|
|
265
|
+
|
|
266
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
|
|
267
|
+
|
|
268
|
+
def assemble_residual_with_kernel(
|
|
269
|
+
self, kernel: ElementResidualKernel, u: Array, *, mask: Array | None = None
|
|
270
|
+
) -> Array:
|
|
271
|
+
"""
|
|
272
|
+
kernel(ctx, u_elem) -> (n_ldofs,)
|
|
273
|
+
"""
|
|
274
|
+
u_elems = jnp.asarray(u)[self.elem_dofs]
|
|
275
|
+
elem_res = jax.vmap(kernel)(self.elem_data, u_elems)
|
|
276
|
+
if mask is not None:
|
|
277
|
+
elem_res = elem_res * jnp.asarray(mask)[:, None]
|
|
278
|
+
rows = self.elem_dofs.reshape(-1)
|
|
279
|
+
data = elem_res.reshape(-1)
|
|
280
|
+
return jax.ops.segment_sum(data, rows, self.n_dofs)
|
|
281
|
+
|
|
282
|
+
def assemble_residual(
|
|
283
|
+
self,
|
|
284
|
+
res_form: ResidualForm[P],
|
|
285
|
+
u: Array,
|
|
286
|
+
params: P,
|
|
287
|
+
*,
|
|
288
|
+
mask: Array | None = None,
|
|
289
|
+
kernel: ElementResidualKernel | None = None,
|
|
290
|
+
) -> Array:
|
|
291
|
+
if kernel is None:
|
|
292
|
+
kernel = make_element_residual_kernel(res_form, params)
|
|
293
|
+
return self.assemble_residual_with_kernel(kernel, u, mask=mask)
|
|
294
|
+
|
|
295
|
+
def assemble_jacobian_with_kernel(
|
|
296
|
+
self,
|
|
297
|
+
kernel: ElementJacobianKernel,
|
|
298
|
+
u: Array,
|
|
299
|
+
*,
|
|
300
|
+
mask: Array | None = None,
|
|
301
|
+
sparse: bool = True,
|
|
302
|
+
return_flux_matrix: bool = False,
|
|
303
|
+
) -> JacobianReturn:
|
|
304
|
+
"""
|
|
305
|
+
kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
|
|
306
|
+
"""
|
|
307
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
308
|
+
|
|
309
|
+
u_elems = jnp.asarray(u)[self.elem_dofs]
|
|
310
|
+
J_e = jax.vmap(kernel)(u_elems, self.elem_data)
|
|
311
|
+
if mask is not None:
|
|
312
|
+
J_e = J_e * jnp.asarray(mask)[:, None, None]
|
|
313
|
+
data = J_e.reshape(-1)
|
|
314
|
+
if sparse:
|
|
315
|
+
if self.pattern is not None:
|
|
316
|
+
if return_flux_matrix:
|
|
317
|
+
return FluxSparseMatrix(self.pattern, data)
|
|
318
|
+
return self.pattern.rows, self.pattern.cols, data, self.n_dofs
|
|
319
|
+
rows, cols = self._rows_cols()
|
|
320
|
+
if return_flux_matrix:
|
|
321
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
|
|
322
|
+
return rows, cols, data, self.n_dofs
|
|
323
|
+
rows, cols = self._rows_cols()
|
|
324
|
+
idx = (rows.astype(jnp.int64) * int(self.n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
325
|
+
n_entries = self.n_dofs * self.n_dofs
|
|
326
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
327
|
+
update_window_dims=(),
|
|
328
|
+
inserted_window_dims=(0,),
|
|
329
|
+
scatter_dims_to_operand_dims=(0,),
|
|
330
|
+
)
|
|
331
|
+
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
332
|
+
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
333
|
+
return K_flat.reshape(self.n_dofs, self.n_dofs)
|
|
334
|
+
|
|
335
|
+
def assemble_jacobian(
|
|
336
|
+
self,
|
|
337
|
+
res_form: ResidualForm[P],
|
|
338
|
+
u: Array,
|
|
339
|
+
params: P,
|
|
340
|
+
*,
|
|
341
|
+
mask: Array | None = None,
|
|
342
|
+
kernel: ElementJacobianKernel | None = None,
|
|
343
|
+
sparse: bool = True,
|
|
344
|
+
return_flux_matrix: bool = False,
|
|
345
|
+
) -> JacobianReturn:
|
|
346
|
+
if kernel is None:
|
|
347
|
+
kernel = make_element_jacobian_kernel(res_form, params)
|
|
348
|
+
return self.assemble_jacobian_with_kernel(
|
|
349
|
+
kernel,
|
|
350
|
+
u,
|
|
351
|
+
mask=mask,
|
|
352
|
+
sparse=sparse,
|
|
353
|
+
return_flux_matrix=return_flux_matrix,
|
|
354
|
+
)
|
|
19
355
|
|
|
20
356
|
class SpaceLike(FESpaceBase, Protocol):
|
|
21
357
|
pass
|
|
@@ -28,7 +364,7 @@ def assemble_bilinear_dense(
|
|
|
28
364
|
*,
|
|
29
365
|
sparse: bool = False,
|
|
30
366
|
return_flux_matrix: bool = False,
|
|
31
|
-
):
|
|
367
|
+
) -> BilinearReturn:
|
|
32
368
|
"""
|
|
33
369
|
Similar to scikit-fem's asm(biform, basis).
|
|
34
370
|
kernel: FormContext, params -> (n_ldofs, n_ldofs)
|
|
@@ -47,11 +383,15 @@ def assemble_bilinear_dense(
|
|
|
47
383
|
|
|
48
384
|
# ---- scatter into COO format ----
|
|
49
385
|
# row/col indices (n_elems, n_ldofs, n_ldofs)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
386
|
+
pat = _get_pattern(space, with_idx=False)
|
|
387
|
+
if pat is None:
|
|
388
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
|
|
389
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
|
|
390
|
+
rows = rows.reshape(-1)
|
|
391
|
+
cols = cols.reshape(-1)
|
|
392
|
+
else:
|
|
393
|
+
rows = pat.rows
|
|
394
|
+
cols = pat.cols
|
|
55
395
|
data = K_e_all.reshape(-1)
|
|
56
396
|
|
|
57
397
|
# Flatten indices for segment_sum via (row * n_dofs + col)
|
|
@@ -71,18 +411,22 @@ def assemble_bilinear_dense(
|
|
|
71
411
|
|
|
72
412
|
|
|
73
413
|
def assemble_bilinear_form(
|
|
74
|
-
space,
|
|
75
|
-
form,
|
|
76
|
-
params,
|
|
414
|
+
space: SpaceLike,
|
|
415
|
+
form: Kernel[P],
|
|
416
|
+
params: P,
|
|
77
417
|
*,
|
|
78
|
-
pattern=None,
|
|
79
|
-
|
|
418
|
+
pattern: SparsityPattern | None = None,
|
|
419
|
+
n_chunks: Optional[int] = None, # None -> no chunking
|
|
80
420
|
dep: jnp.ndarray | None = None,
|
|
81
|
-
|
|
421
|
+
kernel: ElementBilinearKernel | None = None,
|
|
422
|
+
jit: bool = True,
|
|
423
|
+
pad_trace: bool = False,
|
|
424
|
+
) -> FluxSparseMatrix:
|
|
82
425
|
"""
|
|
83
426
|
Assemble a sparse bilinear form into a FluxSparseMatrix.
|
|
84
427
|
|
|
85
428
|
Expects form(ctx, params) -> (n_q, n_ldofs, n_ldofs).
|
|
429
|
+
If kernel is provided: kernel(ctx) -> (n_ldofs, n_ldofs).
|
|
86
430
|
"""
|
|
87
431
|
from ..solver import FluxSparseMatrix
|
|
88
432
|
|
|
@@ -104,18 +448,27 @@ def assemble_bilinear_form(
|
|
|
104
448
|
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
105
449
|
return (integrand * wJ[:, None, None]).sum(axis=0) # (m, m)
|
|
106
450
|
|
|
451
|
+
if kernel is None:
|
|
452
|
+
kernel = make_element_bilinear_kernel(form, params, jit=jit)
|
|
453
|
+
|
|
107
454
|
# --- no-chunk path (your current implementation) ---
|
|
108
|
-
if
|
|
109
|
-
K_e_all = jax.vmap(
|
|
455
|
+
if n_chunks is None:
|
|
456
|
+
K_e_all = jax.vmap(kernel)(elem_data) # (n_elems, m, m)
|
|
110
457
|
data = K_e_all.reshape(-1)
|
|
111
458
|
return FluxSparseMatrix(pat, data)
|
|
112
459
|
|
|
113
460
|
# --- chunked path ---
|
|
114
461
|
n_elems = space.elem_dofs.shape[0]
|
|
462
|
+
if n_chunks <= 0:
|
|
463
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
464
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
465
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
466
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
467
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
115
468
|
# Ideally get m from pat (otherwise infer from one element).
|
|
116
469
|
m = getattr(pat, "n_ldofs", None)
|
|
117
470
|
if m is None:
|
|
118
|
-
m =
|
|
471
|
+
m = kernel(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
|
|
119
472
|
|
|
120
473
|
# Pad to fixed-size chunks for JIT stability.
|
|
121
474
|
pad = (-n_elems) % chunk_size
|
|
@@ -141,7 +494,7 @@ def assemble_bilinear_form(
|
|
|
141
494
|
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
142
495
|
elem_data_pad,
|
|
143
496
|
)
|
|
144
|
-
Ke = jax.vmap(
|
|
497
|
+
Ke = jax.vmap(kernel)(ctx_chunk) # (chunk, m, m)
|
|
145
498
|
return Ke.reshape(-1) # (chunk*m*m,)
|
|
146
499
|
|
|
147
500
|
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
@@ -149,7 +502,13 @@ def assemble_bilinear_form(
|
|
|
149
502
|
return FluxSparseMatrix(pat, data)
|
|
150
503
|
|
|
151
504
|
|
|
152
|
-
def assemble_mass_matrix(
|
|
505
|
+
def assemble_mass_matrix(
|
|
506
|
+
space: SpaceLike,
|
|
507
|
+
*,
|
|
508
|
+
lumped: bool = False,
|
|
509
|
+
n_chunks: Optional[int] = None,
|
|
510
|
+
pad_trace: bool = False,
|
|
511
|
+
) -> MassReturn:
|
|
153
512
|
"""
|
|
154
513
|
Assemble mass matrix M_ij = ∫ N_i N_j dΩ.
|
|
155
514
|
Supports scalar and vector spaces. If lumped=True, rows are summed to diagonal.
|
|
@@ -170,11 +529,17 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
|
|
|
170
529
|
wJ = ctx.w * ctx.test.detJ
|
|
171
530
|
return jnp.einsum("qab,q->ab", base, wJ)
|
|
172
531
|
|
|
173
|
-
if
|
|
532
|
+
if n_chunks is None:
|
|
174
533
|
M_e_all = jax.vmap(per_element)(ctxs) # (n_elems, n_ldofs, n_ldofs)
|
|
175
534
|
data = M_e_all.reshape(-1)
|
|
176
535
|
else:
|
|
177
536
|
n_elems = space.elem_dofs.shape[0]
|
|
537
|
+
if n_chunks <= 0:
|
|
538
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
539
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
540
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
541
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
542
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
178
543
|
pad = (-n_elems) % chunk_size
|
|
179
544
|
if pad:
|
|
180
545
|
ctxs_pad = jax.tree_util.tree_map(
|
|
@@ -205,8 +570,13 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
|
|
|
205
570
|
data = data_chunks.reshape(-1)[: n_elems * n_ldofs * n_ldofs]
|
|
206
571
|
|
|
207
572
|
elem_dofs = space.elem_dofs
|
|
208
|
-
|
|
209
|
-
|
|
573
|
+
pat = _get_pattern(space, with_idx=False)
|
|
574
|
+
if pat is None:
|
|
575
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
576
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
577
|
+
else:
|
|
578
|
+
rows = pat.rows
|
|
579
|
+
cols = pat.cols
|
|
210
580
|
|
|
211
581
|
if lumped:
|
|
212
582
|
n_dofs = space.n_dofs
|
|
@@ -222,12 +592,15 @@ def assemble_linear_form(
|
|
|
222
592
|
form: Kernel[P],
|
|
223
593
|
params: P,
|
|
224
594
|
*,
|
|
595
|
+
kernel: ElementLinearKernel | None = None,
|
|
225
596
|
sparse: bool = False,
|
|
226
|
-
|
|
597
|
+
n_chunks: Optional[int] = None,
|
|
227
598
|
dep: jnp.ndarray | None = None,
|
|
228
|
-
|
|
599
|
+
pad_trace: bool = False,
|
|
600
|
+
) -> LinearReturn:
|
|
229
601
|
"""
|
|
230
602
|
Expects form(ctx, params) -> (n_q, n_ldofs) and integrates Σ_q form * wJ for RHS.
|
|
603
|
+
If kernel is provided: kernel(ctx) -> (n_ldofs,).
|
|
231
604
|
"""
|
|
232
605
|
elem_dofs = space.elem_dofs
|
|
233
606
|
n_dofs = space.n_dofs
|
|
@@ -237,19 +610,28 @@ def assemble_linear_form(
|
|
|
237
610
|
|
|
238
611
|
includes_measure = getattr(form, "_includes_measure", False)
|
|
239
612
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
613
|
+
if kernel is None:
|
|
614
|
+
def per_element(ctx: FormContext):
|
|
615
|
+
integrand = form(ctx, params) # (n_q, m)
|
|
616
|
+
if includes_measure:
|
|
617
|
+
return integrand.sum(axis=0)
|
|
618
|
+
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
619
|
+
return (integrand * wJ[:, None]).sum(axis=0) # (m,)
|
|
620
|
+
else:
|
|
621
|
+
per_element = kernel
|
|
246
622
|
|
|
247
|
-
if
|
|
623
|
+
if n_chunks is None:
|
|
248
624
|
F_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m)
|
|
249
625
|
data = F_e_all.reshape(-1)
|
|
250
626
|
else:
|
|
251
627
|
n_elems = space.elem_dofs.shape[0]
|
|
252
628
|
m = n_ldofs
|
|
629
|
+
if n_chunks <= 0:
|
|
630
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
631
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
632
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
633
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
634
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
253
635
|
pad = (-n_elems) % chunk_size
|
|
254
636
|
if pad:
|
|
255
637
|
elem_data_pad = jax.tree_util.tree_map(
|
|
@@ -279,7 +661,7 @@ def assemble_linear_form(
|
|
|
279
661
|
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
280
662
|
data = data_chunks.reshape(-1)[: n_elems * m]
|
|
281
663
|
|
|
282
|
-
rows =
|
|
664
|
+
rows = _get_elem_rows(space)
|
|
283
665
|
|
|
284
666
|
if sparse:
|
|
285
667
|
return rows, data, n_dofs
|
|
@@ -318,7 +700,7 @@ def assemble_jacobian_global(
|
|
|
318
700
|
*,
|
|
319
701
|
sparse: bool = False,
|
|
320
702
|
return_flux_matrix: bool = False,
|
|
321
|
-
):
|
|
703
|
+
) -> JacobianReturn:
|
|
322
704
|
"""
|
|
323
705
|
Assemble Jacobian (dR/du) from element residual res_form.
|
|
324
706
|
res_form(ctx, u_elem, params) -> (n_q, n_ldofs)
|
|
@@ -339,11 +721,16 @@ def assemble_jacobian_global(
|
|
|
339
721
|
jac_fun = jax.jacrev(fe_fun, argnums=0)
|
|
340
722
|
|
|
341
723
|
u_elems = u[elem_dofs] # (n_elems, n_ldofs)
|
|
342
|
-
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=
|
|
724
|
+
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
|
|
343
725
|
J_e_all = jax.vmap(jac_fun)(u_elems, elem_data, elem_ids) # (n_elems, m, m)
|
|
344
726
|
|
|
345
|
-
|
|
346
|
-
|
|
727
|
+
pat = _get_pattern(space, with_idx=False)
|
|
728
|
+
if pat is None:
|
|
729
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
730
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
731
|
+
else:
|
|
732
|
+
rows = pat.rows
|
|
733
|
+
cols = pat.cols
|
|
347
734
|
data = J_e_all.reshape(-1)
|
|
348
735
|
|
|
349
736
|
if sparse:
|
|
@@ -358,7 +745,7 @@ def assemble_jacobian_global(
|
|
|
358
745
|
return K_flat.reshape(n_dofs, n_dofs)
|
|
359
746
|
|
|
360
747
|
|
|
361
|
-
def
|
|
748
|
+
def assemble_jacobian_elementwise(
|
|
362
749
|
space: SpaceLike,
|
|
363
750
|
res_form: ResidualForm[P],
|
|
364
751
|
u: jnp.ndarray,
|
|
@@ -366,9 +753,9 @@ def assemble_jacobian_elementwise_xla(
|
|
|
366
753
|
*,
|
|
367
754
|
sparse: bool = False,
|
|
368
755
|
return_flux_matrix: bool = False,
|
|
369
|
-
):
|
|
756
|
+
) -> JacobianReturn:
|
|
370
757
|
"""
|
|
371
|
-
Assemble Jacobian with element kernels
|
|
758
|
+
Assemble Jacobian with element kernels via vmap + scatter_add.
|
|
372
759
|
Recompiles if n_dofs changes, but independent of element count.
|
|
373
760
|
"""
|
|
374
761
|
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
@@ -388,8 +775,13 @@ def assemble_jacobian_elementwise_xla(
|
|
|
388
775
|
u_elems = u[elem_dofs]
|
|
389
776
|
J_e_all = jax.vmap(jac_fun)(u_elems, ctxs) # (n_elems, m, m)
|
|
390
777
|
|
|
391
|
-
|
|
392
|
-
|
|
778
|
+
pat = _get_pattern(space, with_idx=False)
|
|
779
|
+
if pat is None:
|
|
780
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
781
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
782
|
+
else:
|
|
783
|
+
rows = pat.rows
|
|
784
|
+
cols = pat.cols
|
|
393
785
|
data = J_e_all.reshape(-1)
|
|
394
786
|
|
|
395
787
|
if sparse:
|
|
@@ -416,7 +808,7 @@ def assemble_residual_global(
|
|
|
416
808
|
params: P,
|
|
417
809
|
*,
|
|
418
810
|
sparse: bool = False
|
|
419
|
-
):
|
|
811
|
+
) -> LinearReturn:
|
|
420
812
|
"""
|
|
421
813
|
Assemble residual vector that depends on u.
|
|
422
814
|
form(ctx, u_elem, params) -> (n_q, n_ldofs)
|
|
@@ -435,10 +827,10 @@ def assemble_residual_global(
|
|
|
435
827
|
fe = (integrand * wJ[:, None]).sum(axis=0)
|
|
436
828
|
return fe
|
|
437
829
|
|
|
438
|
-
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=
|
|
830
|
+
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
|
|
439
831
|
F_e_all = jax.vmap(per_element)(elem_data, elem_dofs, elem_ids) # (n_elems, m)
|
|
440
832
|
|
|
441
|
-
rows =
|
|
833
|
+
rows = _get_elem_rows(space)
|
|
442
834
|
data = F_e_all.reshape(-1)
|
|
443
835
|
|
|
444
836
|
if sparse:
|
|
@@ -448,16 +840,16 @@ def assemble_residual_global(
|
|
|
448
840
|
return F
|
|
449
841
|
|
|
450
842
|
|
|
451
|
-
def
|
|
843
|
+
def assemble_residual_elementwise(
|
|
452
844
|
space: SpaceLike,
|
|
453
845
|
res_form: ResidualForm[P],
|
|
454
846
|
u: jnp.ndarray,
|
|
455
847
|
params: P,
|
|
456
848
|
*,
|
|
457
849
|
sparse: bool = False,
|
|
458
|
-
):
|
|
850
|
+
) -> LinearReturn:
|
|
459
851
|
"""
|
|
460
|
-
Assemble residual using element kernels
|
|
852
|
+
Assemble residual using element kernels via vmap + scatter_add.
|
|
461
853
|
Recompiles if n_dofs changes, but independent of element count.
|
|
462
854
|
"""
|
|
463
855
|
elem_dofs = space.elem_dofs
|
|
@@ -471,7 +863,7 @@ def assemble_residual_elementwise_xla(
|
|
|
471
863
|
|
|
472
864
|
u_elems = u[elem_dofs]
|
|
473
865
|
F_e_all = jax.vmap(per_element)(ctxs, u_elems) # (n_elems, m)
|
|
474
|
-
rows =
|
|
866
|
+
rows = _get_elem_rows(space)
|
|
475
867
|
data = F_e_all.reshape(-1)
|
|
476
868
|
|
|
477
869
|
if sparse:
|
|
@@ -487,7 +879,44 @@ def assemble_residual_elementwise_xla(
|
|
|
487
879
|
return F
|
|
488
880
|
|
|
489
881
|
|
|
490
|
-
|
|
882
|
+
# Backward compatibility aliases (prefer assemble_*_elementwise).
|
|
883
|
+
assemble_jacobian_elementwise_xla = assemble_jacobian_elementwise
|
|
884
|
+
assemble_residual_elementwise_xla = assemble_residual_elementwise
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def make_element_bilinear_kernel(
|
|
888
|
+
form: Kernel[P], params: P, *, jit: bool = True
|
|
889
|
+
) -> ElementBilinearKernel:
|
|
890
|
+
"""Element kernel: (ctx) -> Ke."""
|
|
891
|
+
|
|
892
|
+
def per_element(ctx: FormContext):
|
|
893
|
+
integrand = form(ctx, params)
|
|
894
|
+
if getattr(form, "_includes_measure", False):
|
|
895
|
+
return integrand.sum(axis=0)
|
|
896
|
+
wJ = ctx.w * ctx.test.detJ
|
|
897
|
+
return (integrand * wJ[:, None, None]).sum(axis=0)
|
|
898
|
+
|
|
899
|
+
return jax.jit(per_element) if jit else per_element
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
def make_element_linear_kernel(
|
|
903
|
+
form: Kernel[P], params: P, *, jit: bool = True
|
|
904
|
+
) -> ElementLinearKernel:
|
|
905
|
+
"""Element kernel: (ctx) -> fe."""
|
|
906
|
+
|
|
907
|
+
def per_element(ctx: FormContext):
|
|
908
|
+
integrand = form(ctx, params)
|
|
909
|
+
if getattr(form, "_includes_measure", False):
|
|
910
|
+
return integrand.sum(axis=0)
|
|
911
|
+
wJ = ctx.w * ctx.test.detJ
|
|
912
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
913
|
+
|
|
914
|
+
return jax.jit(per_element) if jit else per_element
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def make_element_residual_kernel(
|
|
918
|
+
res_form: ResidualForm[P], params: P
|
|
919
|
+
) -> ElementResidualKernel:
|
|
491
920
|
"""Jitted element residual kernel: (ctx, u_elem) -> fe."""
|
|
492
921
|
|
|
493
922
|
def per_element(ctx: FormContext, u_elem: jnp.ndarray):
|
|
@@ -500,7 +929,9 @@ def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
|
|
|
500
929
|
return jax.jit(per_element)
|
|
501
930
|
|
|
502
931
|
|
|
503
|
-
def make_element_jacobian_kernel(
|
|
932
|
+
def make_element_jacobian_kernel(
|
|
933
|
+
res_form: ResidualForm[P], params: P
|
|
934
|
+
) -> ElementJacobianKernel:
|
|
504
935
|
"""Jitted element Jacobian kernel: (ctx, u_elem) -> Ke."""
|
|
505
936
|
|
|
506
937
|
def fe_fun(u_elem, ctx: FormContext):
|
|
@@ -513,7 +944,9 @@ def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
|
|
|
513
944
|
return jax.jit(jax.jacrev(fe_fun, argnums=0))
|
|
514
945
|
|
|
515
946
|
|
|
516
|
-
def element_residual(
|
|
947
|
+
def element_residual(
|
|
948
|
+
res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P
|
|
949
|
+
) -> Any:
|
|
517
950
|
"""
|
|
518
951
|
Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
|
|
519
952
|
Returns shape (n_ldofs,).
|
|
@@ -538,7 +971,9 @@ def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
|
|
|
538
971
|
return jax.tree_util.tree_map(lambda x: jnp.einsum("qa,q->a", x, ctx.w * ctx.test.detJ), integrand)
|
|
539
972
|
|
|
540
973
|
|
|
541
|
-
def element_jacobian(
|
|
974
|
+
def element_jacobian(
|
|
975
|
+
res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P
|
|
976
|
+
) -> Any:
|
|
542
977
|
"""
|
|
543
978
|
Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
|
|
544
979
|
"""
|
|
@@ -548,7 +983,42 @@ def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
|
|
|
548
983
|
return jax.jacfwd(_r_elem)(u_elem)
|
|
549
984
|
|
|
550
985
|
|
|
551
|
-
def
|
|
986
|
+
def make_element_kernel(
|
|
987
|
+
form: Kernel[P] | ResidualForm[P],
|
|
988
|
+
params: P,
|
|
989
|
+
*,
|
|
990
|
+
kind: Literal["bilinear", "linear", "residual", "jacobian"],
|
|
991
|
+
jit: bool = True,
|
|
992
|
+
) -> ElementKernel:
|
|
993
|
+
"""
|
|
994
|
+
Unified entry point for element kernels.
|
|
995
|
+
|
|
996
|
+
kind:
|
|
997
|
+
- "bilinear": kernel(ctx) -> (n_ldofs, n_ldofs)
|
|
998
|
+
- "linear": kernel(ctx) -> (n_ldofs,)
|
|
999
|
+
- "residual": kernel(ctx, u_elem) -> (n_ldofs,)
|
|
1000
|
+
- "jacobian": kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
|
|
1001
|
+
"""
|
|
1002
|
+
kind = kind.lower()
|
|
1003
|
+
if kind == "bilinear":
|
|
1004
|
+
return make_element_bilinear_kernel(form, params, jit=jit)
|
|
1005
|
+
if kind == "linear":
|
|
1006
|
+
def per_element(ctx: FormContext):
|
|
1007
|
+
integrand = form(ctx, params)
|
|
1008
|
+
if getattr(form, "_includes_measure", False):
|
|
1009
|
+
return integrand.sum(axis=0)
|
|
1010
|
+
wJ = ctx.w * ctx.test.detJ
|
|
1011
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
1012
|
+
|
|
1013
|
+
return jax.jit(per_element) if jit else per_element
|
|
1014
|
+
if kind == "residual":
|
|
1015
|
+
return make_element_residual_kernel(form, params)
|
|
1016
|
+
if kind == "jacobian":
|
|
1017
|
+
return make_element_jacobian_kernel(form, params)
|
|
1018
|
+
raise ValueError(f"Unknown kernel kind: {kind}")
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True) -> SparsityPattern:
|
|
552
1022
|
"""
|
|
553
1023
|
Build a SparsityPattern (rows/cols[/idx]) that is independent of the solution.
|
|
554
1024
|
NOTE: rows/cols ordering matches assemble_jacobian_values(...).reshape(-1)
|
|
@@ -557,24 +1027,24 @@ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
|
|
|
557
1027
|
"""
|
|
558
1028
|
from ..solver import SparsityPattern # local import to avoid circular
|
|
559
1029
|
|
|
560
|
-
elem_dofs = jnp.asarray(space.elem_dofs, dtype=
|
|
1030
|
+
elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE)
|
|
561
1031
|
n_dofs = int(space.n_dofs)
|
|
562
1032
|
n_ldofs = int(space.n_ldofs)
|
|
563
1033
|
|
|
564
|
-
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(
|
|
565
|
-
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(
|
|
1034
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(INDEX_DTYPE)
|
|
1035
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(INDEX_DTYPE)
|
|
566
1036
|
|
|
567
1037
|
key = rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)
|
|
568
|
-
order = jnp.argsort(key).astype(
|
|
1038
|
+
order = jnp.argsort(key).astype(INDEX_DTYPE)
|
|
569
1039
|
rows_sorted = rows[order]
|
|
570
1040
|
cols_sorted = cols[order]
|
|
571
|
-
counts = jnp.bincount(rows_sorted, length=n_dofs).astype(
|
|
572
|
-
indptr_j = jnp.concatenate([jnp.array([0], dtype=
|
|
573
|
-
indices_j = cols_sorted.astype(
|
|
1041
|
+
counts = jnp.bincount(rows_sorted, length=n_dofs).astype(INDEX_DTYPE)
|
|
1042
|
+
indptr_j = jnp.concatenate([jnp.array([0], dtype=INDEX_DTYPE), jnp.cumsum(counts)])
|
|
1043
|
+
indices_j = cols_sorted.astype(INDEX_DTYPE)
|
|
574
1044
|
perm = order
|
|
575
1045
|
|
|
576
1046
|
if with_idx:
|
|
577
|
-
idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(
|
|
1047
|
+
idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
578
1048
|
return SparsityPattern(
|
|
579
1049
|
rows=rows,
|
|
580
1050
|
cols=cols,
|
|
@@ -601,8 +1071,10 @@ def assemble_jacobian_values(
|
|
|
601
1071
|
u: jnp.ndarray,
|
|
602
1072
|
params: P,
|
|
603
1073
|
*,
|
|
604
|
-
kernel=None,
|
|
605
|
-
|
|
1074
|
+
kernel: ElementJacobianKernel | None = None,
|
|
1075
|
+
n_chunks: Optional[int] = None,
|
|
1076
|
+
pad_trace: bool = False,
|
|
1077
|
+
) -> Array:
|
|
606
1078
|
"""
|
|
607
1079
|
Assemble only the numeric values for the Jacobian (pattern-free).
|
|
608
1080
|
"""
|
|
@@ -610,8 +1082,49 @@ def assemble_jacobian_values(
|
|
|
610
1082
|
ker = kernel if kernel is not None else make_element_jacobian_kernel(res_form, params)
|
|
611
1083
|
|
|
612
1084
|
u_elems = u[space.elem_dofs]
|
|
613
|
-
|
|
614
|
-
|
|
1085
|
+
if n_chunks is None:
|
|
1086
|
+
J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
|
|
1087
|
+
return J_e_all.reshape(-1)
|
|
1088
|
+
|
|
1089
|
+
n_elems = int(u_elems.shape[0])
|
|
1090
|
+
if n_chunks <= 0:
|
|
1091
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
1092
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
1093
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
1094
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
1095
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
1096
|
+
pad = (-n_elems) % chunk_size
|
|
1097
|
+
if pad:
|
|
1098
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
1099
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
1100
|
+
ctxs,
|
|
1101
|
+
)
|
|
1102
|
+
u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
|
|
1103
|
+
else:
|
|
1104
|
+
ctxs_pad = ctxs
|
|
1105
|
+
u_elems_pad = u_elems
|
|
1106
|
+
|
|
1107
|
+
n_pad = n_elems + pad
|
|
1108
|
+
n_chunks = n_pad // chunk_size
|
|
1109
|
+
m = int(space.n_ldofs)
|
|
1110
|
+
|
|
1111
|
+
def _slice_first_dim(x, start, size):
|
|
1112
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
1113
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
1114
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
1115
|
+
|
|
1116
|
+
def chunk_fn(i):
|
|
1117
|
+
start = i * chunk_size
|
|
1118
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
1119
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
1120
|
+
ctxs_pad,
|
|
1121
|
+
)
|
|
1122
|
+
u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
|
|
1123
|
+
J_e = jax.vmap(ker)(u_chunk, ctx_chunk)
|
|
1124
|
+
return J_e.reshape(-1)
|
|
1125
|
+
|
|
1126
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
1127
|
+
return data_chunks.reshape(-1)[: n_elems * m * m]
|
|
615
1128
|
|
|
616
1129
|
|
|
617
1130
|
def assemble_residual_scatter(
|
|
@@ -620,9 +1133,11 @@ def assemble_residual_scatter(
|
|
|
620
1133
|
u: jnp.ndarray,
|
|
621
1134
|
params: P,
|
|
622
1135
|
*,
|
|
623
|
-
kernel=None,
|
|
1136
|
+
kernel: ElementResidualKernel | None = None,
|
|
624
1137
|
sparse: bool = False,
|
|
625
|
-
|
|
1138
|
+
n_chunks: Optional[int] = None,
|
|
1139
|
+
pad_trace: bool = False,
|
|
1140
|
+
) -> LinearReturn:
|
|
626
1141
|
"""
|
|
627
1142
|
Assemble residual using jitted element kernel + vmap + scatter_add.
|
|
628
1143
|
Avoids Python loops; good for JIT stability.
|
|
@@ -633,20 +1148,62 @@ def assemble_residual_scatter(
|
|
|
633
1148
|
"""
|
|
634
1149
|
elem_dofs = space.elem_dofs
|
|
635
1150
|
n_dofs = space.n_dofs
|
|
636
|
-
if
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
1151
|
+
if jax.core.trace_ctx.is_top_level():
|
|
1152
|
+
if np.max(elem_dofs) >= n_dofs:
|
|
1153
|
+
raise ValueError("elem_dofs contains index outside n_dofs")
|
|
1154
|
+
if np.min(elem_dofs) < 0:
|
|
1155
|
+
raise ValueError("elem_dofs contains negative index")
|
|
640
1156
|
ctxs = space.build_form_contexts()
|
|
641
1157
|
ker = kernel if kernel is not None else make_element_residual_kernel(res_form, params)
|
|
642
1158
|
|
|
643
1159
|
u_elems = u[elem_dofs]
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
1160
|
+
if n_chunks is None:
|
|
1161
|
+
elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
|
|
1162
|
+
else:
|
|
1163
|
+
n_elems = int(u_elems.shape[0])
|
|
1164
|
+
if n_chunks <= 0:
|
|
1165
|
+
raise ValueError("n_chunks must be a positive integer.")
|
|
1166
|
+
n_chunks = min(int(n_chunks), int(n_elems))
|
|
1167
|
+
chunk_size = (n_elems + n_chunks - 1) // n_chunks
|
|
1168
|
+
stats = chunk_pad_stats(n_elems, n_chunks)
|
|
1169
|
+
_maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
|
|
1170
|
+
pad = (-n_elems) % chunk_size
|
|
1171
|
+
if pad:
|
|
1172
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
1173
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
1174
|
+
ctxs,
|
|
1175
|
+
)
|
|
1176
|
+
u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
|
|
1177
|
+
else:
|
|
1178
|
+
ctxs_pad = ctxs
|
|
1179
|
+
u_elems_pad = u_elems
|
|
1180
|
+
|
|
1181
|
+
n_pad = n_elems + pad
|
|
1182
|
+
n_chunks = n_pad // chunk_size
|
|
1183
|
+
|
|
1184
|
+
def _slice_first_dim(x, start, size):
|
|
1185
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
1186
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
1187
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
648
1188
|
|
|
649
|
-
|
|
1189
|
+
def chunk_fn(i):
|
|
1190
|
+
start = i * chunk_size
|
|
1191
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
1192
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
1193
|
+
ctxs_pad,
|
|
1194
|
+
)
|
|
1195
|
+
u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
|
|
1196
|
+
res_chunk = jax.vmap(ker)(ctx_chunk, u_chunk)
|
|
1197
|
+
return res_chunk.reshape(-1)
|
|
1198
|
+
|
|
1199
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
1200
|
+
elem_res = data_chunks.reshape(-1)[: n_elems * int(space.n_ldofs)].reshape(n_elems, -1)
|
|
1201
|
+
if jax.core.trace_ctx.is_top_level():
|
|
1202
|
+
if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
|
|
1203
|
+
bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
|
|
1204
|
+
raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
|
|
1205
|
+
|
|
1206
|
+
rows = _get_elem_rows(space)
|
|
650
1207
|
data = elem_res.reshape(-1)
|
|
651
1208
|
|
|
652
1209
|
if sparse:
|
|
@@ -668,11 +1225,13 @@ def assemble_jacobian_scatter(
|
|
|
668
1225
|
u: jnp.ndarray,
|
|
669
1226
|
params: P,
|
|
670
1227
|
*,
|
|
671
|
-
kernel=None,
|
|
1228
|
+
kernel: ElementJacobianKernel | None = None,
|
|
672
1229
|
sparse: bool = False,
|
|
673
1230
|
return_flux_matrix: bool = False,
|
|
674
|
-
pattern=None,
|
|
675
|
-
|
|
1231
|
+
pattern: SparsityPattern | None = None,
|
|
1232
|
+
n_chunks: Optional[int] = None,
|
|
1233
|
+
pad_trace: bool = False,
|
|
1234
|
+
) -> JacobianReturn:
|
|
676
1235
|
"""
|
|
677
1236
|
Assemble Jacobian using jitted element kernel + vmap + scatter_add.
|
|
678
1237
|
If a SparsityPattern is provided, rows/cols are reused without regeneration.
|
|
@@ -682,7 +1241,9 @@ def assemble_jacobian_scatter(
|
|
|
682
1241
|
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
683
1242
|
|
|
684
1243
|
pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
|
|
685
|
-
data = assemble_jacobian_values(
|
|
1244
|
+
data = assemble_jacobian_values(
|
|
1245
|
+
space, res_form, u, params, kernel=kernel, n_chunks=n_chunks, pad_trace=pad_trace
|
|
1246
|
+
)
|
|
686
1247
|
|
|
687
1248
|
if sparse:
|
|
688
1249
|
if return_flux_matrix:
|
|
@@ -691,7 +1252,7 @@ def assemble_jacobian_scatter(
|
|
|
691
1252
|
|
|
692
1253
|
idx = pat.idx
|
|
693
1254
|
if idx is None:
|
|
694
|
-
idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(
|
|
1255
|
+
idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(INDEX_DTYPE)
|
|
695
1256
|
|
|
696
1257
|
n_entries = pat.n_dofs * pat.n_dofs
|
|
697
1258
|
sdn = jax.lax.ScatterDimensionNumbers(
|
|
@@ -708,12 +1269,21 @@ def assemble_jacobian_scatter(
|
|
|
708
1269
|
def assemble_residual(
|
|
709
1270
|
space: SpaceLike,
|
|
710
1271
|
form: ResidualForm[P],
|
|
711
|
-
u: jnp.ndarray,
|
|
1272
|
+
u: jnp.ndarray,
|
|
1273
|
+
params: P,
|
|
712
1274
|
*,
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
1275
|
+
kernel: ElementResidualKernel | None = None,
|
|
1276
|
+
sparse: bool = False,
|
|
1277
|
+
n_chunks: Optional[int] = None,
|
|
1278
|
+
pad_trace: bool = False,
|
|
1279
|
+
) -> LinearReturn:
|
|
1280
|
+
"""
|
|
1281
|
+
Assemble the global residual vector (scatter-based).
|
|
1282
|
+
If kernel is provided: kernel(ctx, u_elem) -> (n_ldofs,).
|
|
1283
|
+
"""
|
|
1284
|
+
return assemble_residual_scatter(
|
|
1285
|
+
space, form, u, params, kernel=kernel, sparse=sparse, n_chunks=n_chunks, pad_trace=pad_trace
|
|
1286
|
+
)
|
|
717
1287
|
|
|
718
1288
|
|
|
719
1289
|
def assemble_jacobian(
|
|
@@ -722,19 +1292,28 @@ def assemble_jacobian(
|
|
|
722
1292
|
u: jnp.ndarray,
|
|
723
1293
|
params: P,
|
|
724
1294
|
*,
|
|
1295
|
+
kernel: ElementJacobianKernel | None = None,
|
|
725
1296
|
sparse: bool = True,
|
|
726
1297
|
return_flux_matrix: bool = False,
|
|
727
|
-
pattern=None,
|
|
728
|
-
|
|
729
|
-
|
|
1298
|
+
pattern: SparsityPattern | None = None,
|
|
1299
|
+
n_chunks: Optional[int] = None,
|
|
1300
|
+
pad_trace: bool = False,
|
|
1301
|
+
) -> JacobianReturn:
|
|
1302
|
+
"""
|
|
1303
|
+
Assemble the global Jacobian (scatter-based).
|
|
1304
|
+
If kernel is provided: kernel(u_elem, ctx) -> (n_ldofs, n_ldofs).
|
|
1305
|
+
"""
|
|
730
1306
|
return assemble_jacobian_scatter(
|
|
731
1307
|
space,
|
|
732
1308
|
res_form,
|
|
733
1309
|
u,
|
|
734
1310
|
params,
|
|
1311
|
+
kernel=kernel,
|
|
735
1312
|
sparse=sparse,
|
|
736
1313
|
return_flux_matrix=return_flux_matrix,
|
|
737
1314
|
pattern=pattern,
|
|
1315
|
+
n_chunks=n_chunks,
|
|
1316
|
+
pad_trace=pad_trace,
|
|
738
1317
|
)
|
|
739
1318
|
|
|
740
1319
|
|
|
@@ -748,13 +1327,19 @@ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
|
|
|
748
1327
|
return load * ctx.test.N # (n_q, n_ldofs)
|
|
749
1328
|
|
|
750
1329
|
|
|
751
|
-
|
|
1330
|
+
scalar_body_force_form._ff_kind = "linear"
|
|
1331
|
+
scalar_body_force_form._ff_domain = "volume"
|
|
1332
|
+
|
|
1333
|
+
|
|
1334
|
+
def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> Kernel[Any]:
|
|
752
1335
|
"""
|
|
753
1336
|
Build a scalar linear form from a callable f(x_q) -> (n_q,).
|
|
754
1337
|
"""
|
|
755
1338
|
def _form(ctx: FormContext, _params):
|
|
756
1339
|
f_q = body_force(ctx.x_q)
|
|
757
1340
|
return f_q[..., None] * ctx.test.N
|
|
1341
|
+
_form._ff_kind = "linear"
|
|
1342
|
+
_form._ff_domain = "volume"
|
|
758
1343
|
return _form
|
|
759
1344
|
|
|
760
1345
|
|
|
@@ -762,7 +1347,7 @@ def make_scalar_body_force_form(body_force):
|
|
|
762
1347
|
constant_body_force_form = scalar_body_force_form
|
|
763
1348
|
|
|
764
1349
|
|
|
765
|
-
def _check_structured_box_connectivity():
|
|
1350
|
+
def _check_structured_box_connectivity() -> None:
|
|
766
1351
|
"""Quick connectivity check for nx=2, ny=1, nz=1 (non-structured order)."""
|
|
767
1352
|
box = StructuredHexBox(nx=2, ny=1, nz=1, lx=2.0, ly=1.0, lz=1.0)
|
|
768
1353
|
mesh = box.build()
|
|
@@ -775,7 +1360,7 @@ def _check_structured_box_connectivity():
|
|
|
775
1360
|
[0, 1, 4, 3, 6, 7, 10, 9], # element at i=0
|
|
776
1361
|
[1, 2, 5, 4, 7, 8, 11, 10], # element at i=1
|
|
777
1362
|
],
|
|
778
|
-
dtype=
|
|
1363
|
+
dtype=INDEX_DTYPE,
|
|
779
1364
|
)
|
|
780
1365
|
max_diff = int(jnp.max(jnp.abs(mesh.conn - expected_conn)))
|
|
781
1366
|
print("StructuredHexBox nx=2,ny=1,nz=1 conn matches expected:", max_diff == 0)
|