fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/assembly.py CHANGED
@@ -1,21 +1,362 @@
1
1
  from __future__ import annotations
2
- from typing import Callable, Protocol, TypeVar, Optional
2
+ from typing import Any, Callable, Literal, Mapping, Optional, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, Union, cast
3
3
  import numpy as np
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
 
7
7
  from ..mesh import HexMesh, StructuredHexBox
8
+ from .dtypes import INDEX_DTYPE
8
9
  from .forms import FormContext
9
10
  from .space import FESpaceBase
10
11
 
11
12
  # Shared call signatures for kernels/forms
12
- Array = jnp.ndarray
13
+ Array: TypeAlias = jnp.ndarray
13
14
  P = TypeVar("P")
14
15
 
15
- Kernel = Callable[[FormContext, P], Array]
16
+ FormKernel: TypeAlias = Callable[[FormContext, P], Array]
17
+ # Form kernels return integrands; element kernels return integrated element arrays.
18
+ Kernel: TypeAlias = Callable[[FormContext, P], Array]
19
+ ResidualInput: TypeAlias = Array | Mapping[str, Array]
20
+ ResidualValue: TypeAlias = Array | Mapping[str, Array]
16
21
  ResidualForm = Callable[[FormContext, Array, P], Array]
22
+ ResidualFormLike = Callable[[FormContext, ResidualInput, P], ResidualValue]
17
23
  ElementDofMapper = Callable[[Array], Array]
18
24
 
25
+ if TYPE_CHECKING:
26
+ from ..solver import FluxSparseMatrix, SparsityPattern
27
+ else:
28
+ FluxSparseMatrix = Any
29
+ SparsityPattern = Any
30
+
31
+ SparseCOO: TypeAlias = tuple[Array, Array, Array, int]
32
+ LinearCOO: TypeAlias = tuple[Array, Array, int]
33
+ JacobianReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
34
+ BilinearReturn: TypeAlias = Union[Array, FluxSparseMatrix, SparseCOO]
35
+ LinearReturn: TypeAlias = Union[Array, LinearCOO]
36
+ MassReturn: TypeAlias = Union[FluxSparseMatrix, Array]
37
+
38
+
39
+ class ElementBilinearKernel(Protocol):
40
+ def __call__(self, ctx: FormContext) -> Array: ...
41
+
42
+
43
+ class ElementLinearKernel(Protocol):
44
+ def __call__(self, ctx: FormContext) -> Array: ...
45
+
46
+
47
+ class ElementResidualKernel(Protocol):
48
+ def __call__(self, ctx: FormContext, u_elem: Array) -> Array: ...
49
+
50
+
51
+ class ElementJacobianKernel(Protocol):
52
+ def __call__(self, u_elem: Array, ctx: FormContext) -> Array: ...
53
+
54
+ ElementKernel: TypeAlias = (
55
+ ElementBilinearKernel
56
+ | ElementLinearKernel
57
+ | ElementResidualKernel
58
+ | ElementJacobianKernel
59
+ )
60
+ def _get_pattern(space: SpaceLike, *, with_idx: bool) -> SparsityPattern | None:
61
+ if hasattr(space, "get_sparsity_pattern"):
62
+ return space.get_sparsity_pattern(with_idx=with_idx)
63
+ return None
64
+
65
+
66
+ def _get_elem_rows(space: SpaceLike) -> Array:
67
+ if hasattr(space, "get_elem_rows"):
68
+ return space.get_elem_rows()
69
+ return space.elem_dofs.reshape(-1)
70
+
71
+
72
+ def chunk_pad_stats(n_elems: int, n_chunks: Optional[int]) -> dict[str, int | float | None]:
73
+ """
74
+ Compute padding overhead for chunked assembly.
75
+ Returns dict with chunk_size, pad, n_pad, and pad_ratio.
76
+ """
77
+ n_elems = int(n_elems)
78
+ if n_chunks is None or n_elems <= 0:
79
+ return {"chunk_size": None, "pad": 0, "n_pad": n_elems, "pad_ratio": 0.0}
80
+ n_chunks = min(int(n_chunks), n_elems)
81
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
82
+ pad = (-n_elems) % chunk_size
83
+ n_pad = n_elems + pad
84
+ pad_ratio = float(pad) / float(n_elems) if n_elems else 0.0
85
+ return {"chunk_size": int(chunk_size), "pad": int(pad), "n_pad": int(n_pad), "pad_ratio": pad_ratio}
86
+
87
+
88
+ def _maybe_trace_pad(
89
+ stats: dict[str, int | float | None], *, n_chunks: Optional[int], pad_trace: bool
90
+ ) -> None:
91
+ if not pad_trace or not jax.core.trace_ctx.is_top_level():
92
+ return
93
+ if n_chunks is None:
94
+ return
95
+ print(
96
+ "[pad]",
97
+ f"n_chunks={int(n_chunks)}",
98
+ f"chunk_size={stats['chunk_size']}",
99
+ f"pad={stats['pad']}",
100
+ f"pad_ratio={stats['pad_ratio']:.4f}",
101
+ flush=True,
102
+ )
103
+
104
+
105
+ class BatchedAssembler:
106
+ """
107
+ Assemble on a fixed space with optional masking to keep shapes static.
108
+
109
+ Use `mask` to zero padded elements while keeping input shapes fixed.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ space: SpaceLike,
115
+ elem_data: FormContext,
116
+ elem_dofs: Array,
117
+ *,
118
+ pattern: SparsityPattern | None = None,
119
+ ) -> None:
120
+ self.space = space
121
+ self.elem_data = elem_data
122
+ self.elem_dofs = elem_dofs
123
+ self.n_elems = int(elem_dofs.shape[0])
124
+ self.n_ldofs = int(space.n_ldofs)
125
+ self.n_dofs = int(space.n_dofs)
126
+ self.pattern = pattern
127
+ self._rows: Array | None = None
128
+ self._cols: Array | None = None
129
+
130
+ @classmethod
131
+ def from_space(
132
+ cls,
133
+ space: SpaceLike,
134
+ *,
135
+ dep: jnp.ndarray | None = None,
136
+ pattern: SparsityPattern | None = None,
137
+ ) -> "BatchedAssembler":
138
+ elem_data = space.build_form_contexts(dep=dep)
139
+ return cls(space, elem_data, space.elem_dofs, pattern=pattern)
140
+
141
+ def make_mask(self, n_active: int) -> Array:
142
+ n_active = max(0, min(int(n_active), self.n_elems))
143
+ mask: np.ndarray = np.zeros((self.n_elems,), dtype=float)
144
+ if n_active:
145
+ mask[:n_active] = 1.0
146
+ return jnp.asarray(mask)
147
+
148
+ def slice(self, n_active: int) -> "BatchedAssembler":
149
+ n_active = max(0, min(int(n_active), self.n_elems))
150
+ elem_data = jax.tree_util.tree_map(lambda x: x[:n_active], self.elem_data)
151
+ elem_dofs = self.elem_dofs[:n_active]
152
+ return BatchedAssembler(self.space, elem_data, elem_dofs, pattern=None)
153
+
154
+ def _rows_cols(self) -> tuple[Array, Array]:
155
+ if self.pattern is not None:
156
+ return self.pattern.rows, self.pattern.cols
157
+ if self._rows is None or self._cols is None:
158
+ elem_dofs = self.elem_dofs
159
+ n_ldofs = int(elem_dofs.shape[1])
160
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
161
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
162
+ self._rows = rows
163
+ self._cols = cols
164
+ return self._rows, self._cols
165
+
166
+ def assemble_bilinear_with_kernel(
167
+ self, kernel: ElementBilinearKernel, *, mask: Array | None = None
168
+ ) -> FluxSparseMatrix:
169
+ """
170
+ kernel(ctx) -> (n_ldofs, n_ldofs)
171
+ """
172
+ from ..solver import FluxSparseMatrix
173
+
174
+ Ke = jax.vmap(kernel)(self.elem_data)
175
+ if mask is not None:
176
+ Ke = Ke * jnp.asarray(mask)[:, None, None]
177
+ data = Ke.reshape(-1)
178
+ if self.pattern is not None:
179
+ return FluxSparseMatrix(self.pattern, data)
180
+ rows, cols = self._rows_cols()
181
+ return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
182
+
183
+ def assemble_bilinear(
184
+ self,
185
+ form: FormKernel[P],
186
+ params: P,
187
+ *,
188
+ mask: Array | None = None,
189
+ kernel: ElementBilinearKernel | None = None,
190
+ jit: bool = True,
191
+ ) -> FluxSparseMatrix:
192
+ if kernel is None:
193
+ kernel = make_element_bilinear_kernel(form, params, jit=jit)
194
+ return self.assemble_bilinear_with_kernel(kernel, mask=mask)
195
+
196
+ def assemble_linear_with_kernel(
197
+ self,
198
+ kernel: ElementLinearKernel,
199
+ *,
200
+ mask: Array | None = None,
201
+ dep: jnp.ndarray | None = None,
202
+ ) -> Array:
203
+ """
204
+ kernel(ctx) -> (n_ldofs,)
205
+ """
206
+ elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
207
+ Fe = jax.vmap(kernel)(elem_data)
208
+ if mask is not None:
209
+ Fe = Fe * jnp.asarray(mask)[:, None]
210
+ rows = self.elem_dofs.reshape(-1)
211
+ data = Fe.reshape(-1)
212
+ return jax.ops.segment_sum(data, rows, self.n_dofs)
213
+
214
+ def assemble_linear(
215
+ self,
216
+ form: FormKernel[P],
217
+ params: P,
218
+ *,
219
+ mask: Array | None = None,
220
+ dep: jnp.ndarray | None = None,
221
+ kernel: ElementLinearKernel | None = None,
222
+ ) -> Array:
223
+ if kernel is not None:
224
+ return self.assemble_linear_with_kernel(kernel, mask=mask, dep=dep)
225
+ elem_data = self.elem_data if dep is None else self.space.build_form_contexts(dep=dep)
226
+ includes_measure = getattr(form, "_includes_measure", False)
227
+
228
+ def per_element(ctx: FormContext):
229
+ integrand = form(ctx, params)
230
+ if includes_measure:
231
+ return integrand.sum(axis=0)
232
+ wJ = ctx.w * ctx.test.detJ
233
+ return (integrand * wJ[:, None]).sum(axis=0)
234
+
235
+ Fe = jax.vmap(per_element)(elem_data)
236
+ if mask is not None:
237
+ Fe = Fe * jnp.asarray(mask)[:, None]
238
+ rows = self.elem_dofs.reshape(-1)
239
+ data = Fe.reshape(-1)
240
+ return jax.ops.segment_sum(data, rows, self.n_dofs)
241
+
242
+ def assemble_mass_matrix(
243
+ self, *, mask: Array | None = None, lumped: bool = False
244
+ ) -> MassReturn:
245
+ from ..solver import FluxSparseMatrix
246
+
247
+ n_ldofs = self.n_ldofs
248
+
249
+ def per_element(ctx: FormContext):
250
+ N = ctx.test.N
251
+ base = jnp.einsum("qa,qb->qab", N, N)
252
+ if hasattr(ctx.test, "value_dim"):
253
+ vd = int(ctx.test.value_dim)
254
+ I = jnp.eye(vd, dtype=N.dtype)
255
+ base = base[:, :, :, None, None] * I[None, None, None, :, :]
256
+ base = base.reshape(base.shape[0], n_ldofs, n_ldofs)
257
+ wJ = ctx.w * ctx.test.detJ
258
+ return jnp.einsum("qab,q->ab", base, wJ)
259
+
260
+ Me = jax.vmap(per_element)(self.elem_data)
261
+ if mask is not None:
262
+ Me = Me * jnp.asarray(mask)[:, None, None]
263
+ data = Me.reshape(-1)
264
+ rows, cols = self._rows_cols()
265
+
266
+ if lumped:
267
+ M = jnp.zeros((self.n_dofs,), dtype=data.dtype)
268
+ M = M.at[rows].add(data)
269
+ return M
270
+
271
+ return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
272
+
273
+ def assemble_residual_with_kernel(
274
+ self, kernel: ElementResidualKernel, u: Array, *, mask: Array | None = None
275
+ ) -> Array:
276
+ """
277
+ kernel(ctx, u_elem) -> (n_ldofs,)
278
+ """
279
+ u_elems = jnp.asarray(u)[self.elem_dofs]
280
+ elem_res = jax.vmap(kernel)(self.elem_data, u_elems)
281
+ if mask is not None:
282
+ elem_res = elem_res * jnp.asarray(mask)[:, None]
283
+ rows = self.elem_dofs.reshape(-1)
284
+ data = elem_res.reshape(-1)
285
+ return jax.ops.segment_sum(data, rows, self.n_dofs)
286
+
287
+ def assemble_residual(
288
+ self,
289
+ res_form: ResidualForm[P],
290
+ u: Array,
291
+ params: P,
292
+ *,
293
+ mask: Array | None = None,
294
+ kernel: ElementResidualKernel | None = None,
295
+ ) -> Array:
296
+ if kernel is None:
297
+ kernel = make_element_residual_kernel(res_form, params)
298
+ return self.assemble_residual_with_kernel(kernel, u, mask=mask)
299
+
300
+ def assemble_jacobian_with_kernel(
301
+ self,
302
+ kernel: ElementJacobianKernel,
303
+ u: Array,
304
+ *,
305
+ mask: Array | None = None,
306
+ sparse: bool = True,
307
+ return_flux_matrix: bool = False,
308
+ ) -> JacobianReturn:
309
+ """
310
+ kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
311
+ """
312
+ from ..solver import FluxSparseMatrix # local import to avoid circular
313
+
314
+ u_elems = jnp.asarray(u)[self.elem_dofs]
315
+ J_e = jax.vmap(kernel)(u_elems, self.elem_data)
316
+ if mask is not None:
317
+ J_e = J_e * jnp.asarray(mask)[:, None, None]
318
+ data = J_e.reshape(-1)
319
+ if sparse:
320
+ if self.pattern is not None:
321
+ if return_flux_matrix:
322
+ return FluxSparseMatrix(self.pattern, data)
323
+ return self.pattern.rows, self.pattern.cols, data, self.n_dofs
324
+ rows, cols = self._rows_cols()
325
+ if return_flux_matrix:
326
+ return FluxSparseMatrix(rows, cols, data, n_dofs=self.n_dofs)
327
+ return rows, cols, data, self.n_dofs
328
+ rows, cols = self._rows_cols()
329
+ idx = (rows.astype(jnp.int64) * int(self.n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
330
+ n_entries = self.n_dofs * self.n_dofs
331
+ sdn = jax.lax.ScatterDimensionNumbers(
332
+ update_window_dims=(),
333
+ inserted_window_dims=(0,),
334
+ scatter_dims_to_operand_dims=(0,),
335
+ )
336
+ K_flat = jnp.zeros(n_entries, dtype=data.dtype)
337
+ K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
338
+ return K_flat.reshape(self.n_dofs, self.n_dofs)
339
+
340
+ def assemble_jacobian(
341
+ self,
342
+ res_form: ResidualForm[P],
343
+ u: Array,
344
+ params: P,
345
+ *,
346
+ mask: Array | None = None,
347
+ kernel: ElementJacobianKernel | None = None,
348
+ sparse: bool = True,
349
+ return_flux_matrix: bool = False,
350
+ ) -> JacobianReturn:
351
+ if kernel is None:
352
+ kernel = make_element_jacobian_kernel(res_form, params)
353
+ return self.assemble_jacobian_with_kernel(
354
+ kernel,
355
+ u,
356
+ mask=mask,
357
+ sparse=sparse,
358
+ return_flux_matrix=return_flux_matrix,
359
+ )
19
360
 
20
361
  class SpaceLike(FESpaceBase, Protocol):
21
362
  pass
@@ -23,12 +364,12 @@ class SpaceLike(FESpaceBase, Protocol):
23
364
 
24
365
  def assemble_bilinear_dense(
25
366
  space: SpaceLike,
26
- kernel: Kernel[P],
367
+ kernel: FormKernel[P],
27
368
  params: P,
28
369
  *,
29
370
  sparse: bool = False,
30
371
  return_flux_matrix: bool = False,
31
- ):
372
+ ) -> BilinearReturn:
32
373
  """
33
374
  Similar to scikit-fem's asm(biform, basis).
34
375
  kernel: FormContext, params -> (n_ldofs, n_ldofs)
@@ -47,11 +388,15 @@ def assemble_bilinear_dense(
47
388
 
48
389
  # ---- scatter into COO format ----
49
390
  # row/col indices (n_elems, n_ldofs, n_ldofs)
50
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
51
- cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
52
-
53
- rows = rows.reshape(-1)
54
- cols = cols.reshape(-1)
391
+ pat = _get_pattern(space, with_idx=False)
392
+ if pat is None:
393
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1) # (n_elems, n_ldofs*n_ldofs)
394
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)) # (n_elems, n_ldofs*n_ldofs)
395
+ rows = rows.reshape(-1)
396
+ cols = cols.reshape(-1)
397
+ else:
398
+ rows = pat.rows
399
+ cols = pat.cols
55
400
  data = K_e_all.reshape(-1)
56
401
 
57
402
  # Flatten indices for segment_sum via (row * n_dofs + col)
@@ -71,18 +416,22 @@ def assemble_bilinear_dense(
71
416
 
72
417
 
73
418
  def assemble_bilinear_form(
74
- space,
75
- form,
76
- params,
419
+ space: SpaceLike,
420
+ form: FormKernel[P],
421
+ params: P,
77
422
  *,
78
- pattern=None,
79
- chunk_size: Optional[int] = None, # None -> no-chunk (old behavior)
423
+ pattern: SparsityPattern | None = None,
424
+ n_chunks: Optional[int] = None, # None -> no chunking
80
425
  dep: jnp.ndarray | None = None,
81
- ):
426
+ kernel: ElementBilinearKernel | None = None,
427
+ jit: bool = True,
428
+ pad_trace: bool = False,
429
+ ) -> FluxSparseMatrix:
82
430
  """
83
431
  Assemble a sparse bilinear form into a FluxSparseMatrix.
84
432
 
85
433
  Expects form(ctx, params) -> (n_q, n_ldofs, n_ldofs).
434
+ If kernel is provided: kernel(ctx) -> (n_ldofs, n_ldofs).
86
435
  """
87
436
  from ..solver import FluxSparseMatrix
88
437
 
@@ -104,18 +453,27 @@ def assemble_bilinear_form(
104
453
  wJ = ctx.w * ctx.test.detJ # (n_q,)
105
454
  return (integrand * wJ[:, None, None]).sum(axis=0) # (m, m)
106
455
 
456
+ if kernel is None:
457
+ kernel = make_element_bilinear_kernel(form, params, jit=jit)
458
+
107
459
  # --- no-chunk path (your current implementation) ---
108
- if chunk_size is None:
109
- K_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m, m)
460
+ if n_chunks is None:
461
+ K_e_all = jax.vmap(kernel)(elem_data) # (n_elems, m, m)
110
462
  data = K_e_all.reshape(-1)
111
463
  return FluxSparseMatrix(pat, data)
112
464
 
113
465
  # --- chunked path ---
114
466
  n_elems = space.elem_dofs.shape[0]
467
+ if n_chunks <= 0:
468
+ raise ValueError("n_chunks must be a positive integer.")
469
+ n_chunks = min(int(n_chunks), int(n_elems))
470
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
471
+ stats = chunk_pad_stats(n_elems, n_chunks)
472
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
115
473
  # Ideally get m from pat (otherwise infer from one element).
116
474
  m = getattr(pat, "n_ldofs", None)
117
475
  if m is None:
118
- m = per_element(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
476
+ m = kernel(jax.tree_util.tree_map(lambda x: x[0], elem_data)).shape[0]
119
477
 
120
478
  # Pad to fixed-size chunks for JIT stability.
121
479
  pad = (-n_elems) % chunk_size
@@ -141,7 +499,7 @@ def assemble_bilinear_form(
141
499
  lambda x: _slice_first_dim(x, start, chunk_size),
142
500
  elem_data_pad,
143
501
  )
144
- Ke = jax.vmap(per_element)(ctx_chunk) # (chunk, m, m)
502
+ Ke = jax.vmap(kernel)(ctx_chunk) # (chunk, m, m)
145
503
  return Ke.reshape(-1) # (chunk*m*m,)
146
504
 
147
505
  data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
@@ -149,7 +507,13 @@ def assemble_bilinear_form(
149
507
  return FluxSparseMatrix(pat, data)
150
508
 
151
509
 
152
- def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size: Optional[int] = None):
510
+ def assemble_mass_matrix(
511
+ space: SpaceLike,
512
+ *,
513
+ lumped: bool = False,
514
+ n_chunks: Optional[int] = None,
515
+ pad_trace: bool = False,
516
+ ) -> MassReturn:
153
517
  """
154
518
  Assemble mass matrix M_ij = ∫ N_i N_j dΩ.
155
519
  Supports scalar and vector spaces. If lumped=True, rows are summed to diagonal.
@@ -170,11 +534,17 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
170
534
  wJ = ctx.w * ctx.test.detJ
171
535
  return jnp.einsum("qab,q->ab", base, wJ)
172
536
 
173
- if chunk_size is None:
537
+ if n_chunks is None:
174
538
  M_e_all = jax.vmap(per_element)(ctxs) # (n_elems, n_ldofs, n_ldofs)
175
539
  data = M_e_all.reshape(-1)
176
540
  else:
177
541
  n_elems = space.elem_dofs.shape[0]
542
+ if n_chunks <= 0:
543
+ raise ValueError("n_chunks must be a positive integer.")
544
+ n_chunks = min(int(n_chunks), int(n_elems))
545
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
546
+ stats = chunk_pad_stats(n_elems, n_chunks)
547
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
178
548
  pad = (-n_elems) % chunk_size
179
549
  if pad:
180
550
  ctxs_pad = jax.tree_util.tree_map(
@@ -205,8 +575,13 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
205
575
  data = data_chunks.reshape(-1)[: n_elems * n_ldofs * n_ldofs]
206
576
 
207
577
  elem_dofs = space.elem_dofs
208
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
209
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
578
+ pat = _get_pattern(space, with_idx=False)
579
+ if pat is None:
580
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
581
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
582
+ else:
583
+ rows = pat.rows
584
+ cols = pat.cols
210
585
 
211
586
  if lumped:
212
587
  n_dofs = space.n_dofs
@@ -219,15 +594,18 @@ def assemble_mass_matrix(space: SpaceLike, *, lumped: bool = False, chunk_size:
219
594
 
220
595
  def assemble_linear_form(
221
596
  space: SpaceLike,
222
- form: Kernel[P],
597
+ form: FormKernel[P],
223
598
  params: P,
224
599
  *,
600
+ kernel: ElementLinearKernel | None = None,
225
601
  sparse: bool = False,
226
- chunk_size: Optional[int] = None,
602
+ n_chunks: Optional[int] = None,
227
603
  dep: jnp.ndarray | None = None,
228
- ) -> jnp.ndarray:
604
+ pad_trace: bool = False,
605
+ ) -> LinearReturn:
229
606
  """
230
607
  Expects form(ctx, params) -> (n_q, n_ldofs) and integrates Σ_q form * wJ for RHS.
608
+ If kernel is provided: kernel(ctx) -> (n_ldofs,).
231
609
  """
232
610
  elem_dofs = space.elem_dofs
233
611
  n_dofs = space.n_dofs
@@ -237,19 +615,28 @@ def assemble_linear_form(
237
615
 
238
616
  includes_measure = getattr(form, "_includes_measure", False)
239
617
 
240
- def per_element(ctx: FormContext):
241
- integrand = form(ctx, params) # (n_q, m)
242
- if includes_measure:
243
- return integrand.sum(axis=0)
244
- wJ = ctx.w * ctx.test.detJ # (n_q,)
245
- return (integrand * wJ[:, None]).sum(axis=0) # (m,)
618
+ if kernel is None:
619
+ def per_element(ctx: FormContext):
620
+ integrand = form(ctx, params) # (n_q, m)
621
+ if includes_measure:
622
+ return integrand.sum(axis=0)
623
+ wJ = ctx.w * ctx.test.detJ # (n_q,)
624
+ return (integrand * wJ[:, None]).sum(axis=0) # (m,)
625
+ else:
626
+ per_element = kernel
246
627
 
247
- if chunk_size is None:
628
+ if n_chunks is None:
248
629
  F_e_all = jax.vmap(per_element)(elem_data) # (n_elems, m)
249
630
  data = F_e_all.reshape(-1)
250
631
  else:
251
632
  n_elems = space.elem_dofs.shape[0]
252
633
  m = n_ldofs
634
+ if n_chunks <= 0:
635
+ raise ValueError("n_chunks must be a positive integer.")
636
+ n_chunks = min(int(n_chunks), int(n_elems))
637
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
638
+ stats = chunk_pad_stats(n_elems, n_chunks)
639
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
253
640
  pad = (-n_elems) % chunk_size
254
641
  if pad:
255
642
  elem_data_pad = jax.tree_util.tree_map(
@@ -279,7 +666,7 @@ def assemble_linear_form(
279
666
  data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
280
667
  data = data_chunks.reshape(-1)[: n_elems * m]
281
668
 
282
- rows = elem_dofs.reshape(-1)
669
+ rows = _get_elem_rows(space)
283
670
 
284
671
  if sparse:
285
672
  return rows, data, n_dofs
@@ -288,7 +675,7 @@ def assemble_linear_form(
288
675
  return F
289
676
 
290
677
 
291
- def assemble_functional(space: SpaceLike, form: Kernel[P], params: P) -> jnp.ndarray:
678
+ def assemble_functional(space: SpaceLike, form: FormKernel[P], params: P) -> jnp.ndarray:
292
679
  """
293
680
  Assemble scalar functional J = ∫ form(ctx, params) dΩ.
294
681
  Expects form(ctx, params) -> (n_q,) or (n_q, 1).
@@ -318,7 +705,7 @@ def assemble_jacobian_global(
318
705
  *,
319
706
  sparse: bool = False,
320
707
  return_flux_matrix: bool = False,
321
- ):
708
+ ) -> JacobianReturn:
322
709
  """
323
710
  Assemble Jacobian (dR/du) from element residual res_form.
324
711
  res_form(ctx, u_elem, params) -> (n_q, n_ldofs)
@@ -339,11 +726,16 @@ def assemble_jacobian_global(
339
726
  jac_fun = jax.jacrev(fe_fun, argnums=0)
340
727
 
341
728
  u_elems = u[elem_dofs] # (n_elems, n_ldofs)
342
- elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
729
+ elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
343
730
  J_e_all = jax.vmap(jac_fun)(u_elems, elem_data, elem_ids) # (n_elems, m, m)
344
731
 
345
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
346
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
732
+ pat = _get_pattern(space, with_idx=False)
733
+ if pat is None:
734
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
735
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
736
+ else:
737
+ rows = pat.rows
738
+ cols = pat.cols
347
739
  data = J_e_all.reshape(-1)
348
740
 
349
741
  if sparse:
@@ -358,7 +750,7 @@ def assemble_jacobian_global(
358
750
  return K_flat.reshape(n_dofs, n_dofs)
359
751
 
360
752
 
361
- def assemble_jacobian_elementwise_xla(
753
+ def assemble_jacobian_elementwise(
362
754
  space: SpaceLike,
363
755
  res_form: ResidualForm[P],
364
756
  u: jnp.ndarray,
@@ -366,9 +758,9 @@ def assemble_jacobian_elementwise_xla(
366
758
  *,
367
759
  sparse: bool = False,
368
760
  return_flux_matrix: bool = False,
369
- ):
761
+ ) -> JacobianReturn:
370
762
  """
371
- Assemble Jacobian with element kernels in XLA (vmap + scatter_add).
763
+ Assemble Jacobian with element kernels via vmap + scatter_add.
372
764
  Recompiles if n_dofs changes, but independent of element count.
373
765
  """
374
766
  from ..solver import FluxSparseMatrix # local import to avoid circular
@@ -388,8 +780,13 @@ def assemble_jacobian_elementwise_xla(
388
780
  u_elems = u[elem_dofs]
389
781
  J_e_all = jax.vmap(jac_fun)(u_elems, ctxs) # (n_elems, m, m)
390
782
 
391
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
392
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
783
+ pat = _get_pattern(space, with_idx=False)
784
+ if pat is None:
785
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1)
786
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1)
787
+ else:
788
+ rows = pat.rows
789
+ cols = pat.cols
393
790
  data = J_e_all.reshape(-1)
394
791
 
395
792
  if sparse:
@@ -406,7 +803,7 @@ def assemble_jacobian_elementwise_xla(
406
803
  )
407
804
  K_flat = jnp.zeros(n_entries, dtype=data.dtype)
408
805
  K_flat = jax.lax.scatter_add(K_flat, idx[:, None], data, sdn)
409
- return K_flat.reshape(pat.n_dofs, pat.n_dofs)
806
+ return K_flat.reshape(n_dofs, n_dofs)
410
807
 
411
808
 
412
809
  def assemble_residual_global(
@@ -416,7 +813,7 @@ def assemble_residual_global(
416
813
  params: P,
417
814
  *,
418
815
  sparse: bool = False
419
- ):
816
+ ) -> LinearReturn:
420
817
  """
421
818
  Assemble residual vector that depends on u.
422
819
  form(ctx, u_elem, params) -> (n_q, n_ldofs)
@@ -435,10 +832,10 @@ def assemble_residual_global(
435
832
  fe = (integrand * wJ[:, None]).sum(axis=0)
436
833
  return fe
437
834
 
438
- elem_ids = jnp.arange(elem_dofs.shape[0], dtype=jnp.int32)
835
+ elem_ids = jnp.arange(elem_dofs.shape[0], dtype=INDEX_DTYPE)
439
836
  F_e_all = jax.vmap(per_element)(elem_data, elem_dofs, elem_ids) # (n_elems, m)
440
837
 
441
- rows = elem_dofs.reshape(-1)
838
+ rows = _get_elem_rows(space)
442
839
  data = F_e_all.reshape(-1)
443
840
 
444
841
  if sparse:
@@ -448,16 +845,16 @@ def assemble_residual_global(
448
845
  return F
449
846
 
450
847
 
451
- def assemble_residual_elementwise_xla(
848
+ def assemble_residual_elementwise(
452
849
  space: SpaceLike,
453
850
  res_form: ResidualForm[P],
454
851
  u: jnp.ndarray,
455
852
  params: P,
456
853
  *,
457
854
  sparse: bool = False,
458
- ):
855
+ ) -> LinearReturn:
459
856
  """
460
- Assemble residual using element kernels fully in XLA (vmap + scatter_add).
857
+ Assemble residual using element kernels via vmap + scatter_add.
461
858
  Recompiles if n_dofs changes, but independent of element count.
462
859
  """
463
860
  elem_dofs = space.elem_dofs
@@ -471,7 +868,7 @@ def assemble_residual_elementwise_xla(
471
868
 
472
869
  u_elems = u[elem_dofs]
473
870
  F_e_all = jax.vmap(per_element)(ctxs, u_elems) # (n_elems, m)
474
- rows = elem_dofs.reshape(-1)
871
+ rows = _get_elem_rows(space)
475
872
  data = F_e_all.reshape(-1)
476
873
 
477
874
  if sparse:
@@ -487,7 +884,44 @@ def assemble_residual_elementwise_xla(
487
884
  return F
488
885
 
489
886
 
490
- def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
887
+ # Backward compatibility aliases (prefer assemble_*_elementwise).
888
+ assemble_jacobian_elementwise_xla = assemble_jacobian_elementwise
889
+ assemble_residual_elementwise_xla = assemble_residual_elementwise
890
+
891
+
892
+ def make_element_bilinear_kernel(
893
+ form: FormKernel[P], params: P, *, jit: bool = True
894
+ ) -> ElementBilinearKernel:
895
+ """Element kernel: (ctx) -> Ke."""
896
+
897
+ def per_element(ctx: FormContext):
898
+ integrand = form(ctx, params)
899
+ if getattr(form, "_includes_measure", False):
900
+ return integrand.sum(axis=0)
901
+ wJ = ctx.w * ctx.test.detJ
902
+ return (integrand * wJ[:, None, None]).sum(axis=0)
903
+
904
+ return jax.jit(per_element) if jit else per_element
905
+
906
+
907
+ def make_element_linear_kernel(
908
+ form: FormKernel[P], params: P, *, jit: bool = True
909
+ ) -> ElementLinearKernel:
910
+ """Element kernel: (ctx) -> fe."""
911
+
912
+ def per_element(ctx: FormContext):
913
+ integrand = form(ctx, params)
914
+ if getattr(form, "_includes_measure", False):
915
+ return integrand.sum(axis=0)
916
+ wJ = ctx.w * ctx.test.detJ
917
+ return (integrand * wJ[:, None]).sum(axis=0)
918
+
919
+ return jax.jit(per_element) if jit else per_element
920
+
921
+
922
+ def make_element_residual_kernel(
923
+ res_form: ResidualForm[P], params: P
924
+ ) -> ElementResidualKernel:
491
925
  """Jitted element residual kernel: (ctx, u_elem) -> fe."""
492
926
 
493
927
  def per_element(ctx: FormContext, u_elem: jnp.ndarray):
@@ -500,7 +934,9 @@ def make_element_residual_kernel(res_form: ResidualForm[P], params: P):
500
934
  return jax.jit(per_element)
501
935
 
502
936
 
503
- def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
937
+ def make_element_jacobian_kernel(
938
+ res_form: ResidualForm[P], params: P
939
+ ) -> ElementJacobianKernel:
504
940
  """Jitted element Jacobian kernel: (ctx, u_elem) -> Ke."""
505
941
 
506
942
  def fe_fun(u_elem, ctx: FormContext):
@@ -513,7 +949,9 @@ def make_element_jacobian_kernel(res_form: ResidualForm[P], params: P):
513
949
  return jax.jit(jax.jacrev(fe_fun, argnums=0))
514
950
 
515
951
 
516
- def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
952
+ def element_residual(
953
+ res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
954
+ ) -> ResidualValue:
517
955
  """
518
956
  Element residual vector r_e(u_e) = sum_q w_q * detJ_q * res_form(ctx, u_e, params).
519
957
  Returns shape (n_ldofs,).
@@ -538,7 +976,9 @@ def element_residual(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
538
976
  return jax.tree_util.tree_map(lambda x: jnp.einsum("qa,q->a", x, ctx.w * ctx.test.detJ), integrand)
539
977
 
540
978
 
541
- def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.ndarray, params: P):
979
+ def element_jacobian(
980
+ res_form: ResidualFormLike[P], ctx: FormContext, u_elem: ResidualInput, params: P
981
+ ) -> ResidualValue:
542
982
  """
543
983
  Element Jacobian K_e = d r_e / d u_e (AD via jacfwd), shape (n_ldofs, n_ldofs).
544
984
  """
@@ -548,7 +988,46 @@ def element_jacobian(res_form: ResidualForm[P], ctx: FormContext, u_elem: jnp.nd
548
988
  return jax.jacfwd(_r_elem)(u_elem)
549
989
 
550
990
 
551
- def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
991
+ def make_element_kernel(
992
+ form: FormKernel[P] | ResidualForm[P],
993
+ params: P,
994
+ *,
995
+ kind: Literal["bilinear", "linear", "residual", "jacobian"],
996
+ jit: bool = True,
997
+ ) -> ElementKernel:
998
+ """
999
+ Unified entry point for element kernels.
1000
+
1001
+ kind:
1002
+ - "bilinear": kernel(ctx) -> (n_ldofs, n_ldofs)
1003
+ - "linear": kernel(ctx) -> (n_ldofs,)
1004
+ - "residual": kernel(ctx, u_elem) -> (n_ldofs,)
1005
+ - "jacobian": kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
1006
+ """
1007
+ kind = cast(Literal["bilinear", "linear", "residual", "jacobian"], kind.lower())
1008
+ if kind == "bilinear":
1009
+ form_bilinear = cast(FormKernel[P], form)
1010
+ return make_element_bilinear_kernel(form_bilinear, params, jit=jit)
1011
+ if kind == "linear":
1012
+ form_linear = cast(FormKernel[P], form)
1013
+ def per_element(ctx: FormContext):
1014
+ integrand = form_linear(ctx, params)
1015
+ if getattr(form_linear, "_includes_measure", False):
1016
+ return integrand.sum(axis=0)
1017
+ wJ = ctx.w * ctx.test.detJ
1018
+ return (integrand * wJ[:, None]).sum(axis=0)
1019
+
1020
+ return jax.jit(per_element) if jit else per_element
1021
+ if kind == "residual":
1022
+ form_residual = cast(ResidualForm[P], form)
1023
+ return make_element_residual_kernel(form_residual, params)
1024
+ if kind == "jacobian":
1025
+ form_residual = cast(ResidualForm[P], form)
1026
+ return make_element_jacobian_kernel(form_residual, params)
1027
+ raise ValueError(f"Unknown kernel kind: {kind}")
1028
+
1029
+
1030
+ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True) -> SparsityPattern:
552
1031
  """
553
1032
  Build a SparsityPattern (rows/cols[/idx]) that is independent of the solution.
554
1033
  NOTE: rows/cols ordering matches assemble_jacobian_values(...).reshape(-1)
@@ -557,24 +1036,24 @@ def make_sparsity_pattern(space: SpaceLike, *, with_idx: bool = True):
557
1036
  """
558
1037
  from ..solver import SparsityPattern # local import to avoid circular
559
1038
 
560
- elem_dofs = jnp.asarray(space.elem_dofs, dtype=jnp.int32)
1039
+ elem_dofs = jnp.asarray(space.elem_dofs, dtype=INDEX_DTYPE)
561
1040
  n_dofs = int(space.n_dofs)
562
1041
  n_ldofs = int(space.n_ldofs)
563
1042
 
564
- rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(jnp.int32)
565
- cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(jnp.int32)
1043
+ rows = jnp.repeat(elem_dofs, n_ldofs, axis=1).reshape(-1).astype(INDEX_DTYPE)
1044
+ cols = jnp.tile(elem_dofs, (1, n_ldofs)).reshape(-1).astype(INDEX_DTYPE)
566
1045
 
567
1046
  key = rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)
568
- order = jnp.argsort(key).astype(jnp.int32)
1047
+ order = jnp.argsort(key).astype(INDEX_DTYPE)
569
1048
  rows_sorted = rows[order]
570
1049
  cols_sorted = cols[order]
571
- counts = jnp.bincount(rows_sorted, length=n_dofs).astype(jnp.int32)
572
- indptr_j = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(counts)])
573
- indices_j = cols_sorted.astype(jnp.int32)
1050
+ counts = jnp.bincount(rows_sorted, length=n_dofs).astype(INDEX_DTYPE)
1051
+ indptr_j = jnp.concatenate([jnp.array([0], dtype=INDEX_DTYPE), jnp.cumsum(counts)])
1052
+ indices_j = cols_sorted.astype(INDEX_DTYPE)
574
1053
  perm = order
575
1054
 
576
1055
  if with_idx:
577
- idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(jnp.int32)
1056
+ idx = (rows.astype(jnp.int64) * jnp.int64(n_dofs) + cols.astype(jnp.int64)).astype(INDEX_DTYPE)
578
1057
  return SparsityPattern(
579
1058
  rows=rows,
580
1059
  cols=cols,
@@ -601,8 +1080,10 @@ def assemble_jacobian_values(
601
1080
  u: jnp.ndarray,
602
1081
  params: P,
603
1082
  *,
604
- kernel=None,
605
- ):
1083
+ kernel: ElementJacobianKernel | None = None,
1084
+ n_chunks: Optional[int] = None,
1085
+ pad_trace: bool = False,
1086
+ ) -> Array:
606
1087
  """
607
1088
  Assemble only the numeric values for the Jacobian (pattern-free).
608
1089
  """
@@ -610,8 +1091,49 @@ def assemble_jacobian_values(
610
1091
  ker = kernel if kernel is not None else make_element_jacobian_kernel(res_form, params)
611
1092
 
612
1093
  u_elems = u[space.elem_dofs]
613
- J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
614
- return J_e_all.reshape(-1)
1094
+ if n_chunks is None:
1095
+ J_e_all = jax.vmap(ker)(u_elems, ctxs) # (n_elem, m, m)
1096
+ return J_e_all.reshape(-1)
1097
+
1098
+ n_elems = int(u_elems.shape[0])
1099
+ if n_chunks <= 0:
1100
+ raise ValueError("n_chunks must be a positive integer.")
1101
+ n_chunks = min(int(n_chunks), int(n_elems))
1102
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
1103
+ stats = chunk_pad_stats(n_elems, n_chunks)
1104
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
1105
+ pad = (-n_elems) % chunk_size
1106
+ if pad:
1107
+ ctxs_pad = jax.tree_util.tree_map(
1108
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
1109
+ ctxs,
1110
+ )
1111
+ u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
1112
+ else:
1113
+ ctxs_pad = ctxs
1114
+ u_elems_pad = u_elems
1115
+
1116
+ n_pad = n_elems + pad
1117
+ n_chunks = n_pad // chunk_size
1118
+ m = int(space.n_ldofs)
1119
+
1120
+ def _slice_first_dim(x, start, size):
1121
+ start_idx = (start,) + (0,) * (x.ndim - 1)
1122
+ slice_sizes = (size,) + x.shape[1:]
1123
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
1124
+
1125
+ def chunk_fn(i):
1126
+ start = i * chunk_size
1127
+ ctx_chunk = jax.tree_util.tree_map(
1128
+ lambda x: _slice_first_dim(x, start, chunk_size),
1129
+ ctxs_pad,
1130
+ )
1131
+ u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
1132
+ J_e = jax.vmap(ker)(u_chunk, ctx_chunk)
1133
+ return J_e.reshape(-1)
1134
+
1135
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
1136
+ return data_chunks.reshape(-1)[: n_elems * m * m]
615
1137
 
616
1138
 
617
1139
  def assemble_residual_scatter(
@@ -620,9 +1142,11 @@ def assemble_residual_scatter(
620
1142
  u: jnp.ndarray,
621
1143
  params: P,
622
1144
  *,
623
- kernel=None,
1145
+ kernel: ElementResidualKernel | None = None,
624
1146
  sparse: bool = False,
625
- ):
1147
+ n_chunks: Optional[int] = None,
1148
+ pad_trace: bool = False,
1149
+ ) -> LinearReturn:
626
1150
  """
627
1151
  Assemble residual using jitted element kernel + vmap + scatter_add.
628
1152
  Avoids Python loops; good for JIT stability.
@@ -633,20 +1157,62 @@ def assemble_residual_scatter(
633
1157
  """
634
1158
  elem_dofs = space.elem_dofs
635
1159
  n_dofs = space.n_dofs
636
- if np.max(elem_dofs) >= n_dofs:
637
- raise ValueError("elem_dofs contains index outside n_dofs")
638
- if np.min(elem_dofs) < 0:
639
- raise ValueError("elem_dofs contains negative index")
1160
+ if jax.core.trace_ctx.is_top_level():
1161
+ if np.max(elem_dofs) >= n_dofs:
1162
+ raise ValueError("elem_dofs contains index outside n_dofs")
1163
+ if np.min(elem_dofs) < 0:
1164
+ raise ValueError("elem_dofs contains negative index")
640
1165
  ctxs = space.build_form_contexts()
641
1166
  ker = kernel if kernel is not None else make_element_residual_kernel(res_form, params)
642
1167
 
643
1168
  u_elems = u[elem_dofs]
644
- elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
645
- if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
646
- bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
647
- raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
1169
+ if n_chunks is None:
1170
+ elem_res = jax.vmap(ker)(ctxs, u_elems) # (n_elem, n_ldofs)
1171
+ else:
1172
+ n_elems = int(u_elems.shape[0])
1173
+ if n_chunks <= 0:
1174
+ raise ValueError("n_chunks must be a positive integer.")
1175
+ n_chunks = min(int(n_chunks), int(n_elems))
1176
+ chunk_size = (n_elems + n_chunks - 1) // n_chunks
1177
+ stats = chunk_pad_stats(n_elems, n_chunks)
1178
+ _maybe_trace_pad(stats, n_chunks=n_chunks, pad_trace=pad_trace)
1179
+ pad = (-n_elems) % chunk_size
1180
+ if pad:
1181
+ ctxs_pad = jax.tree_util.tree_map(
1182
+ lambda x: jnp.concatenate([x, jnp.repeat(x[-1:], pad, axis=0)], axis=0),
1183
+ ctxs,
1184
+ )
1185
+ u_elems_pad = jnp.concatenate([u_elems, jnp.repeat(u_elems[-1:], pad, axis=0)], axis=0)
1186
+ else:
1187
+ ctxs_pad = ctxs
1188
+ u_elems_pad = u_elems
1189
+
1190
+ n_pad = n_elems + pad
1191
+ n_chunks = n_pad // chunk_size
1192
+
1193
+ def _slice_first_dim(x, start, size):
1194
+ start_idx = (start,) + (0,) * (x.ndim - 1)
1195
+ slice_sizes = (size,) + x.shape[1:]
1196
+ return jax.lax.dynamic_slice(x, start_idx, slice_sizes)
648
1197
 
649
- rows = elem_dofs.reshape(-1)
1198
+ def chunk_fn(i):
1199
+ start = i * chunk_size
1200
+ ctx_chunk = jax.tree_util.tree_map(
1201
+ lambda x: _slice_first_dim(x, start, chunk_size),
1202
+ ctxs_pad,
1203
+ )
1204
+ u_chunk = _slice_first_dim(u_elems_pad, start, chunk_size)
1205
+ res_chunk = jax.vmap(ker)(ctx_chunk, u_chunk)
1206
+ return res_chunk.reshape(-1)
1207
+
1208
+ data_chunks = jax.vmap(chunk_fn)(jnp.arange(n_chunks))
1209
+ elem_res = data_chunks.reshape(-1)[: n_elems * int(space.n_ldofs)].reshape(n_elems, -1)
1210
+ if jax.core.trace_ctx.is_top_level():
1211
+ if not bool(jax.block_until_ready(jnp.all(jnp.isfinite(elem_res)))):
1212
+ bad = int(jnp.count_nonzero(~jnp.isfinite(elem_res)))
1213
+ raise RuntimeError(f"[assemble_residual_scatter] elem_res nonfinite: {bad}")
1214
+
1215
+ rows = _get_elem_rows(space)
650
1216
  data = elem_res.reshape(-1)
651
1217
 
652
1218
  if sparse:
@@ -668,11 +1234,13 @@ def assemble_jacobian_scatter(
668
1234
  u: jnp.ndarray,
669
1235
  params: P,
670
1236
  *,
671
- kernel=None,
1237
+ kernel: ElementJacobianKernel | None = None,
672
1238
  sparse: bool = False,
673
1239
  return_flux_matrix: bool = False,
674
- pattern=None,
675
- ):
1240
+ pattern: SparsityPattern | None = None,
1241
+ n_chunks: Optional[int] = None,
1242
+ pad_trace: bool = False,
1243
+ ) -> JacobianReturn:
676
1244
  """
677
1245
  Assemble Jacobian using jitted element kernel + vmap + scatter_add.
678
1246
  If a SparsityPattern is provided, rows/cols are reused without regeneration.
@@ -682,7 +1250,9 @@ def assemble_jacobian_scatter(
682
1250
  from ..solver import FluxSparseMatrix # local import to avoid circular
683
1251
 
684
1252
  pat = pattern if pattern is not None else make_sparsity_pattern(space, with_idx=not sparse)
685
- data = assemble_jacobian_values(space, res_form, u, params, kernel=kernel)
1253
+ data = assemble_jacobian_values(
1254
+ space, res_form, u, params, kernel=kernel, n_chunks=n_chunks, pad_trace=pad_trace
1255
+ )
686
1256
 
687
1257
  if sparse:
688
1258
  if return_flux_matrix:
@@ -691,7 +1261,7 @@ def assemble_jacobian_scatter(
691
1261
 
692
1262
  idx = pat.idx
693
1263
  if idx is None:
694
- idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(jnp.int32)
1264
+ idx = (pat.rows.astype(jnp.int64) * int(pat.n_dofs) + pat.cols.astype(jnp.int64)).astype(INDEX_DTYPE)
695
1265
 
696
1266
  n_entries = pat.n_dofs * pat.n_dofs
697
1267
  sdn = jax.lax.ScatterDimensionNumbers(
@@ -708,12 +1278,21 @@ def assemble_jacobian_scatter(
708
1278
  def assemble_residual(
709
1279
  space: SpaceLike,
710
1280
  form: ResidualForm[P],
711
- u: jnp.ndarray, params: P,
1281
+ u: jnp.ndarray,
1282
+ params: P,
712
1283
  *,
713
- sparse: bool = False
714
- ):
715
- """Assemble the global residual vector (scatter-based)."""
716
- return assemble_residual_scatter(space, form, u, params, sparse=sparse)
1284
+ kernel: ElementResidualKernel | None = None,
1285
+ sparse: bool = False,
1286
+ n_chunks: Optional[int] = None,
1287
+ pad_trace: bool = False,
1288
+ ) -> LinearReturn:
1289
+ """
1290
+ Assemble the global residual vector (scatter-based).
1291
+ If kernel is provided: kernel(ctx, u_elem) -> (n_ldofs,).
1292
+ """
1293
+ return assemble_residual_scatter(
1294
+ space, form, u, params, kernel=kernel, sparse=sparse, n_chunks=n_chunks, pad_trace=pad_trace
1295
+ )
717
1296
 
718
1297
 
719
1298
  def assemble_jacobian(
@@ -722,19 +1301,28 @@ def assemble_jacobian(
722
1301
  u: jnp.ndarray,
723
1302
  params: P,
724
1303
  *,
1304
+ kernel: ElementJacobianKernel | None = None,
725
1305
  sparse: bool = True,
726
1306
  return_flux_matrix: bool = False,
727
- pattern=None,
728
- ):
729
- """Assemble the global Jacobian (scatter-based)."""
1307
+ pattern: SparsityPattern | None = None,
1308
+ n_chunks: Optional[int] = None,
1309
+ pad_trace: bool = False,
1310
+ ) -> JacobianReturn:
1311
+ """
1312
+ Assemble the global Jacobian (scatter-based).
1313
+ If kernel is provided: kernel(u_elem, ctx) -> (n_ldofs, n_ldofs).
1314
+ """
730
1315
  return assemble_jacobian_scatter(
731
1316
  space,
732
1317
  res_form,
733
1318
  u,
734
1319
  params,
1320
+ kernel=kernel,
735
1321
  sparse=sparse,
736
1322
  return_flux_matrix=return_flux_matrix,
737
1323
  pattern=pattern,
1324
+ n_chunks=n_chunks,
1325
+ pad_trace=pad_trace,
738
1326
  )
739
1327
 
740
1328
 
@@ -748,13 +1336,19 @@ def scalar_body_force_form(ctx: FormContext, load: float) -> jnp.ndarray:
748
1336
  return load * ctx.test.N # (n_q, n_ldofs)
749
1337
 
750
1338
 
751
- def make_scalar_body_force_form(body_force):
1339
+ scalar_body_force_form._ff_kind = "linear" # type: ignore[attr-defined]
1340
+ scalar_body_force_form._ff_domain = "volume" # type: ignore[attr-defined]
1341
+
1342
+
1343
+ def make_scalar_body_force_form(body_force: Callable[[Array], Array]) -> FormKernel[Any]:
752
1344
  """
753
1345
  Build a scalar linear form from a callable f(x_q) -> (n_q,).
754
1346
  """
755
1347
  def _form(ctx: FormContext, _params):
756
1348
  f_q = body_force(ctx.x_q)
757
1349
  return f_q[..., None] * ctx.test.N
1350
+ _form._ff_kind = "linear" # type: ignore[attr-defined]
1351
+ _form._ff_domain = "volume" # type: ignore[attr-defined]
758
1352
  return _form
759
1353
 
760
1354
 
@@ -762,7 +1356,7 @@ def make_scalar_body_force_form(body_force):
762
1356
  constant_body_force_form = scalar_body_force_form
763
1357
 
764
1358
 
765
- def _check_structured_box_connectivity():
1359
+ def _check_structured_box_connectivity() -> None:
766
1360
  """Quick connectivity check for nx=2, ny=1, nz=1 (non-structured order)."""
767
1361
  box = StructuredHexBox(nx=2, ny=1, nz=1, lx=2.0, ly=1.0, lz=1.0)
768
1362
  mesh = box.build()
@@ -775,7 +1369,7 @@ def _check_structured_box_connectivity():
775
1369
  [0, 1, 4, 3, 6, 7, 10, 9], # element at i=0
776
1370
  [1, 2, 5, 4, 7, 8, 11, 10], # element at i=1
777
1371
  ],
778
- dtype=jnp.int32,
1372
+ dtype=INDEX_DTYPE,
779
1373
  )
780
1374
  max_diff = int(jnp.max(jnp.abs(mesh.conn - expected_conn)))
781
1375
  print("StructuredHexBox nx=2,ny=1,nz=1 conn matches expected:", max_diff == 0)