fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,788 @@
1
+ from __future__ import annotations
2
+ from typing import Callable, Protocol, TypeVar, Optional
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from ..mesh import HexMesh, StructuredHexBox
8
+ from .forms import FormContext
9
+ from .space import FESpaceBase
10
+
11
+ # Shared call signatures for kernels/forms
12
+ Array = jnp.ndarray
13
+ P = TypeVar("P")
14
+
15
+ Kernel = Callable[[FormContext, P], Array]
16
+ ResidualForm = Callable[[FormContext, Array, P], Array]
17
+ ElementDofMapper = Callable[[Array], Array]
18
+
19
+
20
+ class SpaceLike(FESpaceBase, Protocol):
21
+ pass
22
+
23
+
24
+ def assemble_bilinear_dense(
25
+ space: SpaceLike,
26
+ kernel: Kernel[P],
27
+ params: P,
28
+ *,
29
+ sparse: bool = False,
30
+ return_flux_matrix: bool = False,
31
+ ):
32
+ """
33
+ Similar to scikit-fem's asm(biform, basis).
34
+ kernel: FormContext, params -> (n_ldofs, n_ldofs)
35
+ """
36
+ elem_dofs = space.elem_dofs # (n_elems, n_ldofs)
37
+ n_dofs = space.n_dofs
38
+ n_ldofs = space.n_ldofs
39
+
40
+ elem_data = space.build_form_contexts() # Pytree with leading n_elems in each field
41
+
42
+ # apply kernel per element
43
+ def ke_fun(ctx: FormContext):
44
+ return kernel(ctx, params)
45
+
46
+ K_e_all = jax.vmap(ke_fun)(elem_data) # (n_elems, n_ldofs, n_ldofs)
47
+
48
+ # ---- scatter into COO format ----
49
+ # row/col indices (n_elems, n_ldofs, n_ldofs)
50
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
51
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
52
+
53
+ rows = rows.reshape(-1)
54
+ cols = cols.reshape(-1)
55
+ data = K_e_all.reshape(-1)
56
+
57
+ # Flatten indices for segment_sum via (row * n_dofs + col)
58
+ idx = rows * n_dofs + cols # (n_entries,)
59
+
60
+ if sparse:
61
+ if return_flux_matrix:
62
+ from ..solver import FluxSparseMatrix # local import to avoid circular
63
+ return FluxSparseMatrix(rows, cols, data, n_dofs)
64
+ return rows, cols, data, n_dofs
65
+
66
+ n_entries = n_dofs * n_dofs
67
+ out = jnp.zeros((n_entries,), dtype=data.dtype)
68
+ out = out.at[idx].add(data)
69
+ K = out.reshape(n_dofs, n_dofs)
70
+ return K
71
+
72
+
73
+ def assemble_bilinear_form(
74
+ space,
75
+ form,
76
+ params,
77
+ *,
78
+ pattern=None,
79
+ chunk_size: Optional[int] = None, # None -> no-chunk (old behavior)
80
+ dep: jnp.ndarray | None = None,
81
+ ):
82
+ """
83
+ Assemble a sparse bilinear form into a FluxSparseMatrix.
84
+
85
+ Expects form(ctx, params) -> (n_q, n_ldofs, n_ldofs).
86
+ """
87
+ from ..solver import FluxSparseMatrix
88
+
89
+ if pattern is None:
90
+ if hasattr(space, "get_sparsity_pattern"):
91
+ pat = space.get_sparsity_pattern(with_idx=True)
92
+ else:
93
+ pat = make_sparsity_pattern(space, with_idx=True)
94
+ else:
95
+ pat = pattern
96
+ elem_data = space.build_form_contexts(dep=dep)
97
+
98
+ includes_measure = getattr(form, "_includes_measure", False)
99
+
100
+ def per_element(ctx):
101
+ integrand = form(ctx, params) # (n_q, m, m)
102
+ if includes_measure:
103
+ return integrand.sum(axis=0)
104
+ wJ = ctx.w * ctx.test.detJ # (n_q,)
105
+ return (integrand * wJ[:, None, None]).sum(axis=0) # (m, m)
106
+
107
+ # --- no-chunk path (your current implementation) ---
108
+ if chunk_size is None:
109
+ K_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m, m)
110
+ data = K_e_all.reshape(-1)
111
+ return FluxSparseMatrix(pat, data)
112
+
113
+ # --- chunked path ---
114
+ n_elems = space.elem_dofs.shape[0]
115
+ # Ideally get m from pat (otherwise infer from one element).
116
+ m = getattr(pat, "n_ldofs", None)
117
+ if m is None:
118
+ m = per_element(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
119
+
120
+ # Pad to fixed-size chunks for JIT stability.
121
+ pad = (-n_elems) % chunk_size
122
+ if pad:
123
+ elem_data_pad = jax.tree_util.tree_map(
124
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
125
+ elem_data,
126
+ )
127
+ else:
128
+ elem_data_pad = elem_data
129
+
130
+ n_pad = n_elems + pad
131
+ n_chunks = n_pad // chunk_size
132
+
133
+ def _slice_first_dim(x, start, size):
134
+ start_idx = (start,) + (0,) * (x.ndim - 1)
135
+ slice_sizes = (size,) + x.shape[1:]
136
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
137
+
138
+ def chunk_fn(i):
139
+ start = i * chunk_size
140
+ ctx_chunk = jax.tree_util.tree_map(
141
+ lambda x: _slice_first_dim(x, start, chunk_size),
142
+ elem_data_pad,
143
+ )
144
+ Ke = jax.vmap(per_element)(ctx_chunk) # (chunk, m, m)
145
+ return Ke.reshape(-1) # (chunk*m*m,)
146
+
147
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
148
+ data = data_chunks.reshape(-1)[: n_elems * m * m]
149
+ return FluxSparseMatrix(pat, data)
150
+
151
+
152
+ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size: Optional[int] = None):
153
+ """
154
+ Assemble mass matrix M_ij = ∫ N_i N_j dΩ.
155
+ Supports scalar and vector spaces. If lumped=True, rows are summed to diagonal.
156
+ """
157
+ from ..solver import FluxSparseMatrix # local import to avoid circular
158
+
159
+ ctxs = space.build_form_contexts()
160
+ n_ldofs = space.n_ldofs
161
+
162
+ def per_element(ctx: FormContext):
163
+ N = ctx.test.N # (n_q, n_nodes)
164
+ base = jnp.einsum("qa,qb->qab", N, N) # (n_q, n_nodes, n_nodes)
165
+ if hasattr(ctx.test, "value_dim"):
166
+ vd = int(ctx.test.value_dim)
167
+ I = jnp.eye(vd, dtype=N.dtype)
168
+ base = base[:, :, :, None, None] * I[None, None, None, :, :]
169
+ base = base.reshape(base.shape[0], n_ldofs, n_ldofs)
170
+ wJ = ctx.w * ctx.test.detJ
171
+ return jnp.einsum("qab,q->ab", base, wJ)
172
+
173
+ if chunk_size is None:
174
+ M_e_all = jax.vmap(per_element)(ctxs) # (n_elems, n_ldofs, n_ldofs)
175
+ data = M_e_all.reshape(-1)
176
+ else:
177
+ n_elems = space.elem_dofs.shape[0]
178
+ pad = (-n_elems) % chunk_size
179
+ if pad:
180
+ ctxs_pad = jax.tree_util.tree_map(
181
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
182
+ ctxs,
183
+ )
184
+ else:
185
+ ctxs_pad = ctxs
186
+
187
+ n_pad = n_elems + pad
188
+ n_chunks = n_pad // chunk_size
189
+
190
+ def _slice_first_dim(x, start, size):
191
+ start_idx = (start,) + (0,) * (x.ndim - 1)
192
+ slice_sizes = (size,) + x.shape[1:]
193
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
194
+
195
+ def chunk_fn(i):
196
+ start = i * chunk_size
197
+ ctx_chunk = jax.tree_util.tree_map(
198
+ lambda x: _slice_first_dim(x, start, chunk_size),
199
+ ctxs_pad,
200
+ )
201
+ Me = jax.vmap(per_element)(ctx_chunk) # (chunk, n_ldofs, n_ldofs)
202
+ return Me.reshape(-1)
203
+
204
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
205
+ data = data_chunks.reshape(-1)[: n_elems * n_ldofs * n_ldofs]
206
+
207
+ elem_dofs = space.elem_dofs
208
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
209
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
210
+
211
+ if lumped:
212
+ n_dofs = space.n_dofs
213
+ M = jnp.zeros((n_dofs,), dtype=data.dtype)
214
+ M = M.at[rows].add(data)
215
+ return M
216
+
217
+ return FluxSparseMatrix(rows, cols, data, n_dofs=space.n_dofs)
218
+
219
+
220
+ def assemble_linear_form(
221
+ space: SpaceLike,
222
+ form: Kernel[P],
223
+ params: P,
224
+ *,
225
+ sparse: bool = False,
226
+ chunk_size: Optional[int] = None,
227
+ dep: jnp.ndarray | None = None,
228
+ ) -> jnp.ndarray:
229
+ """
230
+ Expects form(ctx, params) -> (n_q, n_ldofs) and integrates Σ_q form * wJ for RHS.
231
+ """
232
+ elem_dofs = space.elem_dofs
233
+ n_dofs = space.n_dofs
234
+ n_ldofs = space.n_ldofs
235
+
236
+ elem_data = space.build_form_contexts(dep=dep)
237
+
238
+ includes_measure = getattr(form, "_includes_measure", False)
239
+
240
+ def per_element(ctx: FormContext):
241
+ integrand = form(ctx, params) # (n_q, m)
242
+ if includes_measure:
243
+ return integrand.sum(axis=0)
244
+ wJ = ctx.w * ctx.test.detJ # (n_q,)
245
+ return (integrand * wJ[:, None]).sum(axis=0) # (m,)
246
+
247
+ if chunk_size is None:
248
+ F_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m)
249
+ data = F_e_all.reshape(-1)
250
+ else:
251
+ n_elems = space.elem_dofs.shape[0]
252
+ m = n_ldofs
253
+ pad = (-n_elems) % chunk_size
254
+ if pad:
255
+ elem_data_pad = jax.tree_util.tree_map(
256
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
257
+ elem_data,
258
+ )
259
+ else:
260
+ elem_data_pad = elem_data
261
+
262
+ n_pad = n_elems + pad
263
+ n_chunks = n_pad // chunk_size
264
+
265
+ def _slice_first_dim(x, start, size):
266
+ start_idx = (start,) + (0,) * (x.ndim - 1)
267
+ slice_sizes = (size,) + x.shape[1:]
268
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
269
+
270
+ def chunk_fn(i):
271
+ start = i * chunk_size
272
+ ctx_chunk = jax.tree_util.tree_map(
273
+ lambda x: _slice_first_dim(x, start, chunk_size),
274
+ elem_data_pad,
275
+ )
276
+ fe = jax.vmap(per_element)(ctx_chunk) # (chunk, m)
277
+ return fe.reshape(-1)
278
+
279
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
280
+ data = data_chunks.reshape(-1)[: n_elems * m]
281
+
282
+ rows = elem_dofs.reshape(-1)
283
+
284
+ if sparse:
285
+ return rows, data, n_dofs
286
+
287
+ F = jax.ops.segment_sum(data, rows, n_dofs)
288
+ return F
289
+
290
+
291
+ def assemble_functional(space: SpaceLike, form: Kernel[P], params: P) -> jnp.ndarray:
292
+ """
293
+ Assemble scalar functional J = ∫ form(ctx, params) dΩ.
294
+ Expects form(ctx, params) -> (n_q,) or (n_q, 1).
295
+ """
296
+ elem_data = space.build_form_contexts()
297
+
298
+ includes_measure = getattr(form, "_includes_measure", False)
299
+
300
+ def per_element(ctx: FormContext):
301
+ integrand = form(ctx, params)
302
+ if integrand.ndim == 2 and integrand.shape[1] == 1:
303
+ integrand = integrand[:, 0]
304
+ if includes_measure:
305
+ return jnp.sum(integrand)
306
+ wJ = ctx.w * ctx.test.detJ
307
+ return jnp.sum(integrand * wJ)
308
+
309
+ vals = jax.vmap(per_element)(elem_data)
310
+ return jnp.sum(vals)
311
+
312
+
313
+ def assemble_jacobian_global(
314
+ space: SpaceLike,
315
+ res_form: ResidualForm[P],
316
+ u: jnp.ndarray,
317
+ params: P,
318
+ *,
319
+ sparse: bool = False,
320
+ return_flux_matrix: bool = False,
321
+ ):
322
+ """
323
+ Assemble Jacobian (dR/du) from element residual res_form.
324
+ res_form(ctx, u_elem, params) -> (n_q, n_ldofs)
325
+ """
326
+ elem_dofs = space.elem_dofs
327
+ n_dofs = space.n_dofs
328
+ n_ldofs = space.n_ldofs
329
+
330
+ elem_data = space.build_form_contexts()
331
+
332
+ def fe_fun(u_elem, ctx: FormContext, elem_id):
333
+ ctx_with_id = FormContext(ctx.test, ctx.trial, ctx.x_q, ctx.w, elem_id)
334
+ integrand = res_form(ctx_with_id, u_elem, params) # (n_q, m)
335
+ wJ = ctx.w * ctx.test.detJ
336
+ fe = (integrand * wJ[:, None]).sum(axis=0) # (m,)
337
+ return fe
338
+
339
+ jac_fun = jax.jacrev(fe_fun, argnums=0)
340
+
341
+ u_elems = u[elem_dofs] # (n_elems, n_ldofs)
342
+ elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
343
+ J_e_all = jax.vmap(jac_fun)(u_elems, elem_data, elem_ids) # (n_elems, m, m)
344
+
345
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
346
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
347
+ data = J_e_all.reshape(-1)
348
+
349
+ if sparse:
350
+ if return_flux_matrix:
351
+ from ..solver import FluxSparseMatrix # local import to avoid circular
352
+ return FluxSparseMatrix(rows, cols, data, n_dofs)
353
+ return rows, cols, data, n_dofs
354
+
355
+ n_entries = n_dofs * n_dofs
356
+ idx = rows * n_dofs + cols
357
+ K_flat = jax.ops.segment_sum(data, idx, n_entries)
358
+ return K_flat.reshape(n_dofs, n_dofs)
359
+
360
+
361
+ def assemble_jacobian_elementwise_xla(
362
+ space: SpaceLike,
363
+ res_form: ResidualForm[P],
364
+ u: jnp.ndarray,
365
+ params: P,
366
+ *,
367
+ sparse: bool = False,
368
+ return_flux_matrix: bool = False,
369
+ ):
370
+ """
371
+ Assemble Jacobian with element kernels in XLA (vmap + scatter_add).
372
+ Recompiles if n_dofs changes, but independent of element count.
373
+ """
374
+ from ..solver import FluxSparseMatrix # local import to avoid circular
375
+
376
+ elem_dofs = space.elem_dofs
377
+ n_dofs = space.n_dofs
378
+ n_ldofs = space.n_ldofs
379
+
380
+ ctxs = space.build_form_contexts()
381
+
382
+ def fe_fun(u_elem, ctx: FormContext):
383
+ integrand = res_form(ctx, u_elem, params)
384
+ wJ = ctx.w * ctx.test.detJ
385
+ return (integrand * wJ[:, None]).sum(axis=0)
386
+
387
+ jac_fun = jax.jacrev(fe_fun, argnums=0)
388
+ u_elems = u[elem_dofs]
389
+ J_e_all = jax.vmap(jac_fun)(u_elems, ctxs) # (n_elems, m, m)
390
+
391
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
392
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
393
+ data = J_e_all.reshape(-1)
394
+
395
+ if sparse:
396
+ if return_flux_matrix:
397
+ return FluxSparseMatrix(rows, cols, data, n_dofs)
398
+ return rows, cols, data, n_dofs
399
+
400
+ n_entries = n_dofs * n_dofs
401
+ idx = rows * n_dofs + cols
402
+ sdn = jax.lax.ScatterDimensionNumbers(
403
+ update_window_dims=(),
404
+ inserted_window_dims=(0,),
405
+ scatter_dims_to_operand_dims=(0,),
406
+ )
407
+ K_flat = jnp.zeros(n_entries, dtype=data.dtype)
408
+ K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
409
+ return K_flat.reshape(pat.n_dofs, pat.n_dofs)
410
+
411
+
412
+ def assemble_residual_global(
413
+ space: SpaceLike,
414
+ form: ResidualForm[P],
415
+ u: jnp.ndarray,
416
+ params: P,
417
+ *,
418
+ sparse: bool = False
419
+ ):
420
+ """
421
+ Assemble residual vector that depends on u.
422
+ form(ctx, u_elem, params) -> (n_q, n_ldofs)
423
+ """
424
+ elem_dofs = space.elem_dofs
425
+ n_dofs = space.n_dofs
426
+ n_ldofs = space.n_ldofs
427
+
428
+ elem_data = space.build_form_contexts()
429
+
430
+ def per_element(ctx: FormContext, conn: jnp.ndarray, elem_id: jnp.ndarray):
431
+ u_elem = u[conn]
432
+ ctx_with_id = FormContext(ctx.test, ctx.trial, ctx.x_q, ctx.w, elem_id)
433
+ integrand = form(ctx_with_id, u_elem, params) # (n_q, m)
434
+ wJ = ctx.w * ctx.test.detJ
435
+ fe = (integrand * wJ[:, None]).sum(axis=0)
436
+ return fe
437
+
438
+ elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
439
+ F_e_all = jax.vmap(per_element)(elem_data, elem_dofs, elem_ids) # (n_elems, m)
440
+
441
+ rows = elem_dofs.reshape(-1)
442
+ data = F_e_all.reshape(-1)
443
+
444
+ if sparse:
445
+ return rows, data, n_dofs
446
+
447
+ F = jax.ops.segment_sum(data, rows, n_dofs)
448
+ return F
449
+
450
+
451
+ def assemble_residual_elementwise_xla(
452
+ space: SpaceLike,
453
+ res_form: ResidualForm[P],
454
+ u: jnp.ndarray,
455
+ params: P,
456
+ *,
457
+ sparse: bool = False,
458
+ ):
459
+ """
460
+ Assemble residual using element kernels fully in XLA (vmap + scatter_add).
461
+ Recompiles if n_dofs changes, but independent of element count.
462
+ """
463
+ elem_dofs = space.elem_dofs
464
+ n_dofs = space.n_dofs
465
+ ctxs = space.build_form_contexts()
466
+
467
+ def per_element(ctx: FormContext, u_elem: jnp.ndarray):
468
+ integrand = res_form(ctx, u_elem, params)
469
+ wJ = ctx.w * ctx.test.detJ
470
+ return (integrand * wJ[:, None]).sum(axis=0)
471
+
472
+ u_elems = u[elem_dofs]
473
+ F_e_all = jax.vmap(per_element)(ctxs, u_elems) # (n_elems, m)
474
+ rows = elem_dofs.reshape(-1)
475
+ data = F_e_all.reshape(-1)
476
+
477
+ if sparse:
478
+ return rows, data, n_dofs
479
+
480
+ sdn = jax.lax.ScatterDimensionNumbers(
481
+ update_window_dims=(),
482
+ inserted_window_dims=(0,),
483
+ scatter_dims_to_operand_dims=(0,),
484
+ )
485
+ F = jnp.zeros(n_dofs, dtype=data.dtype)
486
+ F = jax.lax.scatter_add(F, rows[:, None], data, sdn)
487
+ return F
488
+
489
+
490
+ def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
491
+ """Jitted element residual kernel: (ctx, u_elem) -> fe."""
492
+
493
+ def per_element(ctx: FormContext, u_elem: jnp.ndarray):
494
+ integrand = res_form(ctx, u_elem, params)
495
+ if getattr(res_form, "_includes_measure", False):
496
+ return integrand.sum(axis=0)
497
+ wJ = ctx.w * ctx.test.detJ
498
+ return (integrand * wJ[:, None]).sum(axis=0)
499
+
500
+ return jax.jit(per_element)
501
+
502
+
503
+ def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
504
+ """Jitted element Jacobian kernel: (ctx, u_elem) -> Ke."""
505
+
506
+ def fe_fun(u_elem, ctx: FormContext):
507
+ integrand = res_form(ctx, u_elem, params)
508
+ if getattr(res_form, "_includes_measure", False):
509
+ return integrand.sum(axis=0)
510
+ wJ = ctx.w * ctx.test.detJ
511
+ return (integrand * wJ[:, None]).sum(axis=0)
512
+
513
+ return jax.jit(jax.jacrev(fe_fun, argnums=0))
514
+
515
+
516
+ def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
517
+ """
518
+ Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
519
+ Returns shape (n_ldofs,).
520
+ """
521
+ integrand = res_form(ctx, u_elem, params) # (n_q, n_ldofs) or pytree
522
+ includes_measure = getattr(res_form, "_includes_measure", False)
523
+ if isinstance(integrand, jnp.ndarray):
524
+ if includes_measure:
525
+ return jnp.einsum("qa->a", integrand)
526
+ wJ = ctx.w * ctx.test.detJ # (n_q,)
527
+ return jnp.einsum("qa,q->a", integrand, wJ)
528
+ if hasattr(ctx, "fields") and ctx.fields is not None:
529
+ def _reduce(name, val):
530
+ if isinstance(includes_measure, dict) and includes_measure.get(name, False):
531
+ return jnp.einsum("qa->a", val)
532
+ wJ = ctx.w * ctx.fields[name].test.detJ
533
+ return jnp.einsum("qa,q->a", val, wJ)
534
+
535
+ return {name: _reduce(name, val) for name, val in integrand.items()}
536
+ if includes_measure:
537
+ return jax.tree_util.tree_map(lambda x: jnp.einsum("qa->a", x), integrand)
538
+ return jax.tree_util.tree_map(lambda x: jnp.einsum("qa,q->a", x, ctx.w * ctx.test.detJ), integrand)
539
+
540
+
541
+ def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
542
+ """
543
+ Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
544
+ """
545
+ def _r_elem(u_local):
546
+ return element_residual(res_form, ctx, u_local, params)
547
+
548
+ return jax.jacfwd(_r_elem)(u_elem)
549
+
550
+
551
+ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
552
+ """
553
+ Build a SparsityPattern (rows/cols[/idx]) that is independent of the solution.
554
+ NOTE: rows/cols ordering matches assemble_jacobian_values(...).reshape(-1)
555
+ so that pattern and data are aligned 1:1. If you change the flattening/
556
+ compression strategy, keep this ordering contract in sync.
557
+ """
558
+ from ..solver import SparsityPattern # local import to avoid circular
559
+
560
+ elem_dofs = jnp.asarray(space.elem_dofs, dtype=jnp.int32)
561
+ n_dofs = int(space.n_dofs)
562
+ n_ldofs = int(space.n_ldofs)
563
+
564
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(jnp.int32)
565
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(jnp.int32)
566
+
567
+ key = rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)
568
+ order = jnp.argsort(key).astype(jnp.int32)
569
+ rows_sorted = rows[order]
570
+ cols_sorted = cols[order]
571
+ counts = jnp.bincount(rows_sorted, length=n_dofs).astype(jnp.int32)
572
+ indptr_j = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(counts)])
573
+ indices_j = cols_sorted.astype(jnp.int32)
574
+ perm = order
575
+
576
+ if with_idx:
577
+ idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(jnp.int32)
578
+ return SparsityPattern(
579
+ rows=rows,
580
+ cols=cols,
581
+ n_dofs=n_dofs,
582
+ idx=idx,
583
+ perm=perm,
584
+ indptr=indptr_j,
585
+ indices=indices_j,
586
+ )
587
+ return SparsityPattern(
588
+ rows=rows,
589
+ cols=cols,
590
+ n_dofs=n_dofs,
591
+ idx=None,
592
+ perm=perm,
593
+ indptr=indptr_j,
594
+ indices=indices_j,
595
+ )
596
+
597
+
598
+ def assemble_jacobian_values(
599
+ space: SpaceLike,
600
+ res_form: ResidualForm[P],
601
+ u: jnp.ndarray,
602
+ params: P,
603
+ *,
604
+ kernel=None,
605
+ ):
606
+ """
607
+ Assemble only the numeric values for the Jacobian (pattern-free).
608
+ """
609
+ ctxs = space.build_form_contexts()
610
+ ker = kernel if kernel is not None else make_element_jacobian_kernel(res_form, params)
611
+
612
+ u_elems = u[space.elem_dofs]
613
+ J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
614
+ return J_e_all.reshape(-1)
615
+
616
+
617
+ def assemble_residual_scatter(
618
+ space: SpaceLike,
619
+ res_form: ResidualForm[P],
620
+ u: jnp.ndarray,
621
+ params: P,
622
+ *,
623
+ kernel=None,
624
+ sparse: bool = False,
625
+ ):
626
+ """
627
+ Assemble residual using jitted element kernel + vmap + scatter_add.
628
+ Avoids Python loops; good for JIT stability.
629
+
630
+ Note: `res_form` should return the integrand only; quadrature weights and detJ
631
+ are applied in the element kernel (make_element_residual_kernel). Do not multiply
632
+ by w or detJ inside `res_form`.
633
+ """
634
+ elem_dofs = space.elem_dofs
635
+ n_dofs = space.n_dofs
636
+ if np.max(elem_dofs) >= n_dofs:
637
+ raise ValueError("elem_dofs contains index outside n_dofs")
638
+ if np.min(elem_dofs) < 0:
639
+ raise ValueError("elem_dofs contains negative index")
640
+ ctxs = space.build_form_contexts()
641
+ ker = kernel if kernel is not None else make_element_residual_kernel(res_form, params)
642
+
643
+ u_elems = u[elem_dofs]
644
+ elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
645
+ if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
646
+ bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
647
+ raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
648
+
649
+ rows = elem_dofs.reshape(-1)
650
+ data = elem_res.reshape(-1)
651
+
652
+ if sparse:
653
+ return rows, data, n_dofs
654
+
655
+ sdn = jax.lax.ScatterDimensionNumbers(
656
+ update_window_dims=(),
657
+ inserted_window_dims=(0,),
658
+ scatter_dims_to_operand_dims=(0,),
659
+ )
660
+ F = jnp.zeros((n_dofs,), dtype=data.dtype)
661
+ F = jax.lax.scatter_add(F, rows[:, None], data, sdn)
662
+ return F
663
+
664
+
665
+ def assemble_jacobian_scatter(
666
+ space: SpaceLike,
667
+ res_form: ResidualForm[P],
668
+ u: jnp.ndarray,
669
+ params: P,
670
+ *,
671
+ kernel=None,
672
+ sparse: bool = False,
673
+ return_flux_matrix: bool = False,
674
+ pattern=None,
675
+ ):
676
+ """
677
+ Assemble Jacobian using jitted element kernel + vmap + scatter_add.
678
+ If a SparsityPattern is provided, rows/cols are reused without regeneration.
679
+ CONTRACT: The returned `data` ordering matches `pattern.rows/cols` exactly.
680
+ Any change to pattern generation or data flattening must preserve this.
681
+ """
682
+ from ..solver import FluxSparseMatrix # local import to avoid circular
683
+
684
+ pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
685
+ data = assemble_jacobian_values(space, res_form, u, params, kernel=kernel)
686
+
687
+ if sparse:
688
+ if return_flux_matrix:
689
+ return FluxSparseMatrix(pat, data)
690
+ return pat.rows, pat.cols, data, pat.n_dofs
691
+
692
+ idx = pat.idx
693
+ if idx is None:
694
+ idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(jnp.int32)
695
+
696
+ n_entries = pat.n_dofs * pat.n_dofs
697
+ sdn = jax.lax.ScatterDimensionNumbers(
698
+ update_window_dims=(),
699
+ inserted_window_dims=(0,),
700
+ scatter_dims_to_operand_dims=(0,),
701
+ )
702
+ K_flat = jnp.zeros(n_entries, dtype=data.dtype)
703
+ K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
704
+ return K_flat.reshape(pat.n_dofs, pat.n_dofs)
705
+
706
+
707
+ # Alias scatter-based assembly as the default public API
708
+ def assemble_residual(
709
+ space: SpaceLike,
710
+ form: ResidualForm[P],
711
+ u: jnp.ndarray, params: P,
712
+ *,
713
+ sparse: bool = False
714
+ ):
715
+ """Assemble the global residual vector (scatter-based)."""
716
+ return assemble_residual_scatter(space, form, u, params, sparse=sparse)
717
+
718
+
719
+ def assemble_jacobian(
720
+ space: SpaceLike,
721
+ res_form: ResidualForm[P],
722
+ u: jnp.ndarray,
723
+ params: P,
724
+ *,
725
+ sparse: bool = True,
726
+ return_flux_matrix: bool = False,
727
+ pattern=None,
728
+ ):
729
+ """Assemble the global Jacobian (scatter-based)."""
730
+ return assemble_jacobian_scatter(
731
+ space,
732
+ res_form,
733
+ u,
734
+ params,
735
+ sparse=sparse,
736
+ return_flux_matrix=return_flux_matrix,
737
+ pattern=pattern,
738
+ )
739
+
740
+
741
+ def _make_unit_cube_mesh() -> HexMesh:
742
+ """Single hex element on [0, 1]^3."""
743
+ return StructuredHexBox(nx=1, ny=1, nz=1, lx=1.0, ly=1.0, lz=1.0).build()
744
+
745
+
746
+ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
747
+ """Linear form for constant scalar body force: f * N."""
748
+ return load * ctx.test.N # (n_q, n_ldofs)
749
+
750
+
751
+ def make_scalar_body_force_form(body_force):
752
+ """
753
+ Build a scalar linear form from a callable f(x_q) -> (n_q,).
754
+ """
755
+ def _form(ctx: FormContext, _params):
756
+ f_q = body_force(ctx.x_q)
757
+ return f_q[..., None] * ctx.test.N
758
+ return _form
759
+
760
+
761
+ # Backward compatibility alias
762
+ constant_body_force_form = scalar_body_force_form
763
+
764
+
765
+ def _check_structured_box_connectivity():
766
+ """Quick connectivity check for nx=2, ny=1, nz=1 (non-structured order)."""
767
+ box = StructuredHexBox(nx=2, ny=1, nz=1, lx=2.0, ly=1.0, lz=1.0)
768
+ mesh = box.build()
769
+
770
+ assert mesh.coords.shape == (12, 3)
771
+ assert mesh.conn.shape == (2, 8)
772
+
773
+ expected_conn = jnp.array(
774
+ [
775
+ [0, 1, 4, 3, 6, 7, 10, 9], # element at i=0
776
+ [1, 2, 5, 4, 7, 8, 11, 10], # element at i=1
777
+ ],
778
+ dtype=jnp.int32,
779
+ )
780
+ max_diff = int(jnp.max(jnp.abs(mesh.conn - expected_conn)))
781
+ print("StructuredHexBox nx=2,ny=1,nz=1 conn matches expected:", max_diff == 0)
782
+ if max_diff != 0:
783
+ print("expected conn:\n", expected_conn)
784
+ print("got conn:\n", mesh.conn)
785
+
786
+
787
+ if __name__ == "__main__":
788
+ _check_structured_box_connectivity()