fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. fluxfem/__init__.py +136 -161
  2. fluxfem/core/__init__.py +172 -41
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/context_types.py +36 -0
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +15 -1
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +348 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +262 -17
  13. fluxfem/core/weakform.py +1503 -312
  14. fluxfem/helpers_wf.py +53 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +322 -8
  17. fluxfem/mesh/contact.py +825 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +18 -16
  20. fluxfem/mesh/io.py +8 -4
  21. fluxfem/mesh/mortar.py +3907 -0
  22. fluxfem/mesh/supermesh.py +316 -0
  23. fluxfem/mesh/surface.py +22 -4
  24. fluxfem/mesh/tet.py +10 -4
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  27. fluxfem/physics/elasticity/linear.py +9 -2
  28. fluxfem/solver/__init__.py +42 -2
  29. fluxfem/solver/bc.py +38 -2
  30. fluxfem/solver/block_matrix.py +132 -0
  31. fluxfem/solver/block_system.py +454 -0
  32. fluxfem/solver/cg.py +115 -33
  33. fluxfem/solver/dirichlet.py +334 -4
  34. fluxfem/solver/newton.py +237 -60
  35. fluxfem/solver/petsc.py +439 -0
  36. fluxfem/solver/preconditioner.py +106 -0
  37. fluxfem/solver/result.py +18 -0
  38. fluxfem/solver/solve_runner.py +168 -1
  39. fluxfem/solver/solver.py +12 -1
  40. fluxfem/solver/sparse.py +124 -9
  41. fluxfem-0.2.0.dist-info/METADATA +303 -0
  42. fluxfem-0.2.0.dist-info/RECORD +59 -0
  43. fluxfem-0.1.3.dist-info/METADATA +0 -125
  44. fluxfem-0.1.3.dist-info/RECORD +0 -47
  45. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  46. {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/assembly.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
- from typing import Callable, Protocol, TypeVar, Optional
2
+ from typing import Any, Callable, Literal, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union
3
3
  import numpy as np
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
 
7
7
  from ..mesh import HexMesh, StructuredHexBox
8
+ from .dtypes import INDEX_DTYPE
8
9
  from .forms import FormContext
9
10
  from .space import FESpaceBase
10
11
 
@@ -16,6 +17,341 @@ Kernel = Callable[[FormContext, P], Array]
16
17
  ResidualForm = Callable[[FormContext, Array, P], Array]
17
18
  ElementDofMapper = Callable[[Array], Array]
18
19
 
20
+ if TYPE_CHECKING:
21
+ from ..solver import FluxSparseMatrix, SparsityPattern
22
+ else:
23
+ FluxSparseMatrix = Any
24
+ SparsityPattern = Any
25
+
26
+ SparseCOO: TypeAlias = tuple[Array, Array, Array, int]
27
+ LinearCOO: TypeAlias = tuple[Array, Array, int]
28
+ JacobianReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
29
+ BilinearReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
30
+ LinearReturn: TypeAlias = Union[Array, LinearCOO]
31
+ MassReturn: TypeAlias = Union[FluxSparseMatrix, Array]
32
+
33
+
34
+ class ElementBilinearKernel(Protocol):
35
+ def __call__(self, ctx: FormContext) -> Array: ...
36
+
37
+
38
+ class ElementLinearKernel(Protocol):
39
+ def __call__(self, ctx: FormContext) -> Array: ...
40
+
41
+
42
+ class ElementResidualKernel(Protocol):
43
+ def __call__(self, ctx: FormContext, u_elem: Array) -> Array: ...
44
+
45
+
46
+ class ElementJacobianKernel(Protocol):
47
+ def __call__(self, u_elem: Array, ctx: FormContext) -> Array: ...
48
+
49
+ ElementKernel: TypeAlias = (
50
+ ElementBilinearKernel
51
+ | ElementLinearKernel
52
+ | ElementResidualKernel
53
+ | ElementJacobianKernel
54
+ )
55
+ def _get_pattern(space: SpaceLike, *, with_idx: bool) -> SparsityPattern | None:
56
+ if hasattr(space, "get_sparsity_pattern"):
57
+ return space.get_sparsity_pattern(with_idx=with_idx)
58
+ return None
59
+
60
+
61
+ def _get_elem_rows(space: SpaceLike) -> Array:
62
+ if hasattr(space, "get_elem_rows"):
63
+ return space.get_elem_rows()
64
+ return space.elem_dofs.reshape(-1)
65
+
66
+
67
+ def chunk_pad_stats(n_elems: int, n_chunks: Optional[int]) -> dict[str, int | float | None]:
68
+ """
69
+ Compute padding overhead for chunked assembly.
70
+ Returns dict with chunk_size, pad, n_pad, and pad_ratio.
71
+ """
72
+ n_elems = int(n_elems)
73
+ if n_chunks is None or n_elems <= 0:
74
+ return {"chunk_size": None, "pad": 0, "n_pad": n_elems, "pad_ratio": 0.0}
75
+ n_chunks = min(int(n_chunks), n_elems)
76
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
77
+ pad = (-n_elems) % chunk_size
78
+ n_pad = n_elems + pad
79
+ pad_ratio = float(pad) / float(n_elems) if n_elems else 0.0
80
+ return {"chunk_size": int(chunk_size), "pad": int(pad), "n_pad": int(n_pad), "pad_ratio": pad_ratio}
81
+
82
+
83
+ def _maybe_trace_pad(
84
+ stats: dict[str, int | float | None], *, n_chunks: Optional[int], pad_trace: bool
85
+ ) -> None:
86
+ if not pad_trace or not jax.core.trace_ctx.is_top_level():
87
+ return
88
+ if n_chunks is None:
89
+ return
90
+ print(
91
+ "[pad]",
92
+ f"n_chunks={int(n_chunks)}",
93
+ f"chunk_size={stats['chunk_size']}",
94
+ f"pad={stats['pad']}",
95
+ f"pad_ratio={stats['pad_ratio']:.4f}",
96
+ flush=True,
97
+ )
98
+
99
+
100
+ class BatchedAssembler:
101
+ """
102
+ Assemble on a fixed space with optional masking to keep shapes static.
103
+
104
+ Use `mask` to zero padded elements while keeping input shapes fixed.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ space: SpaceLike,
110
+ elem_data: Any,
111
+ elem_dofs: Array,
112
+ *,
113
+ pattern: SparsityPattern | None = None,
114
+ ) -> None:
115
+ self.space = space
116
+ self.elem_data = elem_data
117
+ self.elem_dofs = elem_dofs
118
+ self.n_elems = int(elem_dofs.shape[0])
119
+ self.n_ldofs = int(space.n_ldofs)
120
+ self.n_dofs = int(space.n_dofs)
121
+ self.pattern = pattern
122
+ self._rows = None
123
+ self._cols = None
124
+
125
+ @classmethod
126
+ def from_space(
127
+ cls,
128
+ space: SpaceLike,
129
+ *,
130
+ dep: jnp.ndarray | None = None,
131
+ pattern: SparsityPattern | None = None,
132
+ ) -> "BatchedAssembler":
133
+ elem_data = space.build_form_contexts(dep=dep)
134
+ return cls(space, elem_data, space.elem_dofs, pattern=pattern)
135
+
136
+ def make_mask(self, n_active: int) -> Array:
137
+ n_active = max(0, min(int(n_active), self.n_elems))
138
+ mask = np.zeros((self.n_elems,), dtype=float)
139
+ if n_active:
140
+ mask[:n_active] = 1.0
141
+ return jnp.asarray(mask)
142
+
143
+ def slice(self, n_active: int) -> "BatchedAssembler":
144
+ n_active = max(0, min(int(n_active), self.n_elems))
145
+ elem_data = jax.tree_util.tree_map(lambda x: x[:n_active], self.elem_data)
146
+ elem_dofs = self.elem_dofs[:n_active]
147
+ return BatchedAssembler(self.space, elem_data, elem_dofs, pattern=None)
148
+
149
+ def _rows_cols(self) -> tuple[Array, Array]:
150
+ if self.pattern is not None:
151
+ return self.pattern.rows, self.pattern.cols
152
+ if self._rows is None or self._cols is None:
153
+ elem_dofs = self.elem_dofs
154
+ n_ldofs = int(elem_dofs.shape[1])
155
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
156
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
157
+ self._rows = rows
158
+ self._cols = cols
159
+ return self._rows, self._cols
160
+
161
+ def assemble_bilinear_with_kernel(
162
+ self, kernel: ElementBilinearKernel, *, mask: Array | None = None
163
+ ) -> FluxSparseMatrix:
164
+ """
165
+ kernel(ctx) -> (n_ldofs, n_ldofs)
166
+ """
167
+ from ..solver import FluxSparseMatrix
168
+
169
+ Ke = jax.vmap(kernel)(self.elem_data)
170
+ if mask is not None:
171
+ Ke = Ke * jnp.asarray(mask)[:, None, None]
172
+ data = Ke.reshape(-1)
173
+ if self.pattern is not None:
174
+ return FluxSparseMatrix(self.pattern, data)
175
+ rows, cols = self._rows_cols()
176
+ return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
177
+
178
+ def assemble_bilinear(
179
+ self,
180
+ form: Kernel[P],
181
+ params: P,
182
+ *,
183
+ mask: Array | None = None,
184
+ kernel: ElementBilinearKernel | None = None,
185
+ jit: bool = True,
186
+ ) -> FluxSparseMatrix:
187
+ if kernel is None:
188
+ kernel = make_element_bilinear_kernel(form, params, jit=jit)
189
+ return self.assemble_bilinear_with_kernel(kernel, mask=mask)
190
+
191
+ def assemble_linear_with_kernel(
192
+ self,
193
+ kernel: ElementLinearKernel,
194
+ *,
195
+ mask: Array | None = None,
196
+ dep: jnp.ndarray | None = None,
197
+ ) -> Array:
198
+ """
199
+ kernel(ctx) -> (n_ldofs,)
200
+ """
201
+ elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
202
+ Fe = jax.vmap(kernel)(elem_data)
203
+ if mask is not None:
204
+ Fe = Fe * jnp.asarray(mask)[:, None]
205
+ rows = self.elem_dofs.reshape(-1)
206
+ data = Fe.reshape(-1)
207
+ return jax.ops.segment_sum(data, rows, self.n_dofs)
208
+
209
+ def assemble_linear(
210
+ self,
211
+ form: Kernel[P],
212
+ params: P,
213
+ *,
214
+ mask: Array | None = None,
215
+ dep: jnp.ndarray | None = None,
216
+ kernel: ElementLinearKernel | None = None,
217
+ ) -> Array:
218
+ if kernel is not None:
219
+ return self.assemble_linear_with_kernel(kernel, mask=mask, dep=dep)
220
+ elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
221
+ includes_measure = getattr(form, "_includes_measure", False)
222
+
223
+ def per_element(ctx: FormContext):
224
+ integrand = form(ctx, params)
225
+ if includes_measure:
226
+ return integrand.sum(axis=0)
227
+ wJ = ctx.w * ctx.test.detJ
228
+ return (integrand * wJ[:, None]).sum(axis=0)
229
+
230
+ Fe = jax.vmap(per_element)(elem_data)
231
+ if mask is not None:
232
+ Fe = Fe * jnp.asarray(mask)[:, None]
233
+ rows = self.elem_dofs.reshape(-1)
234
+ data = Fe.reshape(-1)
235
+ return jax.ops.segment_sum(data, rows, self.n_dofs)
236
+
237
+ def assemble_mass_matrix(
238
+ self, *, mask: Array | None = None, lumped: bool = False
239
+ ) -> MassReturn:
240
+ from ..solver import FluxSparseMatrix
241
+
242
+ n_ldofs = self.n_ldofs
243
+
244
+ def per_element(ctx: FormContext):
245
+ N = ctx.test.N
246
+ base = jnp.einsum("qa,qb->qab", N, N)
247
+ if hasattr(ctx.test, "value_dim"):
248
+ vd = int(ctx.test.value_dim)
249
+ I = jnp.eye(vd, dtype=N.dtype)
250
+ base = base[:, :, :, None, None] * I[None, None, None, :, :]
251
+ base = base.reshape(base.shape[0], n_ldofs, n_ldofs)
252
+ wJ = ctx.w * ctx.test.detJ
253
+ return jnp.einsum("qab,q->ab", base, wJ)
254
+
255
+ Me = jax.vmap(per_element)(self.elem_data)
256
+ if mask is not None:
257
+ Me = Me * jnp.asarray(mask)[:, None, None]
258
+ data = Me.reshape(-1)
259
+ rows, cols = self._rows_cols()
260
+
261
+ if lumped:
262
+ M = jnp.zeros((self.n_dofs,), dtype=data.dtype)
263
+ M = M.at[rows].add(data)
264
+ return M
265
+
266
+ return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
267
+
268
+ def assemble_residual_with_kernel(
269
+ self, kernel: ElementResidualKernel, u: Array, *, mask: Array | None = None
270
+ ) -> Array:
271
+ """
272
+ kernel(ctx, u_elem) -> (n_ldofs,)
273
+ """
274
+ u_elems = jnp.asarray(u)[self.elem_dofs]
275
+ elem_res = jax.vmap(kernel)(self.elem_data, u_elems)
276
+ if mask is not None:
277
+ elem_res = elem_res * jnp.asarray(mask)[:, None]
278
+ rows = self.elem_dofs.reshape(-1)
279
+ data = elem_res.reshape(-1)
280
+ return jax.ops.segment_sum(data, rows, self.n_dofs)
281
+
282
+ def assemble_residual(
283
+ self,
284
+ res_form: ResidualForm[P],
285
+ u: Array,
286
+ params: P,
287
+ *,
288
+ mask: Array | None = None,
289
+ kernel: ElementResidualKernel | None = None,
290
+ ) -> Array:
291
+ if kernel is None:
292
+ kernel = make_element_residual_kernel(res_form, params)
293
+ return self.assemble_residual_with_kernel(kernel, u, mask=mask)
294
+
295
+ def assemble_jacobian_with_kernel(
296
+ self,
297
+ kernel: ElementJacobianKernel,
298
+ u: Array,
299
+ *,
300
+ mask: Array | None = None,
301
+ sparse: bool = True,
302
+ return_flux_matrix: bool = False,
303
+ ) -> JacobianReturn:
304
+ """
305
+ kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
306
+ """
307
+ from ..solver import FluxSparseMatrix # local import to avoid circular
308
+
309
+ u_elems = jnp.asarray(u)[self.elem_dofs]
310
+ J_e = jax.vmap(kernel)(u_elems, self.elem_data)
311
+ if mask is not None:
312
+ J_e = J_e * jnp.asarray(mask)[:, None, None]
313
+ data = J_e.reshape(-1)
314
+ if sparse:
315
+ if self.pattern is not None:
316
+ if return_flux_matrix:
317
+ return FluxSparseMatrix(self.pattern, data)
318
+ return self.pattern.rows, self.pattern.cols, data, self.n_dofs
319
+ rows, cols = self._rows_cols()
320
+ if return_flux_matrix:
321
+ return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
322
+ return rows, cols, data, self.n_dofs
323
+ rows, cols = self._rows_cols()
324
+ idx = (rows.astype(jnp.int64) * int(self.n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
325
+ n_entries = self.n_dofs * self.n_dofs
326
+ sdn = jax.lax.ScatterDimensionNumbers(
327
+ update_window_dims=(),
328
+ inserted_window_dims=(0,),
329
+ scatter_dims_to_operand_dims=(0,),
330
+ )
331
+ K_flat = jnp.zeros(n_entries, dtype=data.dtype)
332
+ K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
333
+ return K_flat.reshape(self.n_dofs, self.n_dofs)
334
+
335
+ def assemble_jacobian(
336
+ self,
337
+ res_form: ResidualForm[P],
338
+ u: Array,
339
+ params: P,
340
+ *,
341
+ mask: Array | None = None,
342
+ kernel: ElementJacobianKernel | None = None,
343
+ sparse: bool = True,
344
+ return_flux_matrix: bool = False,
345
+ ) -> JacobianReturn:
346
+ if kernel is None:
347
+ kernel = make_element_jacobian_kernel(res_form, params)
348
+ return self.assemble_jacobian_with_kernel(
349
+ kernel,
350
+ u,
351
+ mask=mask,
352
+ sparse=sparse,
353
+ return_flux_matrix=return_flux_matrix,
354
+ )
19
355
 
20
356
  class SpaceLike(FESpaceBase, Protocol):
21
357
  pass
@@ -28,7 +364,7 @@ def assemble_bilinear_dense(
28
364
  *,
29
365
  sparse: bool = False,
30
366
  return_flux_matrix: bool = False,
31
- ):
367
+ ) -> BilinearReturn:
32
368
  """
33
369
  Similar to scikit-fem's asm(biform, basis).
34
370
  kernel: FormContext, params -> (n_ldofs, n_ldofs)
@@ -47,11 +383,15 @@ def assemble_bilinear_dense(
47
383
 
48
384
  # ---- scatter into COO format ----
49
385
  # row/col indices (n_elems, n_ldofs, n_ldofs)
50
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
51
- cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
52
-
53
- rows = rows.reshape(-1)
54
- cols = cols.reshape(-1)
386
+ pat = _get_pattern(space, with_idx=False)
387
+ if pat is None:
388
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
389
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
390
+ rows = rows.reshape(-1)
391
+ cols = cols.reshape(-1)
392
+ else:
393
+ rows = pat.rows
394
+ cols = pat.cols
55
395
  data = K_e_all.reshape(-1)
56
396
 
57
397
  # Flatten indices for segment_sum via (row * n_dofs + col)
@@ -71,18 +411,22 @@ def assemble_bilinear_dense(
71
411
 
72
412
 
73
413
  def assemble_bilinear_form(
74
- space,
75
- form,
76
- params,
414
+ space: SpaceLike,
415
+ form: Kernel[P],
416
+ params: P,
77
417
  *,
78
- pattern=None,
79
- chunk_size: Optional[int] = None, # None -> no-chunk (old behavior)
418
+ pattern: SparsityPattern | None = None,
419
+ n_chunks: Optional[int] = None, # None -> no chunking
80
420
  dep: jnp.ndarray | None = None,
81
- ):
421
+ kernel: ElementBilinearKernel | None = None,
422
+ jit: bool = True,
423
+ pad_trace: bool = False,
424
+ ) -> FluxSparseMatrix:
82
425
  """
83
426
  Assemble a sparse bilinear form into a FluxSparseMatrix.
84
427
 
85
428
  Expects form(ctx, params) -> (n_q, n_ldofs, n_ldofs).
429
+ If kernel is provided: kernel(ctx) -> (n_ldofs, n_ldofs).
86
430
  """
87
431
  from ..solver import FluxSparseMatrix
88
432
 
@@ -104,18 +448,27 @@ def assemble_bilinear_form(
104
448
  wJ = ctx.w * ctx.test.detJ # (n_q,)
105
449
  return (integrand * wJ[:, None, None]).sum(axis=0) # (m, m)
106
450
 
451
+ if kernel is None:
452
+ kernel = make_element_bilinear_kernel(form, params, jit=jit)
453
+
107
454
  # --- no-chunk path (your current implementation) ---
108
- if chunk_size is None:
109
- K_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m, m)
455
+ if n_chunks is None:
456
+ K_e_all = jax.vmap(kernel)(elem_data) # (n_elems, m, m)
110
457
  data = K_e_all.reshape(-1)
111
458
  return FluxSparseMatrix(pat, data)
112
459
 
113
460
  # --- chunked path ---
114
461
  n_elems = space.elem_dofs.shape[0]
462
+ if n_chunks <= 0:
463
+ raise ValueError("n_chunks must be a positive integer.")
464
+ n_chunks = min(int(n_chunks), int(n_elems))
465
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
466
+ stats = chunk_pad_stats(n_elems, n_chunks)
467
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
115
468
  # Ideally get m from pat (otherwise infer from one element).
116
469
  m = getattr(pat, "n_ldofs", None)
117
470
  if m is None:
118
- m = per_element(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
471
+ m = kernel(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
119
472
 
120
473
  # Pad to fixed-size chunks for JIT stability.
121
474
  pad = (-n_elems) % chunk_size
@@ -141,7 +494,7 @@ def assemble_bilinear_form(
141
494
  lambda x: _slice_first_dim(x, start, chunk_size),
142
495
  elem_data_pad,
143
496
  )
144
- Ke = jax.vmap(per_element)(ctx_chunk) # (chunk, m, m)
497
+ Ke = jax.vmap(kernel)(ctx_chunk) # (chunk, m, m)
145
498
  return Ke.reshape(-1) # (chunk*m*m,)
146
499
 
147
500
  data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
@@ -149,7 +502,13 @@ def assemble_bilinear_form(
149
502
  return FluxSparseMatrix(pat, data)
150
503
 
151
504
 
152
- def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size: Optional[int] = None):
505
+ def assemble_mass_matrix(
506
+ space: SpaceLike,
507
+ *,
508
+ lumped: bool = False,
509
+ n_chunks: Optional[int] = None,
510
+ pad_trace: bool = False,
511
+ ) -> MassReturn:
153
512
  """
154
513
  Assemble mass matrix M_ij = ∫ N_i N_j dΩ.
155
514
  Supports scalar and vector spaces. If lumped=True, rows are summed to diagonal.
@@ -170,11 +529,17 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
170
529
  wJ = ctx.w * ctx.test.detJ
171
530
  return jnp.einsum("qab,q->ab", base, wJ)
172
531
 
173
- if chunk_size is None:
532
+ if n_chunks is None:
174
533
  M_e_all = jax.vmap(per_element)(ctxs) # (n_elems, n_ldofs, n_ldofs)
175
534
  data = M_e_all.reshape(-1)
176
535
  else:
177
536
  n_elems = space.elem_dofs.shape[0]
537
+ if n_chunks <= 0:
538
+ raise ValueError("n_chunks must be a positive integer.")
539
+ n_chunks = min(int(n_chunks), int(n_elems))
540
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
541
+ stats = chunk_pad_stats(n_elems, n_chunks)
542
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
178
543
  pad = (-n_elems) % chunk_size
179
544
  if pad:
180
545
  ctxs_pad = jax.tree_util.tree_map(
@@ -205,8 +570,13 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
205
570
  data = data_chunks.reshape(-1)[: n_elems * n_ldofs * n_ldofs]
206
571
 
207
572
  elem_dofs = space.elem_dofs
208
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
209
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
573
+ pat = _get_pattern(space, with_idx=False)
574
+ if pat is None:
575
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
576
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
577
+ else:
578
+ rows = pat.rows
579
+ cols = pat.cols
210
580
 
211
581
  if lumped:
212
582
  n_dofs = space.n_dofs
@@ -222,12 +592,15 @@ def assemble_linear_form(
222
592
  form: Kernel[P],
223
593
  params: P,
224
594
  *,
595
+ kernel: ElementLinearKernel | None = None,
225
596
  sparse: bool = False,
226
- chunk_size: Optional[int] = None,
597
+ n_chunks: Optional[int] = None,
227
598
  dep: jnp.ndarray | None = None,
228
- ) -> jnp.ndarray:
599
+ pad_trace: bool = False,
600
+ ) -> LinearReturn:
229
601
  """
230
602
  Expects form(ctx, params) -> (n_q, n_ldofs) and integrates Σ_q form * wJ for RHS.
603
+ If kernel is provided: kernel(ctx) -> (n_ldofs,).
231
604
  """
232
605
  elem_dofs = space.elem_dofs
233
606
  n_dofs = space.n_dofs
@@ -237,19 +610,28 @@ def assemble_linear_form(
237
610
 
238
611
  includes_measure = getattr(form, "_includes_measure", False)
239
612
 
240
- def per_element(ctx: FormContext):
241
- integrand = form(ctx, params) # (n_q, m)
242
- if includes_measure:
243
- return integrand.sum(axis=0)
244
- wJ = ctx.w * ctx.test.detJ # (n_q,)
245
- return (integrand * wJ[:, None]).sum(axis=0) # (m,)
613
+ if kernel is None:
614
+ def per_element(ctx: FormContext):
615
+ integrand = form(ctx, params) # (n_q, m)
616
+ if includes_measure:
617
+ return integrand.sum(axis=0)
618
+ wJ = ctx.w * ctx.test.detJ # (n_q,)
619
+ return (integrand * wJ[:, None]).sum(axis=0) # (m,)
620
+ else:
621
+ per_element = kernel
246
622
 
247
- if chunk_size is None:
623
+ if n_chunks is None:
248
624
  F_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m)
249
625
  data = F_e_all.reshape(-1)
250
626
  else:
251
627
  n_elems = space.elem_dofs.shape[0]
252
628
  m = n_ldofs
629
+ if n_chunks <= 0:
630
+ raise ValueError("n_chunks must be a positive integer.")
631
+ n_chunks = min(int(n_chunks), int(n_elems))
632
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
633
+ stats = chunk_pad_stats(n_elems, n_chunks)
634
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
253
635
  pad = (-n_elems) % chunk_size
254
636
  if pad:
255
637
  elem_data_pad = jax.tree_util.tree_map(
@@ -279,7 +661,7 @@ def assemble_linear_form(
279
661
  data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
280
662
  data = data_chunks.reshape(-1)[: n_elems * m]
281
663
 
282
- rows = elem_dofs.reshape(-1)
664
+ rows = _get_elem_rows(space)
283
665
 
284
666
  if sparse:
285
667
  return rows, data, n_dofs
@@ -318,7 +700,7 @@ def assemble_jacobian_global(
318
700
  *,
319
701
  sparse: bool = False,
320
702
  return_flux_matrix: bool = False,
321
- ):
703
+ ) -> JacobianReturn:
322
704
  """
323
705
  Assemble Jacobian (dR/du) from element residual res_form.
324
706
  res_form(ctx, u_elem, params) -> (n_q, n_ldofs)
@@ -339,11 +721,16 @@ def assemble_jacobian_global(
339
721
  jac_fun = jax.jacrev(fe_fun, argnums=0)
340
722
 
341
723
  u_elems = u[elem_dofs] # (n_elems, n_ldofs)
342
- elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
724
+ elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
343
725
  J_e_all = jax.vmap(jac_fun)(u_elems, elem_data, elem_ids) # (n_elems, m, m)
344
726
 
345
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
346
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
727
+ pat = _get_pattern(space, with_idx=False)
728
+ if pat is None:
729
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
730
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
731
+ else:
732
+ rows = pat.rows
733
+ cols = pat.cols
347
734
  data = J_e_all.reshape(-1)
348
735
 
349
736
  if sparse:
@@ -358,7 +745,7 @@ def assemble_jacobian_global(
358
745
  return K_flat.reshape(n_dofs, n_dofs)
359
746
 
360
747
 
361
- def assemble_jacobian_elementwise_xla(
748
+ def assemble_jacobian_elementwise(
362
749
  space: SpaceLike,
363
750
  res_form: ResidualForm[P],
364
751
  u: jnp.ndarray,
@@ -366,9 +753,9 @@ def assemble_jacobian_elementwise_xla(
366
753
  *,
367
754
  sparse: bool = False,
368
755
  return_flux_matrix: bool = False,
369
- ):
756
+ ) -> JacobianReturn:
370
757
  """
371
- Assemble Jacobian with element kernels in XLA (vmap + scatter_add).
758
+ Assemble Jacobian with element kernels via vmap + scatter_add.
372
759
  Recompiles if n_dofs changes, but independent of element count.
373
760
  """
374
761
  from ..solver import FluxSparseMatrix # local import to avoid circular
@@ -388,8 +775,13 @@ def assemble_jacobian_elementwise_xla(
388
775
  u_elems = u[elem_dofs]
389
776
  J_e_all = jax.vmap(jac_fun)(u_elems, ctxs) # (n_elems, m, m)
390
777
 
391
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
392
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
778
+ pat = _get_pattern(space, with_idx=False)
779
+ if pat is None:
780
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
781
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
782
+ else:
783
+ rows = pat.rows
784
+ cols = pat.cols
393
785
  data = J_e_all.reshape(-1)
394
786
 
395
787
  if sparse:
@@ -416,7 +808,7 @@ def assemble_residual_global(
416
808
  params: P,
417
809
  *,
418
810
  sparse: bool = False
419
- ):
811
+ ) -> LinearReturn:
420
812
  """
421
813
  Assemble residual vector that depends on u.
422
814
  form(ctx, u_elem, params) -> (n_q, n_ldofs)
@@ -435,10 +827,10 @@ def assemble_residual_global(
435
827
  fe = (integrand * wJ[:, None]).sum(axis=0)
436
828
  return fe
437
829
 
438
- elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
830
+ elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
439
831
  F_e_all = jax.vmap(per_element)(elem_data, elem_dofs, elem_ids) # (n_elems, m)
440
832
 
441
- rows = elem_dofs.reshape(-1)
833
+ rows = _get_elem_rows(space)
442
834
  data = F_e_all.reshape(-1)
443
835
 
444
836
  if sparse:
@@ -448,16 +840,16 @@ def assemble_residual_global(
448
840
  return F
449
841
 
450
842
 
451
- def assemble_residual_elementwise_xla(
843
+ def assemble_residual_elementwise(
452
844
  space: SpaceLike,
453
845
  res_form: ResidualForm[P],
454
846
  u: jnp.ndarray,
455
847
  params: P,
456
848
  *,
457
849
  sparse: bool = False,
458
- ):
850
+ ) -> LinearReturn:
459
851
  """
460
- Assemble residual using element kernels fully in XLA (vmap + scatter_add).
852
+ Assemble residual using element kernels via vmap + scatter_add.
461
853
  Recompiles if n_dofs changes, but independent of element count.
462
854
  """
463
855
  elem_dofs = space.elem_dofs
@@ -471,7 +863,7 @@ def assemble_residual_elementwise_xla(
471
863
 
472
864
  u_elems = u[elem_dofs]
473
865
  F_e_all = jax.vmap(per_element)(ctxs, u_elems) # (n_elems, m)
474
- rows = elem_dofs.reshape(-1)
866
+ rows = _get_elem_rows(space)
475
867
  data = F_e_all.reshape(-1)
476
868
 
477
869
  if sparse:
@@ -487,7 +879,44 @@ def assemble_residual_elementwise_xla(
487
879
  return F
488
880
 
489
881
 
490
- def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
882
+ # Backward compatibility aliases (prefer assemble_*_elementwise).
883
+ assemble_jacobian_elementwise_xla = assemble_jacobian_elementwise
884
+ assemble_residual_elementwise_xla = assemble_residual_elementwise
885
+
886
+
887
+ def make_element_bilinear_kernel(
888
+ form: Kernel[P], params: P, *, jit: bool = True
889
+ ) -> ElementBilinearKernel:
890
+ """Element kernel: (ctx) -> Ke."""
891
+
892
+ def per_element(ctx: FormContext):
893
+ integrand = form(ctx, params)
894
+ if getattr(form, "_includes_measure", False):
895
+ return integrand.sum(axis=0)
896
+ wJ = ctx.w * ctx.test.detJ
897
+ return (integrand * wJ[:, None, None]).sum(axis=0)
898
+
899
+ return jax.jit(per_element) if jit else per_element
900
+
901
+
902
+ def make_element_linear_kernel(
903
+ form: Kernel[P], params: P, *, jit: bool = True
904
+ ) -> ElementLinearKernel:
905
+ """Element kernel: (ctx) -> fe."""
906
+
907
+ def per_element(ctx: FormContext):
908
+ integrand = form(ctx, params)
909
+ if getattr(form, "_includes_measure", False):
910
+ return integrand.sum(axis=0)
911
+ wJ = ctx.w * ctx.test.detJ
912
+ return (integrand * wJ[:, None]).sum(axis=0)
913
+
914
+ return jax.jit(per_element) if jit else per_element
915
+
916
+
917
+ def make_element_residual_kernel(
918
+ res_form: ResidualForm[P], params: P
919
+ ) -> ElementResidualKernel:
491
920
  """Jitted element residual kernel: (ctx, u_elem) -> fe."""
492
921
 
493
922
  def per_element(ctx: FormContext, u_elem: jnp.ndarray):
@@ -500,7 +929,9 @@ def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
500
929
  return jax.jit(per_element)
501
930
 
502
931
 
503
- def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
932
+ def make_element_jacobian_kernel(
933
+ res_form: ResidualForm[P], params: P
934
+ ) -> ElementJacobianKernel:
504
935
  """Jitted element Jacobian kernel: (ctx, u_elem) -> Ke."""
505
936
 
506
937
  def fe_fun(u_elem, ctx: FormContext):
@@ -513,7 +944,9 @@ def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
513
944
  return jax.jit(jax.jacrev(fe_fun, argnums=0))
514
945
 
515
946
 
516
- def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
947
+ def element_residual(
948
+ res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P
949
+ ) -> Any:
517
950
  """
518
951
  Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
519
952
  Returns shape (n_ldofs,).
@@ -538,7 +971,9 @@ def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
538
971
  return jax.tree_util.tree_map(lambda x: jnp.einsum("qa,q->a", x, ctx.w * ctx.test.detJ), integrand)
539
972
 
540
973
 
541
- def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
974
+ def element_jacobian(
975
+ res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P
976
+ ) -> Any:
542
977
  """
543
978
  Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
544
979
  """
@@ -548,7 +983,42 @@ def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
548
983
  return jax.jacfwd(_r_elem)(u_elem)
549
984
 
550
985
 
551
- def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
986
+ def make_element_kernel(
987
+ form: Kernel[P] | ResidualForm[P],
988
+ params: P,
989
+ *,
990
+ kind: Literal["bilinear", "linear", "residual", "jacobian"],
991
+ jit: bool = True,
992
+ ) -> ElementKernel:
993
+ """
994
+ Unified entry point for element kernels.
995
+
996
+ kind:
997
+ - "bilinear": kernel(ctx) -> (n_ldofs, n_ldofs)
998
+ - "linear": kernel(ctx) -> (n_ldofs,)
999
+ - "residual": kernel(ctx, u_elem) -> (n_ldofs,)
1000
+ - "jacobian": kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
1001
+ """
1002
+ kind = kind.lower()
1003
+ if kind == "bilinear":
1004
+ return make_element_bilinear_kernel(form, params, jit=jit)
1005
+ if kind == "linear":
1006
+ def per_element(ctx: FormContext):
1007
+ integrand = form(ctx, params)
1008
+ if getattr(form, "_includes_measure", False):
1009
+ return integrand.sum(axis=0)
1010
+ wJ = ctx.w * ctx.test.detJ
1011
+ return (integrand * wJ[:, None]).sum(axis=0)
1012
+
1013
+ return jax.jit(per_element) if jit else per_element
1014
+ if kind == "residual":
1015
+ return make_element_residual_kernel(form, params)
1016
+ if kind == "jacobian":
1017
+ return make_element_jacobian_kernel(form, params)
1018
+ raise ValueError(f"Unknown kernel kind: {kind}")
1019
+
1020
+
1021
+ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True) -> SparsityPattern:
552
1022
  """
553
1023
  Build a SparsityPattern (rows/cols[/idx]) that is independent of the solution.
554
1024
  NOTE: rows/cols ordering matches assemble_jacobian_values(...).reshape(-1)
@@ -557,24 +1027,24 @@ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
557
1027
  """
558
1028
  from ..solver import SparsityPattern # local import to avoid circular
559
1029
 
560
- elem_dofs = jnp.asarray(space.elem_dofs, dtype=jnp.int32)
1030
+ elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE)
561
1031
  n_dofs = int(space.n_dofs)
562
1032
  n_ldofs = int(space.n_ldofs)
563
1033
 
564
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(jnp.int32)
565
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(jnp.int32)
1034
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(INDEX_DTYPE)
1035
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(INDEX_DTYPE)
566
1036
 
567
1037
  key = rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)
568
- order = jnp.argsort(key).astype(jnp.int32)
1038
+ order = jnp.argsort(key).astype(INDEX_DTYPE)
569
1039
  rows_sorted = rows[order]
570
1040
  cols_sorted = cols[order]
571
- counts = jnp.bincount(rows_sorted, length=n_dofs).astype(jnp.int32)
572
- indptr_j = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(counts)])
573
- indices_j = cols_sorted.astype(jnp.int32)
1041
+ counts = jnp.bincount(rows_sorted, length=n_dofs).astype(INDEX_DTYPE)
1042
+ indptr_j = jnp.concatenate([jnp.array([0], dtype=INDEX_DTYPE), jnp.cumsum(counts)])
1043
+ indices_j = cols_sorted.astype(INDEX_DTYPE)
574
1044
  perm = order
575
1045
 
576
1046
  if with_idx:
577
- idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(jnp.int32)
1047
+ idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
578
1048
  return SparsityPattern(
579
1049
  rows=rows,
580
1050
  cols=cols,
@@ -601,8 +1071,10 @@ def assemble_jacobian_values(
601
1071
  u: jnp.ndarray,
602
1072
  params: P,
603
1073
  *,
604
- kernel=None,
605
- ):
1074
+ kernel: ElementJacobianKernel | None = None,
1075
+ n_chunks: Optional[int] = None,
1076
+ pad_trace: bool = False,
1077
+ ) -> Array:
606
1078
  """
607
1079
  Assemble only the numeric values for the Jacobian (pattern-free).
608
1080
  """
@@ -610,8 +1082,49 @@ def assemble_jacobian_values(
610
1082
  ker = kernel if kernel is not None else make_element_jacobian_kernel(res_form, params)
611
1083
 
612
1084
  u_elems = u[space.elem_dofs]
613
- J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
614
- return J_e_all.reshape(-1)
1085
+ if n_chunks is None:
1086
+ J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
1087
+ return J_e_all.reshape(-1)
1088
+
1089
+ n_elems = int(u_elems.shape[0])
1090
+ if n_chunks <= 0:
1091
+ raise ValueError("n_chunks must be a positive integer.")
1092
+ n_chunks = min(int(n_chunks), int(n_elems))
1093
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
1094
+ stats = chunk_pad_stats(n_elems, n_chunks)
1095
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
1096
+ pad = (-n_elems) % chunk_size
1097
+ if pad:
1098
+ ctxs_pad = jax.tree_util.tree_map(
1099
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
1100
+ ctxs,
1101
+ )
1102
+ u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
1103
+ else:
1104
+ ctxs_pad = ctxs
1105
+ u_elems_pad = u_elems
1106
+
1107
+ n_pad = n_elems + pad
1108
+ n_chunks = n_pad // chunk_size
1109
+ m = int(space.n_ldofs)
1110
+
1111
+ def _slice_first_dim(x, start, size):
1112
+ start_idx = (start,) + (0,) * (x.ndim - 1)
1113
+ slice_sizes = (size,) + x.shape[1:]
1114
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
1115
+
1116
+ def chunk_fn(i):
1117
+ start = i * chunk_size
1118
+ ctx_chunk = jax.tree_util.tree_map(
1119
+ lambda x: _slice_first_dim(x, start, chunk_size),
1120
+ ctxs_pad,
1121
+ )
1122
+ u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
1123
+ J_e = jax.vmap(ker)(u_chunk, ctx_chunk)
1124
+ return J_e.reshape(-1)
1125
+
1126
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
1127
+ return data_chunks.reshape(-1)[: n_elems * m * m]
615
1128
 
616
1129
 
617
1130
  def assemble_residual_scatter(
@@ -620,9 +1133,11 @@ def assemble_residual_scatter(
620
1133
  u: jnp.ndarray,
621
1134
  params: P,
622
1135
  *,
623
- kernel=None,
1136
+ kernel: ElementResidualKernel | None = None,
624
1137
  sparse: bool = False,
625
- ):
1138
+ n_chunks: Optional[int] = None,
1139
+ pad_trace: bool = False,
1140
+ ) -> LinearReturn:
626
1141
  """
627
1142
  Assemble residual using jitted element kernel + vmap + scatter_add.
628
1143
  Avoids Python loops; good for JIT stability.
@@ -633,20 +1148,62 @@ def assemble_residual_scatter(
633
1148
  """
634
1149
  elem_dofs = space.elem_dofs
635
1150
  n_dofs = space.n_dofs
636
- if np.max(elem_dofs) >= n_dofs:
637
- raise ValueError("elem_dofs contains index outside n_dofs")
638
- if np.min(elem_dofs) < 0:
639
- raise ValueError("elem_dofs contains negative index")
1151
+ if jax.core.trace_ctx.is_top_level():
1152
+ if np.max(elem_dofs) >= n_dofs:
1153
+ raise ValueError("elem_dofs contains index outside n_dofs")
1154
+ if np.min(elem_dofs) < 0:
1155
+ raise ValueError("elem_dofs contains negative index")
640
1156
  ctxs = space.build_form_contexts()
641
1157
  ker = kernel if kernel is not None else make_element_residual_kernel(res_form, params)
642
1158
 
643
1159
  u_elems = u[elem_dofs]
644
- elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
645
- if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
646
- bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
647
- raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
1160
+ if n_chunks is None:
1161
+ elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
1162
+ else:
1163
+ n_elems = int(u_elems.shape[0])
1164
+ if n_chunks <= 0:
1165
+ raise ValueError("n_chunks must be a positive integer.")
1166
+ n_chunks = min(int(n_chunks), int(n_elems))
1167
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
1168
+ stats = chunk_pad_stats(n_elems, n_chunks)
1169
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
1170
+ pad = (-n_elems) % chunk_size
1171
+ if pad:
1172
+ ctxs_pad = jax.tree_util.tree_map(
1173
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
1174
+ ctxs,
1175
+ )
1176
+ u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
1177
+ else:
1178
+ ctxs_pad = ctxs
1179
+ u_elems_pad = u_elems
1180
+
1181
+ n_pad = n_elems + pad
1182
+ n_chunks = n_pad // chunk_size
1183
+
1184
+ def _slice_first_dim(x, start, size):
1185
+ start_idx = (start,) + (0,) * (x.ndim - 1)
1186
+ slice_sizes = (size,) + x.shape[1:]
1187
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
648
1188
 
649
- rows = elem_dofs.reshape(-1)
1189
+ def chunk_fn(i):
1190
+ start = i * chunk_size
1191
+ ctx_chunk = jax.tree_util.tree_map(
1192
+ lambda x: _slice_first_dim(x, start, chunk_size),
1193
+ ctxs_pad,
1194
+ )
1195
+ u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
1196
+ res_chunk = jax.vmap(ker)(ctx_chunk, u_chunk)
1197
+ return res_chunk.reshape(-1)
1198
+
1199
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
1200
+ elem_res = data_chunks.reshape(-1)[: n_elems * int(space.n_ldofs)].reshape(n_elems, -1)
1201
+ if jax.core.trace_ctx.is_top_level():
1202
+ if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
1203
+ bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
1204
+ raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
1205
+
1206
+ rows = _get_elem_rows(space)
650
1207
  data = elem_res.reshape(-1)
651
1208
 
652
1209
  if sparse:
@@ -668,11 +1225,13 @@ def assemble_jacobian_scatter(
668
1225
  u: jnp.ndarray,
669
1226
  params: P,
670
1227
  *,
671
- kernel=None,
1228
+ kernel: ElementJacobianKernel | None = None,
672
1229
  sparse: bool = False,
673
1230
  return_flux_matrix: bool = False,
674
- pattern=None,
675
- ):
1231
+ pattern: SparsityPattern | None = None,
1232
+ n_chunks: Optional[int] = None,
1233
+ pad_trace: bool = False,
1234
+ ) -> JacobianReturn:
676
1235
  """
677
1236
  Assemble Jacobian using jitted element kernel + vmap + scatter_add.
678
1237
  If a SparsityPattern is provided, rows/cols are reused without regeneration.
@@ -682,7 +1241,9 @@ def assemble_jacobian_scatter(
682
1241
  from ..solver import FluxSparseMatrix # local import to avoid circular
683
1242
 
684
1243
  pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
685
- data = assemble_jacobian_values(space, res_form, u, params, kernel=kernel)
1244
+ data = assemble_jacobian_values(
1245
+ space, res_form, u, params, kernel=kernel, n_chunks=n_chunks, pad_trace=pad_trace
1246
+ )
686
1247
 
687
1248
  if sparse:
688
1249
  if return_flux_matrix:
@@ -691,7 +1252,7 @@ def assemble_jacobian_scatter(
691
1252
 
692
1253
  idx = pat.idx
693
1254
  if idx is None:
694
- idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(jnp.int32)
1255
+ idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(INDEX_DTYPE)
695
1256
 
696
1257
  n_entries = pat.n_dofs * pat.n_dofs
697
1258
  sdn = jax.lax.ScatterDimensionNumbers(
@@ -708,12 +1269,21 @@ def assemble_jacobian_scatter(
708
1269
  def assemble_residual(
709
1270
  space: SpaceLike,
710
1271
  form: ResidualForm[P],
711
- u: jnp.ndarray, params: P,
1272
+ u: jnp.ndarray,
1273
+ params: P,
712
1274
  *,
713
- sparse: bool = False
714
- ):
715
- """Assemble the global residual vector (scatter-based)."""
716
- return assemble_residual_scatter(space, form, u, params, sparse=sparse)
1275
+ kernel: ElementResidualKernel | None = None,
1276
+ sparse: bool = False,
1277
+ n_chunks: Optional[int] = None,
1278
+ pad_trace: bool = False,
1279
+ ) -> LinearReturn:
1280
+ """
1281
+ Assemble the global residual vector (scatter-based).
1282
+ If kernel is provided: kernel(ctx, u_elem) -> (n_ldofs,).
1283
+ """
1284
+ return assemble_residual_scatter(
1285
+ space, form, u, params, kernel=kernel, sparse=sparse, n_chunks=n_chunks, pad_trace=pad_trace
1286
+ )
717
1287
 
718
1288
 
719
1289
  def assemble_jacobian(
@@ -722,19 +1292,28 @@ def assemble_jacobian(
722
1292
  u: jnp.ndarray,
723
1293
  params: P,
724
1294
  *,
1295
+ kernel: ElementJacobianKernel | None = None,
725
1296
  sparse: bool = True,
726
1297
  return_flux_matrix: bool = False,
727
- pattern=None,
728
- ):
729
- """Assemble the global Jacobian (scatter-based)."""
1298
+ pattern: SparsityPattern | None = None,
1299
+ n_chunks: Optional[int] = None,
1300
+ pad_trace: bool = False,
1301
+ ) -> JacobianReturn:
1302
+ """
1303
+ Assemble the global Jacobian (scatter-based).
1304
+ If kernel is provided: kernel(u_elem, ctx) -> (n_ldofs, n_ldofs).
1305
+ """
730
1306
  return assemble_jacobian_scatter(
731
1307
  space,
732
1308
  res_form,
733
1309
  u,
734
1310
  params,
1311
+ kernel=kernel,
735
1312
  sparse=sparse,
736
1313
  return_flux_matrix=return_flux_matrix,
737
1314
  pattern=pattern,
1315
+ n_chunks=n_chunks,
1316
+ pad_trace=pad_trace,
738
1317
  )
739
1318
 
740
1319
 
@@ -748,13 +1327,19 @@ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
748
1327
  return load * ctx.test.N # (n_q, n_ldofs)
749
1328
 
750
1329
 
751
- def make_scalar_body_force_form(body_force):
1330
+ scalar_body_force_form._ff_kind = "linear"
1331
+ scalar_body_force_form._ff_domain = "volume"
1332
+
1333
+
1334
+ def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> Kernel[Any]:
752
1335
  """
753
1336
  Build a scalar linear form from a callable f(x_q) -> (n_q,).
754
1337
  """
755
1338
  def _form(ctx: FormContext, _params):
756
1339
  f_q = body_force(ctx.x_q)
757
1340
  return f_q[..., None] * ctx.test.N
1341
+ _form._ff_kind = "linear"
1342
+ _form._ff_domain = "volume"
758
1343
  return _form
759
1344
 
760
1345
 
@@ -762,7 +1347,7 @@ def make_scalar_body_force_form(body_force):
762
1347
  constant_body_force_form = scalar_body_force_form
763
1348
 
764
1349
 
765
- def _check_structured_box_connectivity():
1350
+ def _check_structured_box_connectivity() -> None:
766
1351
  """Quick connectivity check for nx=2, ny=1, nz=1 (non-structured order)."""
767
1352
  box = StructuredHexBox(nx=2, ny=1, nz=1, lx=2.0, ly=1.0, lz=1.0)
768
1353
  mesh = box.build()
@@ -775,7 +1360,7 @@ def _check_structured_box_connectivity():
775
1360
  [0, 1, 4, 3, 6, 7, 10, 9], # element at i=0
776
1361
  [1, 2, 5, 4, 7, 8, 11, 10], # element at i=1
777
1362
  ],
778
- dtype=jnp.int32,
1363
+ dtype=INDEX_DTYPE,
779
1364
  )
780
1365
  max_diff = int(jnp.max(jnp.abs(mesh.conn - expected_conn)))
781
1366
  print("StructuredHexBox nx=2,ny=1,nz=1 conn matches expected:", max_diff == 0)