fluxfem 0.1.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +316 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +818 -0
- fluxfem/helpers_num.py +11 -0
- fluxfem/helpers_wf.py +42 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.1a0.dist-info/METADATA +111 -0
- fluxfem-0.1.1a0.dist-info/RECORD +47 -0
- fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
- fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/core/assembly.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable, Protocol, TypeVar, Optional
|
|
3
|
+
import numpy as np
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from ..mesh import HexMesh, StructuredHexBox
|
|
8
|
+
from .forms import FormContext
|
|
9
|
+
from .space import FESpaceBase
|
|
10
|
+
|
|
11
|
+
# Shared call signatures for kernels/forms
|
|
12
|
+
Array = jnp.ndarray
|
|
13
|
+
P = TypeVar("P")
|
|
14
|
+
|
|
15
|
+
Kernel = Callable[[FormContext, P], Array]
|
|
16
|
+
ResidualForm = Callable[[FormContext, Array, P], Array]
|
|
17
|
+
ElementDofMapper = Callable[[Array], Array]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpaceLike(FESpaceBase, Protocol):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def assemble_bilinear_dense(
|
|
25
|
+
space: SpaceLike,
|
|
26
|
+
kernel: Kernel[P],
|
|
27
|
+
params: P,
|
|
28
|
+
*,
|
|
29
|
+
sparse: bool = False,
|
|
30
|
+
return_flux_matrix: bool = False,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Similar to scikit-fem's asm(biform, basis).
|
|
34
|
+
kernel: FormContext, params -> (n_ldofs, n_ldofs)
|
|
35
|
+
"""
|
|
36
|
+
elem_dofs = space.elem_dofs # (n_elems, n_ldofs)
|
|
37
|
+
n_dofs = space.n_dofs
|
|
38
|
+
n_ldofs = space.n_ldofs
|
|
39
|
+
|
|
40
|
+
elem_data = space.build_form_contexts() # Pytree with leading n_elems in each field
|
|
41
|
+
|
|
42
|
+
# apply kernel per element
|
|
43
|
+
def ke_fun(ctx: FormContext):
|
|
44
|
+
return kernel(ctx, params)
|
|
45
|
+
|
|
46
|
+
K_e_all = jax.vmap(ke_fun)(elem_data) # (n_elems, n_ldofs, n_ldofs)
|
|
47
|
+
|
|
48
|
+
# ---- scatter into COO format ----
|
|
49
|
+
# row/col indices (n_elems, n_ldofs, n_ldofs)
|
|
50
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
|
|
51
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
|
|
52
|
+
|
|
53
|
+
rows = rows.reshape(-1)
|
|
54
|
+
cols = cols.reshape(-1)
|
|
55
|
+
data = K_e_all.reshape(-1)
|
|
56
|
+
|
|
57
|
+
# Flatten indices for segment_sum via (row * n_dofs + col)
|
|
58
|
+
idx = rows * n_dofs + cols # (n_entries,)
|
|
59
|
+
|
|
60
|
+
if sparse:
|
|
61
|
+
if return_flux_matrix:
|
|
62
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
63
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs)
|
|
64
|
+
return rows, cols, data, n_dofs
|
|
65
|
+
|
|
66
|
+
n_entries = n_dofs * n_dofs
|
|
67
|
+
out = jnp.zeros((n_entries,), dtype=data.dtype)
|
|
68
|
+
out = out.at[idx].add(data)
|
|
69
|
+
K = out.reshape(n_dofs, n_dofs)
|
|
70
|
+
return K
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def assemble_bilinear_form(
|
|
74
|
+
space,
|
|
75
|
+
form,
|
|
76
|
+
params,
|
|
77
|
+
*,
|
|
78
|
+
pattern=None,
|
|
79
|
+
chunk_size: Optional[int] = None, # None -> no-chunk (old behavior)
|
|
80
|
+
dep: jnp.ndarray | None = None,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Assemble a sparse bilinear form into a FluxSparseMatrix.
|
|
84
|
+
|
|
85
|
+
Expects form(ctx, params) -> (n_q, n_ldofs, n_ldofs).
|
|
86
|
+
"""
|
|
87
|
+
from ..solver import FluxSparseMatrix
|
|
88
|
+
|
|
89
|
+
if pattern is None:
|
|
90
|
+
if hasattr(space, "get_sparsity_pattern"):
|
|
91
|
+
pat = space.get_sparsity_pattern(with_idx=True)
|
|
92
|
+
else:
|
|
93
|
+
pat = make_sparsity_pattern(space, with_idx=True)
|
|
94
|
+
else:
|
|
95
|
+
pat = pattern
|
|
96
|
+
elem_data = space.build_form_contexts(dep=dep)
|
|
97
|
+
|
|
98
|
+
includes_measure = getattr(form, "_includes_measure", False)
|
|
99
|
+
|
|
100
|
+
def per_element(ctx):
|
|
101
|
+
integrand = form(ctx, params) # (n_q, m, m)
|
|
102
|
+
if includes_measure:
|
|
103
|
+
return integrand.sum(axis=0)
|
|
104
|
+
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
105
|
+
return (integrand * wJ[:, None, None]).sum(axis=0) # (m, m)
|
|
106
|
+
|
|
107
|
+
# --- no-chunk path (your current implementation) ---
|
|
108
|
+
if chunk_size is None:
|
|
109
|
+
K_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m, m)
|
|
110
|
+
data = K_e_all.reshape(-1)
|
|
111
|
+
return FluxSparseMatrix(pat, data)
|
|
112
|
+
|
|
113
|
+
# --- chunked path ---
|
|
114
|
+
n_elems = space.elem_dofs.shape[0]
|
|
115
|
+
# Ideally get m from pat (otherwise infer from one element).
|
|
116
|
+
m = getattr(pat, "n_ldofs", None)
|
|
117
|
+
if m is None:
|
|
118
|
+
m = per_element(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
|
|
119
|
+
|
|
120
|
+
# Pad to fixed-size chunks for JIT stability.
|
|
121
|
+
pad = (-n_elems) % chunk_size
|
|
122
|
+
if pad:
|
|
123
|
+
elem_data_pad = jax.tree_util.tree_map(
|
|
124
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
125
|
+
elem_data,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
elem_data_pad = elem_data
|
|
129
|
+
|
|
130
|
+
n_pad = n_elems + pad
|
|
131
|
+
n_chunks = n_pad // chunk_size
|
|
132
|
+
|
|
133
|
+
def _slice_first_dim(x, start, size):
|
|
134
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
135
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
136
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
137
|
+
|
|
138
|
+
def chunk_fn(i):
|
|
139
|
+
start = i * chunk_size
|
|
140
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
141
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
142
|
+
elem_data_pad,
|
|
143
|
+
)
|
|
144
|
+
Ke = jax.vmap(per_element)(ctx_chunk) # (chunk, m, m)
|
|
145
|
+
return Ke.reshape(-1) # (chunk*m*m,)
|
|
146
|
+
|
|
147
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
148
|
+
data = data_chunks.reshape(-1)[: n_elems * m * m]
|
|
149
|
+
return FluxSparseMatrix(pat, data)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size: Optional[int] = None):
|
|
153
|
+
"""
|
|
154
|
+
Assemble mass matrix M_ij = ∫ N_i N_j dΩ.
|
|
155
|
+
Supports scalar and vector spaces. If lumped=True, rows are summed to diagonal.
|
|
156
|
+
"""
|
|
157
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
158
|
+
|
|
159
|
+
ctxs = space.build_form_contexts()
|
|
160
|
+
n_ldofs = space.n_ldofs
|
|
161
|
+
|
|
162
|
+
def per_element(ctx: FormContext):
|
|
163
|
+
N = ctx.test.N # (n_q, n_nodes)
|
|
164
|
+
base = jnp.einsum("qa,qb->qab", N, N) # (n_q, n_nodes, n_nodes)
|
|
165
|
+
if hasattr(ctx.test, "value_dim"):
|
|
166
|
+
vd = int(ctx.test.value_dim)
|
|
167
|
+
I = jnp.eye(vd, dtype=N.dtype)
|
|
168
|
+
base = base[:, :, :, None, None] * I[None, None, None, :, :]
|
|
169
|
+
base = base.reshape(base.shape[0], n_ldofs, n_ldofs)
|
|
170
|
+
wJ = ctx.w * ctx.test.detJ
|
|
171
|
+
return jnp.einsum("qab,q->ab", base, wJ)
|
|
172
|
+
|
|
173
|
+
if chunk_size is None:
|
|
174
|
+
M_e_all = jax.vmap(per_element)(ctxs) # (n_elems, n_ldofs, n_ldofs)
|
|
175
|
+
data = M_e_all.reshape(-1)
|
|
176
|
+
else:
|
|
177
|
+
n_elems = space.elem_dofs.shape[0]
|
|
178
|
+
pad = (-n_elems) % chunk_size
|
|
179
|
+
if pad:
|
|
180
|
+
ctxs_pad = jax.tree_util.tree_map(
|
|
181
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
182
|
+
ctxs,
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
ctxs_pad = ctxs
|
|
186
|
+
|
|
187
|
+
n_pad = n_elems + pad
|
|
188
|
+
n_chunks = n_pad // chunk_size
|
|
189
|
+
|
|
190
|
+
def _slice_first_dim(x, start, size):
|
|
191
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
192
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
193
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
194
|
+
|
|
195
|
+
def chunk_fn(i):
|
|
196
|
+
start = i * chunk_size
|
|
197
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
198
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
199
|
+
ctxs_pad,
|
|
200
|
+
)
|
|
201
|
+
Me = jax.vmap(per_element)(ctx_chunk) # (chunk, n_ldofs, n_ldofs)
|
|
202
|
+
return Me.reshape(-1)
|
|
203
|
+
|
|
204
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
205
|
+
data = data_chunks.reshape(-1)[: n_elems * n_ldofs * n_ldofs]
|
|
206
|
+
|
|
207
|
+
elem_dofs = space.elem_dofs
|
|
208
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
209
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
210
|
+
|
|
211
|
+
if lumped:
|
|
212
|
+
n_dofs = space.n_dofs
|
|
213
|
+
M = jnp.zeros((n_dofs,), dtype=data.dtype)
|
|
214
|
+
M = M.at[rows].add(data)
|
|
215
|
+
return M
|
|
216
|
+
|
|
217
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs=space.n_dofs)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def assemble_linear_form(
|
|
221
|
+
space: SpaceLike,
|
|
222
|
+
form: Kernel[P],
|
|
223
|
+
params: P,
|
|
224
|
+
*,
|
|
225
|
+
sparse: bool = False,
|
|
226
|
+
chunk_size: Optional[int] = None,
|
|
227
|
+
dep: jnp.ndarray | None = None,
|
|
228
|
+
) -> jnp.ndarray:
|
|
229
|
+
"""
|
|
230
|
+
Expects form(ctx, params) -> (n_q, n_ldofs) and integrates Σ_q form * wJ for RHS.
|
|
231
|
+
"""
|
|
232
|
+
elem_dofs = space.elem_dofs
|
|
233
|
+
n_dofs = space.n_dofs
|
|
234
|
+
n_ldofs = space.n_ldofs
|
|
235
|
+
|
|
236
|
+
elem_data = space.build_form_contexts(dep=dep)
|
|
237
|
+
|
|
238
|
+
includes_measure = getattr(form, "_includes_measure", False)
|
|
239
|
+
|
|
240
|
+
def per_element(ctx: FormContext):
|
|
241
|
+
integrand = form(ctx, params) # (n_q, m)
|
|
242
|
+
if includes_measure:
|
|
243
|
+
return integrand.sum(axis=0)
|
|
244
|
+
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
245
|
+
return (integrand * wJ[:, None]).sum(axis=0) # (m,)
|
|
246
|
+
|
|
247
|
+
if chunk_size is None:
|
|
248
|
+
F_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m)
|
|
249
|
+
data = F_e_all.reshape(-1)
|
|
250
|
+
else:
|
|
251
|
+
n_elems = space.elem_dofs.shape[0]
|
|
252
|
+
m = n_ldofs
|
|
253
|
+
pad = (-n_elems) % chunk_size
|
|
254
|
+
if pad:
|
|
255
|
+
elem_data_pad = jax.tree_util.tree_map(
|
|
256
|
+
lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
|
|
257
|
+
elem_data,
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
elem_data_pad = elem_data
|
|
261
|
+
|
|
262
|
+
n_pad = n_elems + pad
|
|
263
|
+
n_chunks = n_pad // chunk_size
|
|
264
|
+
|
|
265
|
+
def _slice_first_dim(x, start, size):
|
|
266
|
+
start_idx = (start,) + (0,) * (x.ndim - 1)
|
|
267
|
+
slice_sizes = (size,) + x.shape[1:]
|
|
268
|
+
return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
|
|
269
|
+
|
|
270
|
+
def chunk_fn(i):
|
|
271
|
+
start = i * chunk_size
|
|
272
|
+
ctx_chunk = jax.tree_util.tree_map(
|
|
273
|
+
lambda x: _slice_first_dim(x, start, chunk_size),
|
|
274
|
+
elem_data_pad,
|
|
275
|
+
)
|
|
276
|
+
fe = jax.vmap(per_element)(ctx_chunk) # (chunk, m)
|
|
277
|
+
return fe.reshape(-1)
|
|
278
|
+
|
|
279
|
+
data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
|
|
280
|
+
data = data_chunks.reshape(-1)[: n_elems * m]
|
|
281
|
+
|
|
282
|
+
rows = elem_dofs.reshape(-1)
|
|
283
|
+
|
|
284
|
+
if sparse:
|
|
285
|
+
return rows, data, n_dofs
|
|
286
|
+
|
|
287
|
+
F = jax.ops.segment_sum(data, rows, n_dofs)
|
|
288
|
+
return F
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def assemble_functional(space: SpaceLike, form: Kernel[P], params: P) -> jnp.ndarray:
|
|
292
|
+
"""
|
|
293
|
+
Assemble scalar functional J = ∫ form(ctx, params) dΩ.
|
|
294
|
+
Expects form(ctx, params) -> (n_q,) or (n_q, 1).
|
|
295
|
+
"""
|
|
296
|
+
elem_data = space.build_form_contexts()
|
|
297
|
+
|
|
298
|
+
includes_measure = getattr(form, "_includes_measure", False)
|
|
299
|
+
|
|
300
|
+
def per_element(ctx: FormContext):
|
|
301
|
+
integrand = form(ctx, params)
|
|
302
|
+
if integrand.ndim == 2 and integrand.shape[1] == 1:
|
|
303
|
+
integrand = integrand[:, 0]
|
|
304
|
+
if includes_measure:
|
|
305
|
+
return jnp.sum(integrand)
|
|
306
|
+
wJ = ctx.w * ctx.test.detJ
|
|
307
|
+
return jnp.sum(integrand * wJ)
|
|
308
|
+
|
|
309
|
+
vals = jax.vmap(per_element)(elem_data)
|
|
310
|
+
return jnp.sum(vals)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def assemble_jacobian_global(
|
|
314
|
+
space: SpaceLike,
|
|
315
|
+
res_form: ResidualForm[P],
|
|
316
|
+
u: jnp.ndarray,
|
|
317
|
+
params: P,
|
|
318
|
+
*,
|
|
319
|
+
sparse: bool = False,
|
|
320
|
+
return_flux_matrix: bool = False,
|
|
321
|
+
):
|
|
322
|
+
"""
|
|
323
|
+
Assemble Jacobian (dR/du) from element residual res_form.
|
|
324
|
+
res_form(ctx, u_elem, params) -> (n_q, n_ldofs)
|
|
325
|
+
"""
|
|
326
|
+
elem_dofs = space.elem_dofs
|
|
327
|
+
n_dofs = space.n_dofs
|
|
328
|
+
n_ldofs = space.n_ldofs
|
|
329
|
+
|
|
330
|
+
elem_data = space.build_form_contexts()
|
|
331
|
+
|
|
332
|
+
def fe_fun(u_elem, ctx: FormContext, elem_id):
|
|
333
|
+
ctx_with_id = FormContext(ctx.test, ctx.trial, ctx.x_q, ctx.w, elem_id)
|
|
334
|
+
integrand = res_form(ctx_with_id, u_elem, params) # (n_q, m)
|
|
335
|
+
wJ = ctx.w * ctx.test.detJ
|
|
336
|
+
fe = (integrand * wJ[:, None]).sum(axis=0) # (m,)
|
|
337
|
+
return fe
|
|
338
|
+
|
|
339
|
+
jac_fun = jax.jacrev(fe_fun, argnums=0)
|
|
340
|
+
|
|
341
|
+
u_elems = u[elem_dofs] # (n_elems, n_ldofs)
|
|
342
|
+
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
|
|
343
|
+
J_e_all = jax.vmap(jac_fun)(u_elems, elem_data, elem_ids) # (n_elems, m, m)
|
|
344
|
+
|
|
345
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
346
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
347
|
+
data = J_e_all.reshape(-1)
|
|
348
|
+
|
|
349
|
+
if sparse:
|
|
350
|
+
if return_flux_matrix:
|
|
351
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
352
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs)
|
|
353
|
+
return rows, cols, data, n_dofs
|
|
354
|
+
|
|
355
|
+
n_entries = n_dofs * n_dofs
|
|
356
|
+
idx = rows * n_dofs + cols
|
|
357
|
+
K_flat = jax.ops.segment_sum(data, idx, n_entries)
|
|
358
|
+
return K_flat.reshape(n_dofs, n_dofs)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def assemble_jacobian_elementwise_xla(
|
|
362
|
+
space: SpaceLike,
|
|
363
|
+
res_form: ResidualForm[P],
|
|
364
|
+
u: jnp.ndarray,
|
|
365
|
+
params: P,
|
|
366
|
+
*,
|
|
367
|
+
sparse: bool = False,
|
|
368
|
+
return_flux_matrix: bool = False,
|
|
369
|
+
):
|
|
370
|
+
"""
|
|
371
|
+
Assemble Jacobian with element kernels in XLA (vmap + scatter_add).
|
|
372
|
+
Recompiles if n_dofs changes, but independent of element count.
|
|
373
|
+
"""
|
|
374
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
375
|
+
|
|
376
|
+
elem_dofs = space.elem_dofs
|
|
377
|
+
n_dofs = space.n_dofs
|
|
378
|
+
n_ldofs = space.n_ldofs
|
|
379
|
+
|
|
380
|
+
ctxs = space.build_form_contexts()
|
|
381
|
+
|
|
382
|
+
def fe_fun(u_elem, ctx: FormContext):
|
|
383
|
+
integrand = res_form(ctx, u_elem, params)
|
|
384
|
+
wJ = ctx.w * ctx.test.detJ
|
|
385
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
386
|
+
|
|
387
|
+
jac_fun = jax.jacrev(fe_fun, argnums=0)
|
|
388
|
+
u_elems = u[elem_dofs]
|
|
389
|
+
J_e_all = jax.vmap(jac_fun)(u_elems, ctxs) # (n_elems, m, m)
|
|
390
|
+
|
|
391
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
|
|
392
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
|
|
393
|
+
data = J_e_all.reshape(-1)
|
|
394
|
+
|
|
395
|
+
if sparse:
|
|
396
|
+
if return_flux_matrix:
|
|
397
|
+
return FluxSparseMatrix(rows, cols, data, n_dofs)
|
|
398
|
+
return rows, cols, data, n_dofs
|
|
399
|
+
|
|
400
|
+
n_entries = n_dofs * n_dofs
|
|
401
|
+
idx = rows * n_dofs + cols
|
|
402
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
403
|
+
update_window_dims=(),
|
|
404
|
+
inserted_window_dims=(0,),
|
|
405
|
+
scatter_dims_to_operand_dims=(0,),
|
|
406
|
+
)
|
|
407
|
+
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
408
|
+
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
409
|
+
return K_flat.reshape(pat.n_dofs, pat.n_dofs)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def assemble_residual_global(
|
|
413
|
+
space: SpaceLike,
|
|
414
|
+
form: ResidualForm[P],
|
|
415
|
+
u: jnp.ndarray,
|
|
416
|
+
params: P,
|
|
417
|
+
*,
|
|
418
|
+
sparse: bool = False
|
|
419
|
+
):
|
|
420
|
+
"""
|
|
421
|
+
Assemble residual vector that depends on u.
|
|
422
|
+
form(ctx, u_elem, params) -> (n_q, n_ldofs)
|
|
423
|
+
"""
|
|
424
|
+
elem_dofs = space.elem_dofs
|
|
425
|
+
n_dofs = space.n_dofs
|
|
426
|
+
n_ldofs = space.n_ldofs
|
|
427
|
+
|
|
428
|
+
elem_data = space.build_form_contexts()
|
|
429
|
+
|
|
430
|
+
def per_element(ctx: FormContext, conn: jnp.ndarray, elem_id: jnp.ndarray):
|
|
431
|
+
u_elem = u[conn]
|
|
432
|
+
ctx_with_id = FormContext(ctx.test, ctx.trial, ctx.x_q, ctx.w, elem_id)
|
|
433
|
+
integrand = form(ctx_with_id, u_elem, params) # (n_q, m)
|
|
434
|
+
wJ = ctx.w * ctx.test.detJ
|
|
435
|
+
fe = (integrand * wJ[:, None]).sum(axis=0)
|
|
436
|
+
return fe
|
|
437
|
+
|
|
438
|
+
elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
|
|
439
|
+
F_e_all = jax.vmap(per_element)(elem_data, elem_dofs, elem_ids) # (n_elems, m)
|
|
440
|
+
|
|
441
|
+
rows = elem_dofs.reshape(-1)
|
|
442
|
+
data = F_e_all.reshape(-1)
|
|
443
|
+
|
|
444
|
+
if sparse:
|
|
445
|
+
return rows, data, n_dofs
|
|
446
|
+
|
|
447
|
+
F = jax.ops.segment_sum(data, rows, n_dofs)
|
|
448
|
+
return F
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def assemble_residual_elementwise_xla(
|
|
452
|
+
space: SpaceLike,
|
|
453
|
+
res_form: ResidualForm[P],
|
|
454
|
+
u: jnp.ndarray,
|
|
455
|
+
params: P,
|
|
456
|
+
*,
|
|
457
|
+
sparse: bool = False,
|
|
458
|
+
):
|
|
459
|
+
"""
|
|
460
|
+
Assemble residual using element kernels fully in XLA (vmap + scatter_add).
|
|
461
|
+
Recompiles if n_dofs changes, but independent of element count.
|
|
462
|
+
"""
|
|
463
|
+
elem_dofs = space.elem_dofs
|
|
464
|
+
n_dofs = space.n_dofs
|
|
465
|
+
ctxs = space.build_form_contexts()
|
|
466
|
+
|
|
467
|
+
def per_element(ctx: FormContext, u_elem: jnp.ndarray):
|
|
468
|
+
integrand = res_form(ctx, u_elem, params)
|
|
469
|
+
wJ = ctx.w * ctx.test.detJ
|
|
470
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
471
|
+
|
|
472
|
+
u_elems = u[elem_dofs]
|
|
473
|
+
F_e_all = jax.vmap(per_element)(ctxs, u_elems) # (n_elems, m)
|
|
474
|
+
rows = elem_dofs.reshape(-1)
|
|
475
|
+
data = F_e_all.reshape(-1)
|
|
476
|
+
|
|
477
|
+
if sparse:
|
|
478
|
+
return rows, data, n_dofs
|
|
479
|
+
|
|
480
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
481
|
+
update_window_dims=(),
|
|
482
|
+
inserted_window_dims=(0,),
|
|
483
|
+
scatter_dims_to_operand_dims=(0,),
|
|
484
|
+
)
|
|
485
|
+
F = jnp.zeros(n_dofs, dtype=data.dtype)
|
|
486
|
+
F = jax.lax.scatter_add(F, rows[:, None], data, sdn)
|
|
487
|
+
return F
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
|
|
491
|
+
"""Jitted element residual kernel: (ctx, u_elem) -> fe."""
|
|
492
|
+
|
|
493
|
+
def per_element(ctx: FormContext, u_elem: jnp.ndarray):
|
|
494
|
+
integrand = res_form(ctx, u_elem, params)
|
|
495
|
+
if getattr(res_form, "_includes_measure", False):
|
|
496
|
+
return integrand.sum(axis=0)
|
|
497
|
+
wJ = ctx.w * ctx.test.detJ
|
|
498
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
499
|
+
|
|
500
|
+
return jax.jit(per_element)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
|
|
504
|
+
"""Jitted element Jacobian kernel: (ctx, u_elem) -> Ke."""
|
|
505
|
+
|
|
506
|
+
def fe_fun(u_elem, ctx: FormContext):
|
|
507
|
+
integrand = res_form(ctx, u_elem, params)
|
|
508
|
+
if getattr(res_form, "_includes_measure", False):
|
|
509
|
+
return integrand.sum(axis=0)
|
|
510
|
+
wJ = ctx.w * ctx.test.detJ
|
|
511
|
+
return (integrand * wJ[:, None]).sum(axis=0)
|
|
512
|
+
|
|
513
|
+
return jax.jit(jax.jacrev(fe_fun, argnums=0))
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
|
|
517
|
+
"""
|
|
518
|
+
Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
|
|
519
|
+
Returns shape (n_ldofs,).
|
|
520
|
+
"""
|
|
521
|
+
integrand = res_form(ctx, u_elem, params) # (n_q, n_ldofs) or pytree
|
|
522
|
+
includes_measure = getattr(res_form, "_includes_measure", False)
|
|
523
|
+
if isinstance(integrand, jnp.ndarray):
|
|
524
|
+
if includes_measure:
|
|
525
|
+
return jnp.einsum("qa->a", integrand)
|
|
526
|
+
wJ = ctx.w * ctx.test.detJ # (n_q,)
|
|
527
|
+
return jnp.einsum("qa,q->a", integrand, wJ)
|
|
528
|
+
if hasattr(ctx, "fields") and ctx.fields is not None:
|
|
529
|
+
def _reduce(name, val):
|
|
530
|
+
if isinstance(includes_measure, dict) and includes_measure.get(name, False):
|
|
531
|
+
return jnp.einsum("qa->a", val)
|
|
532
|
+
wJ = ctx.w * ctx.fields[name].test.detJ
|
|
533
|
+
return jnp.einsum("qa,q->a", val, wJ)
|
|
534
|
+
|
|
535
|
+
return {name: _reduce(name, val) for name, val in integrand.items()}
|
|
536
|
+
if includes_measure:
|
|
537
|
+
return jax.tree_util.tree_map(lambda x: jnp.einsum("qa->a", x), integrand)
|
|
538
|
+
return jax.tree_util.tree_map(lambda x: jnp.einsum("qa,q->a", x, ctx.w * ctx.test.detJ), integrand)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
|
|
542
|
+
"""
|
|
543
|
+
Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
|
|
544
|
+
"""
|
|
545
|
+
def _r_elem(u_local):
|
|
546
|
+
return element_residual(res_form, ctx, u_local, params)
|
|
547
|
+
|
|
548
|
+
return jax.jacfwd(_r_elem)(u_elem)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
|
|
552
|
+
"""
|
|
553
|
+
Build a SparsityPattern (rows/cols[/idx]) that is independent of the solution.
|
|
554
|
+
NOTE: rows/cols ordering matches assemble_jacobian_values(...).reshape(-1)
|
|
555
|
+
so that pattern and data are aligned 1:1. If you change the flattening/
|
|
556
|
+
compression strategy, keep this ordering contract in sync.
|
|
557
|
+
"""
|
|
558
|
+
from ..solver import SparsityPattern # local import to avoid circular
|
|
559
|
+
|
|
560
|
+
elem_dofs = jnp.asarray(space.elem_dofs, dtype=jnp.int32)
|
|
561
|
+
n_dofs = int(space.n_dofs)
|
|
562
|
+
n_ldofs = int(space.n_ldofs)
|
|
563
|
+
|
|
564
|
+
rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(jnp.int32)
|
|
565
|
+
cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(jnp.int32)
|
|
566
|
+
|
|
567
|
+
key = rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)
|
|
568
|
+
order = jnp.argsort(key).astype(jnp.int32)
|
|
569
|
+
rows_sorted = rows[order]
|
|
570
|
+
cols_sorted = cols[order]
|
|
571
|
+
counts = jnp.bincount(rows_sorted, length=n_dofs).astype(jnp.int32)
|
|
572
|
+
indptr_j = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(counts)])
|
|
573
|
+
indices_j = cols_sorted.astype(jnp.int32)
|
|
574
|
+
perm = order
|
|
575
|
+
|
|
576
|
+
if with_idx:
|
|
577
|
+
idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(jnp.int32)
|
|
578
|
+
return SparsityPattern(
|
|
579
|
+
rows=rows,
|
|
580
|
+
cols=cols,
|
|
581
|
+
n_dofs=n_dofs,
|
|
582
|
+
idx=idx,
|
|
583
|
+
perm=perm,
|
|
584
|
+
indptr=indptr_j,
|
|
585
|
+
indices=indices_j,
|
|
586
|
+
)
|
|
587
|
+
return SparsityPattern(
|
|
588
|
+
rows=rows,
|
|
589
|
+
cols=cols,
|
|
590
|
+
n_dofs=n_dofs,
|
|
591
|
+
idx=None,
|
|
592
|
+
perm=perm,
|
|
593
|
+
indptr=indptr_j,
|
|
594
|
+
indices=indices_j,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def assemble_jacobian_values(
|
|
599
|
+
space: SpaceLike,
|
|
600
|
+
res_form: ResidualForm[P],
|
|
601
|
+
u: jnp.ndarray,
|
|
602
|
+
params: P,
|
|
603
|
+
*,
|
|
604
|
+
kernel=None,
|
|
605
|
+
):
|
|
606
|
+
"""
|
|
607
|
+
Assemble only the numeric values for the Jacobian (pattern-free).
|
|
608
|
+
"""
|
|
609
|
+
ctxs = space.build_form_contexts()
|
|
610
|
+
ker = kernel if kernel is not None else make_element_jacobian_kernel(res_form, params)
|
|
611
|
+
|
|
612
|
+
u_elems = u[space.elem_dofs]
|
|
613
|
+
J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
|
|
614
|
+
return J_e_all.reshape(-1)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def assemble_residual_scatter(
|
|
618
|
+
space: SpaceLike,
|
|
619
|
+
res_form: ResidualForm[P],
|
|
620
|
+
u: jnp.ndarray,
|
|
621
|
+
params: P,
|
|
622
|
+
*,
|
|
623
|
+
kernel=None,
|
|
624
|
+
sparse: bool = False,
|
|
625
|
+
):
|
|
626
|
+
"""
|
|
627
|
+
Assemble residual using jitted element kernel + vmap + scatter_add.
|
|
628
|
+
Avoids Python loops; good for JIT stability.
|
|
629
|
+
|
|
630
|
+
Note: `res_form` should return the integrand only; quadrature weights and detJ
|
|
631
|
+
are applied in the element kernel (make_element_residual_kernel). Do not multiply
|
|
632
|
+
by w or detJ inside `res_form`.
|
|
633
|
+
"""
|
|
634
|
+
elem_dofs = space.elem_dofs
|
|
635
|
+
n_dofs = space.n_dofs
|
|
636
|
+
if np.max(elem_dofs) >= n_dofs:
|
|
637
|
+
raise ValueError("elem_dofs contains index outside n_dofs")
|
|
638
|
+
if np.min(elem_dofs) < 0:
|
|
639
|
+
raise ValueError("elem_dofs contains negative index")
|
|
640
|
+
ctxs = space.build_form_contexts()
|
|
641
|
+
ker = kernel if kernel is not None else make_element_residual_kernel(res_form, params)
|
|
642
|
+
|
|
643
|
+
u_elems = u[elem_dofs]
|
|
644
|
+
elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
|
|
645
|
+
if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
|
|
646
|
+
bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
|
|
647
|
+
raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
|
|
648
|
+
|
|
649
|
+
rows = elem_dofs.reshape(-1)
|
|
650
|
+
data = elem_res.reshape(-1)
|
|
651
|
+
|
|
652
|
+
if sparse:
|
|
653
|
+
return rows, data, n_dofs
|
|
654
|
+
|
|
655
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
656
|
+
update_window_dims=(),
|
|
657
|
+
inserted_window_dims=(0,),
|
|
658
|
+
scatter_dims_to_operand_dims=(0,),
|
|
659
|
+
)
|
|
660
|
+
F = jnp.zeros((n_dofs,), dtype=data.dtype)
|
|
661
|
+
F = jax.lax.scatter_add(F, rows[:, None], data, sdn)
|
|
662
|
+
return F
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def assemble_jacobian_scatter(
|
|
666
|
+
space: SpaceLike,
|
|
667
|
+
res_form: ResidualForm[P],
|
|
668
|
+
u: jnp.ndarray,
|
|
669
|
+
params: P,
|
|
670
|
+
*,
|
|
671
|
+
kernel=None,
|
|
672
|
+
sparse: bool = False,
|
|
673
|
+
return_flux_matrix: bool = False,
|
|
674
|
+
pattern=None,
|
|
675
|
+
):
|
|
676
|
+
"""
|
|
677
|
+
Assemble Jacobian using jitted element kernel + vmap + scatter_add.
|
|
678
|
+
If a SparsityPattern is provided, rows/cols are reused without regeneration.
|
|
679
|
+
CONTRACT: The returned `data` ordering matches `pattern.rows/cols` exactly.
|
|
680
|
+
Any change to pattern generation or data flattening must preserve this.
|
|
681
|
+
"""
|
|
682
|
+
from ..solver import FluxSparseMatrix # local import to avoid circular
|
|
683
|
+
|
|
684
|
+
pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
|
|
685
|
+
data = assemble_jacobian_values(space, res_form, u, params, kernel=kernel)
|
|
686
|
+
|
|
687
|
+
if sparse:
|
|
688
|
+
if return_flux_matrix:
|
|
689
|
+
return FluxSparseMatrix(pat, data)
|
|
690
|
+
return pat.rows, pat.cols, data, pat.n_dofs
|
|
691
|
+
|
|
692
|
+
idx = pat.idx
|
|
693
|
+
if idx is None:
|
|
694
|
+
idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(jnp.int32)
|
|
695
|
+
|
|
696
|
+
n_entries = pat.n_dofs * pat.n_dofs
|
|
697
|
+
sdn = jax.lax.ScatterDimensionNumbers(
|
|
698
|
+
update_window_dims=(),
|
|
699
|
+
inserted_window_dims=(0,),
|
|
700
|
+
scatter_dims_to_operand_dims=(0,),
|
|
701
|
+
)
|
|
702
|
+
K_flat = jnp.zeros(n_entries, dtype=data.dtype)
|
|
703
|
+
K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
|
|
704
|
+
return K_flat.reshape(pat.n_dofs, pat.n_dofs)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
# Alias scatter-based assembly as the default public API
|
|
708
|
+
def assemble_residual(
|
|
709
|
+
space: SpaceLike,
|
|
710
|
+
form: ResidualForm[P],
|
|
711
|
+
u: jnp.ndarray, params: P,
|
|
712
|
+
*,
|
|
713
|
+
sparse: bool = False
|
|
714
|
+
):
|
|
715
|
+
"""Assemble the global residual vector (scatter-based)."""
|
|
716
|
+
return assemble_residual_scatter(space, form, u, params, sparse=sparse)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def assemble_jacobian(
|
|
720
|
+
space: SpaceLike,
|
|
721
|
+
res_form: ResidualForm[P],
|
|
722
|
+
u: jnp.ndarray,
|
|
723
|
+
params: P,
|
|
724
|
+
*,
|
|
725
|
+
sparse: bool = True,
|
|
726
|
+
return_flux_matrix: bool = False,
|
|
727
|
+
pattern=None,
|
|
728
|
+
):
|
|
729
|
+
"""Assemble the global Jacobian (scatter-based)."""
|
|
730
|
+
return assemble_jacobian_scatter(
|
|
731
|
+
space,
|
|
732
|
+
res_form,
|
|
733
|
+
u,
|
|
734
|
+
params,
|
|
735
|
+
sparse=sparse,
|
|
736
|
+
return_flux_matrix=return_flux_matrix,
|
|
737
|
+
pattern=pattern,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def _make_unit_cube_mesh() -> HexMesh:
|
|
742
|
+
"""Single hex element on [0, 1]^3."""
|
|
743
|
+
return StructuredHexBox(nx=1, ny=1, nz=1, lx=1.0, ly=1.0, lz=1.0).build()
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
|
|
747
|
+
"""Linear form for constant scalar body force: f * N."""
|
|
748
|
+
return load * ctx.test.N # (n_q, n_ldofs)
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def make_scalar_body_force_form(body_force):
|
|
752
|
+
"""
|
|
753
|
+
Build a scalar linear form from a callable f(x_q) -> (n_q,).
|
|
754
|
+
"""
|
|
755
|
+
def _form(ctx: FormContext, _params):
|
|
756
|
+
f_q = body_force(ctx.x_q)
|
|
757
|
+
return f_q[..., None] * ctx.test.N
|
|
758
|
+
return _form
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
# Backward compatibility alias
|
|
762
|
+
constant_body_force_form = scalar_body_force_form
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def _check_structured_box_connectivity():
|
|
766
|
+
"""Quick connectivity check for nx=2, ny=1, nz=1 (non-structured order)."""
|
|
767
|
+
box = StructuredHexBox(nx=2, ny=1, nz=1, lx=2.0, ly=1.0, lz=1.0)
|
|
768
|
+
mesh = box.build()
|
|
769
|
+
|
|
770
|
+
assert mesh.coords.shape == (12, 3)
|
|
771
|
+
assert mesh.conn.shape == (2, 8)
|
|
772
|
+
|
|
773
|
+
expected_conn = jnp.array(
|
|
774
|
+
[
|
|
775
|
+
[0, 1, 4, 3, 6, 7, 10, 9], # element at i=0
|
|
776
|
+
[1, 2, 5, 4, 7, 8, 11, 10], # element at i=1
|
|
777
|
+
],
|
|
778
|
+
dtype=jnp.int32,
|
|
779
|
+
)
|
|
780
|
+
max_diff = int(jnp.max(jnp.abs(mesh.conn - expected_conn)))
|
|
781
|
+
print("StructuredHexBox nx=2,ny=1,nz=1 conn matches expected:", max_diff == 0)
|
|
782
|
+
if max_diff != 0:
|
|
783
|
+
print("expected conn:\n", expected_conn)
|
|
784
|
+
print("got conn:\n", mesh.conn)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
if __name__ == "__main__":
|
|
788
|
+
_check_structured_box_connectivity()
|