fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/space.py CHANGED
@@ -1,11 +1,47 @@
1
1
  from __future__ import annotations
2
2
  import operator
3
3
  from dataclasses import dataclass, field
4
- from typing import Protocol
4
+ from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeVar
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
8
+ import warnings
9
+
10
+ P = TypeVar("P")
11
+
12
+ if TYPE_CHECKING:
13
+ from .assembly import (
14
+ ElementBilinearKernel,
15
+ ElementJacobianKernel,
16
+ ElementLinearKernel,
17
+ ElementResidualKernel,
18
+ Kernel,
19
+ ResidualForm,
20
+ )
21
+ else:
22
+ Kernel = Callable[..., Any]
23
+ ResidualForm = Callable[..., Any]
24
+ ElementBilinearKernel = Callable[..., Any]
25
+ ElementLinearKernel = Callable[..., Any]
26
+ ElementResidualKernel = Callable[..., Any]
27
+ ElementJacobianKernel = Callable[..., Any]
28
+
29
+ _WARNED_UNTAGGED_KERNELS: set[int] = set()
30
+
31
+
32
+ def _warn_untagged_kernel(form) -> None:
33
+ key = id(form)
34
+ if key in _WARNED_UNTAGGED_KERNELS:
35
+ return
36
+ _WARNED_UNTAGGED_KERNELS.add(key)
37
+ warnings.warn(
38
+ "Raw kernel has no _ff_kind metadata; prefer tagging with ff.kernel(...) or "
39
+ "set _ff_kind/_ff_domain on the callable.",
40
+ category=UserWarning,
41
+ stacklevel=3,
42
+ )
8
43
 
44
+ from .dtypes import INDEX_DTYPE
9
45
  from ..mesh import (
10
46
  BaseMesh,
11
47
  BaseMeshPytree,
@@ -79,12 +115,14 @@ class FESpaceClosure:
79
115
  """
80
116
  mesh: BaseMesh
81
117
  basis: Basis3D
82
- elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int32
118
+ elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int64
83
119
  value_dim: int = 1 # 1=scalar, 3=vector, etc.
84
120
  _n_dofs: int | None = None
85
121
  _n_ldofs: int | None = None
86
122
  data: SpaceData | None = None
87
123
  _pattern_cache: dict[bool, object] = field(default_factory=dict, repr=False)
124
+ _kernel_cache: dict[tuple, object] = field(default_factory=dict, repr=False)
125
+ _elem_rows_cache: jnp.ndarray | None = field(default=None, repr=False)
88
126
 
89
127
  def __post_init__(self):
90
128
  # Ensure value_dim is a Python int (avoid tracers).
@@ -116,6 +154,15 @@ class FESpaceClosure:
116
154
  assert self._n_ldofs is not None
117
155
  return self._n_ldofs
118
156
 
157
+ def get_elem_rows(self) -> jnp.ndarray:
158
+ cached = self._elem_rows_cache
159
+ if cached is not None:
160
+ return cached
161
+ rows = self.elem_dofs.reshape(-1)
162
+ if jax.core.trace_ctx.is_top_level():
163
+ self._elem_rows_cache = rows
164
+ return rows
165
+
119
166
  def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext:
120
167
  def _tie_in(x, y):
121
168
  if x is None:
@@ -148,43 +195,240 @@ class FESpaceClosure:
148
195
  )
149
196
 
150
197
  test = jax.vmap(make_field)(elem_coords)
151
- trial = jax.vmap(make_field)(elem_coords)
198
+ # Test/trial share the same field data for single-space bilinear forms.
199
+ trial = test
152
200
 
153
201
  return FormContext(
154
202
  test=test, trial=trial, x_q=x_q,
155
203
  w=w, elem_id=jnp.arange(elem_coords.shape[0])
156
204
  )
157
205
 
206
+ def make_batched_assembler(self, *, dep: jnp.ndarray | None = None, pattern=None):
207
+ from .assembly import BatchedAssembler
208
+ if pattern is None:
209
+ pattern = self.get_sparsity_pattern(with_idx=True)
210
+ return BatchedAssembler.from_space(self, dep=dep, pattern=pattern)
211
+
212
+ def build_cg_operator(
213
+ self,
214
+ A,
215
+ *,
216
+ matvec: str = "flux",
217
+ preconditioner=None,
218
+ solver: str = "cg",
219
+ dof_per_node: int | None = None,
220
+ block_sizes=None,
221
+ ):
222
+ from ..solver.cg import build_cg_operator
223
+
224
+ if dof_per_node is None:
225
+ dof_per_node = int(self.value_dim)
226
+ return build_cg_operator(
227
+ A,
228
+ matvec=matvec,
229
+ preconditioner=preconditioner,
230
+ solver=solver,
231
+ dof_per_node=dof_per_node,
232
+ block_sizes=block_sizes,
233
+ )
234
+
235
+ def _kernel_cache_key(self, kind: str, form, params, *, jit: bool):
236
+ if not jit:
237
+ return None
238
+ form_key = getattr(form, "__wrapped__", form)
239
+ try:
240
+ params_key = hash(params)
241
+ except Exception:
242
+ return None
243
+ return (kind, id(form_key), params_key, True)
244
+
245
+ def _get_cached_kernel(self, kind: str, form, params, *, jit: bool, maker):
246
+ key = self._kernel_cache_key(kind, form, params, jit=jit)
247
+ if key is None:
248
+ return maker(form, params, jit=jit)
249
+ cached = self._kernel_cache.get(key)
250
+ if cached is not None:
251
+ return cached
252
+ kernel = maker(form, params, jit=jit)
253
+ if jax.core.trace_ctx.is_top_level():
254
+ self._kernel_cache[key] = kernel
255
+ return kernel
256
+
257
+ def assemble(
258
+ self,
259
+ form,
260
+ params=None,
261
+ *,
262
+ kind: str | None = None,
263
+ n_chunks: int | None = None,
264
+ dep: jnp.ndarray | None = None,
265
+ jit: bool = True,
266
+ pattern: str | object | None = "auto",
267
+ kernel=None,
268
+ **kwargs,
269
+ ):
270
+ """
271
+ High-level assembly entry point with optional kernel caching.
272
+
273
+ kind: "bilinear" or "linear". If None, inferred from LinearForm/BilinearForm
274
+ or compiled/kernels tagged with _ff_kind metadata.
275
+ pattern: "auto" to reuse cached sparsity pattern for bilinear assembly.
276
+ """
277
+ from .weakform import BilinearForm, LinearForm
278
+ from .assembly import make_element_bilinear_kernel, make_element_linear_kernel
279
+
280
+ if kind is None:
281
+ if isinstance(form, BilinearForm):
282
+ kind = "bilinear"
283
+ form = form.get_compiled()
284
+ elif isinstance(form, LinearForm):
285
+ kind = "linear"
286
+ form = form.get_compiled()
287
+ else:
288
+ inferred_kind = getattr(form, "_ff_kind", None)
289
+ inferred_domain = getattr(form, "_ff_domain", None)
290
+ if inferred_kind is None:
291
+ raise ValueError(
292
+ f"kind is required for raw kernels without metadata (got {form!r}). "
293
+ "Use @ff.kernel(kind=..., domain=...) or pass kind=."
294
+ )
295
+ if inferred_domain not in (None, "volume"):
296
+ raise ValueError(
297
+ f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
298
+ "Use assemble_surface_linear_form or assemble_surface_bilinear_form."
299
+ )
300
+ kind = inferred_kind
301
+ else:
302
+ inferred_kind = getattr(form, "_ff_kind", None)
303
+ inferred_domain = getattr(form, "_ff_domain", None)
304
+ if inferred_kind is None:
305
+ _warn_untagged_kernel(form)
306
+ if inferred_kind is not None and inferred_kind != kind:
307
+ raise ValueError(
308
+ f"assemble kind '{kind}' does not match form kind '{inferred_kind}' "
309
+ f"for {form!r}. "
310
+ "Align kind= with the kernel metadata (or retag the kernel)."
311
+ )
312
+ if inferred_domain not in (None, "volume"):
313
+ raise ValueError(
314
+ f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
315
+ "Use assemble_surface_linear_form or assemble_surface_bilinear_form."
316
+ )
317
+
318
+ if kind == "bilinear":
319
+ if kernel is None:
320
+ kernel = self._get_cached_kernel(
321
+ "bilinear",
322
+ form,
323
+ params,
324
+ jit=jit,
325
+ maker=make_element_bilinear_kernel,
326
+ )
327
+ pattern_use = None
328
+ if pattern == "auto":
329
+ pattern_use = self.get_sparsity_pattern(with_idx=True)
330
+ else:
331
+ pattern_use = pattern
332
+ return self.assemble_bilinear_form(
333
+ form,
334
+ params,
335
+ n_chunks=n_chunks,
336
+ dep=dep,
337
+ kernel=kernel,
338
+ pattern=pattern_use,
339
+ **kwargs,
340
+ )
341
+ if kind == "linear":
342
+ if kernel is None:
343
+ kernel = self._get_cached_kernel(
344
+ "linear",
345
+ form,
346
+ params,
347
+ jit=jit,
348
+ maker=make_element_linear_kernel,
349
+ )
350
+ return self.assemble_linear_form(
351
+ form,
352
+ params,
353
+ n_chunks=n_chunks,
354
+ dep=dep,
355
+ kernel=kernel,
356
+ **kwargs,
357
+ )
358
+ raise ValueError(f"Unsupported assemble kind: {kind}")
359
+
158
360
  # --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
159
- def assemble_bilinear_form(self, form, params, *, chunk_size=None, dep=None, **kwargs):
361
+ def assemble_bilinear_form(
362
+ self,
363
+ form: Kernel[P],
364
+ params: P,
365
+ *,
366
+ n_chunks: int | None = None,
367
+ dep: jnp.ndarray | None = None,
368
+ kernel: ElementBilinearKernel | None = None,
369
+ **kwargs,
370
+ ):
371
+ """Assemble bilinear form; kernel(ctx) -> (n_ldofs, n_ldofs) if provided."""
160
372
  from .assembly import assemble_bilinear_form
161
373
  if "pattern" not in kwargs or kwargs.get("pattern") is None:
162
374
  kwargs["pattern"] = self.get_sparsity_pattern(with_idx=True)
163
- return assemble_bilinear_form(self, form, params, chunk_size=chunk_size, dep=dep, **kwargs)
375
+ return assemble_bilinear_form(
376
+ self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
377
+ )
164
378
 
165
- def assemble_linear_form(self, form, params, *, chunk_size=None, dep=None, **kwargs):
379
+ def assemble_linear_form(
380
+ self,
381
+ form: Kernel[P],
382
+ params: P,
383
+ *,
384
+ n_chunks: int | None = None,
385
+ dep: jnp.ndarray | None = None,
386
+ kernel: ElementLinearKernel | None = None,
387
+ **kwargs,
388
+ ):
389
+ """Assemble linear form; kernel(ctx) -> (n_ldofs,) if provided."""
166
390
  from .assembly import assemble_linear_form
167
- return assemble_linear_form(self, form, params, chunk_size=chunk_size, dep=dep, **kwargs)
391
+ return assemble_linear_form(
392
+ self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
393
+ )
168
394
 
169
395
  def assemble_functional(self, form, params):
170
396
  from .assembly import assemble_functional
171
397
  return assemble_functional(self, form, params)
172
398
 
173
- def assemble_mass_matrix(self, *, chunk_size=None, **kwargs):
399
+ def assemble_mass_matrix(self, *, n_chunks=None, **kwargs):
174
400
  from .assembly import assemble_mass_matrix
175
- return assemble_mass_matrix(self, chunk_size=chunk_size, **kwargs)
401
+ return assemble_mass_matrix(self, n_chunks=n_chunks, **kwargs)
176
402
 
177
403
  def assemble_bilinear_dense(self, kernel, params, **kwargs):
178
404
  from .assembly import assemble_bilinear_dense
179
405
  return assemble_bilinear_dense(self, kernel, params, **kwargs)
180
406
 
181
- def assemble_residual(self, res_form, u, params, **kwargs):
407
+ def assemble_residual(
408
+ self,
409
+ res_form: ResidualForm[P],
410
+ u: jnp.ndarray,
411
+ params: P,
412
+ *,
413
+ kernel: ElementResidualKernel | None = None,
414
+ **kwargs,
415
+ ):
416
+ """Assemble residual; kernel(ctx, u_elem) -> (n_ldofs,) if provided."""
182
417
  from .assembly import assemble_residual
183
- return assemble_residual(self, res_form, u, params, **kwargs)
184
-
185
- def assemble_jacobian(self, res_form, u, params, **kwargs):
418
+ return assemble_residual(self, res_form, u, params, kernel=kernel, **kwargs)
419
+
420
+ def assemble_jacobian(
421
+ self,
422
+ res_form: ResidualForm[P],
423
+ u: jnp.ndarray,
424
+ params: P,
425
+ *,
426
+ kernel: ElementJacobianKernel | None = None,
427
+ **kwargs,
428
+ ):
429
+ """Assemble Jacobian; kernel(u_elem, ctx) -> (n_ldofs, n_ldofs) if provided."""
186
430
  from .assembly import assemble_jacobian
187
- return assemble_jacobian(self, res_form, u, params, **kwargs)
431
+ return assemble_jacobian(self, res_form, u, params, kernel=kernel, **kwargs)
188
432
 
189
433
  def get_sparsity_pattern(self, *, with_idx: bool = True):
190
434
  cached = self._pattern_cache.get(with_idx)
@@ -192,7 +436,8 @@ class FESpaceClosure:
192
436
  return cached
193
437
  from .assembly import make_sparsity_pattern
194
438
  pat = make_sparsity_pattern(self, with_idx=with_idx)
195
- self._pattern_cache[with_idx] = pat
439
+ if jax.core.trace_ctx.is_top_level():
440
+ self._pattern_cache[with_idx] = pat
196
441
  return pat
197
442
 
198
443
 
@@ -251,7 +496,7 @@ def make_space(
251
496
  return FESpace(
252
497
  mesh=mesh,
253
498
  basis=basis,
254
- elem_dofs=jnp.asarray(elem_dofs, dtype=jnp.int32),
499
+ elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
255
500
  value_dim=value_dim
256
501
  )
257
502
 
@@ -320,7 +565,7 @@ def make_space_pytree(
320
565
  return FESpacePytree(
321
566
  mesh=mesh_py,
322
567
  basis=basis_py,
323
- elem_dofs=jnp.asarray(elem_dofs, dtype=jnp.int32),
568
+ elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
324
569
  value_dim=value_dim,
325
570
  )
326
571