fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fluxfem/__init__.py +69 -13
  2. fluxfem/core/__init__.py +140 -53
  3. fluxfem/core/assembly.py +691 -97
  4. fluxfem/core/basis.py +75 -54
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/dtypes.py +9 -1
  7. fluxfem/core/forms.py +10 -0
  8. fluxfem/core/mixed_assembly.py +263 -0
  9. fluxfem/core/mixed_space.py +382 -0
  10. fluxfem/core/mixed_weakform.py +97 -0
  11. fluxfem/core/solver.py +2 -0
  12. fluxfem/core/space.py +315 -30
  13. fluxfem/core/weakform.py +821 -42
  14. fluxfem/helpers_wf.py +49 -0
  15. fluxfem/mesh/__init__.py +54 -2
  16. fluxfem/mesh/base.py +318 -9
  17. fluxfem/mesh/contact.py +841 -0
  18. fluxfem/mesh/dtypes.py +12 -0
  19. fluxfem/mesh/hex.py +17 -16
  20. fluxfem/mesh/io.py +9 -6
  21. fluxfem/mesh/mortar.py +3970 -0
  22. fluxfem/mesh/supermesh.py +318 -0
  23. fluxfem/mesh/surface.py +104 -26
  24. fluxfem/mesh/tet.py +16 -7
  25. fluxfem/physics/diffusion.py +3 -0
  26. fluxfem/physics/elasticity/hyperelastic.py +35 -3
  27. fluxfem/physics/elasticity/linear.py +22 -4
  28. fluxfem/physics/elasticity/stress.py +9 -5
  29. fluxfem/physics/operators.py +12 -5
  30. fluxfem/physics/postprocess.py +29 -3
  31. fluxfem/solver/__init__.py +47 -2
  32. fluxfem/solver/bc.py +38 -2
  33. fluxfem/solver/block_matrix.py +284 -0
  34. fluxfem/solver/block_system.py +477 -0
  35. fluxfem/solver/cg.py +150 -55
  36. fluxfem/solver/dirichlet.py +358 -5
  37. fluxfem/solver/history.py +15 -3
  38. fluxfem/solver/newton.py +260 -70
  39. fluxfem/solver/petsc.py +445 -0
  40. fluxfem/solver/preconditioner.py +109 -0
  41. fluxfem/solver/result.py +18 -0
  42. fluxfem/solver/solve_runner.py +208 -23
  43. fluxfem/solver/solver.py +35 -12
  44. fluxfem/solver/sparse.py +149 -15
  45. fluxfem/tools/jit.py +19 -7
  46. fluxfem/tools/timer.py +14 -12
  47. fluxfem/tools/visualizer.py +16 -4
  48. fluxfem-0.2.1.dist-info/METADATA +314 -0
  49. fluxfem-0.2.1.dist-info/RECORD +59 -0
  50. fluxfem-0.1.4.dist-info/METADATA +0 -127
  51. fluxfem-0.1.4.dist-info/RECORD +0 -48
  52. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  53. {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/space.py CHANGED
@@ -1,11 +1,71 @@
1
1
  from __future__ import annotations
2
2
  import operator
3
3
  from dataclasses import dataclass, field
4
- from typing import Protocol
4
+ from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, cast
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
8
+ import warnings
9
+
10
+ P = TypeVar("P")
11
+
12
+ if TYPE_CHECKING:
13
+ from .assembly import (
14
+ ElementBilinearKernel,
15
+ ElementJacobianKernel,
16
+ ElementLinearKernel,
17
+ ElementResidualKernel,
18
+ Kernel,
19
+ FormKernel,
20
+ BilinearReturn,
21
+ JacobianReturn,
22
+ LinearReturn,
23
+ MassReturn,
24
+ ResidualForm,
25
+ )
26
+ from .weakform import BilinearForm, LinearForm
27
+ from ..solver import FluxSparseMatrix, SparsityPattern
28
+ else:
29
+ Kernel = Callable[..., Any]
30
+ FormKernel = Callable[..., Any]
31
+ ResidualForm = Callable[..., Any]
32
+ ElementBilinearKernel = Callable[..., Any]
33
+ ElementLinearKernel = Callable[..., Any]
34
+ ElementResidualKernel = Callable[..., Any]
35
+ ElementJacobianKernel = Callable[..., Any]
36
+ BilinearReturn = Any
37
+ JacobianReturn = Any
38
+ LinearReturn = Any
39
+ MassReturn = Any
40
+ BilinearForm = Any
41
+ LinearForm = Any
42
+ FluxSparseMatrix = Any
43
+ SparsityPattern = Any
44
+
45
+ KernelCacheKey: TypeAlias = tuple[str, int, int, bool]
46
+ KernelCache: TypeAlias = dict[
47
+ KernelCacheKey,
48
+ ElementBilinearKernel | ElementLinearKernel | ElementResidualKernel | ElementJacobianKernel,
49
+ ]
50
+ PatternCache: TypeAlias = dict[bool, SparsityPattern]
51
+ PatternLike: TypeAlias = str | SparsityPattern | None
52
+
53
+ _WARNED_UNTAGGED_KERNELS: set[int] = set()
54
+
55
+
56
+ def _warn_untagged_kernel(form) -> None:
57
+ key = id(form)
58
+ if key in _WARNED_UNTAGGED_KERNELS:
59
+ return
60
+ _WARNED_UNTAGGED_KERNELS.add(key)
61
+ warnings.warn(
62
+ "Raw kernel has no _ff_kind metadata; prefer tagging with ff.kernel(...) or "
63
+ "set _ff_kind/_ff_domain on the callable.",
64
+ category=UserWarning,
65
+ stacklevel=3,
66
+ )
8
67
 
68
+ from .dtypes import INDEX_DTYPE
9
69
  from ..mesh import (
10
70
  BaseMesh,
11
71
  BaseMeshPytree,
@@ -57,8 +117,12 @@ class FESpaceBase(Protocol):
57
117
  """
58
118
  elem_dofs: jnp.ndarray
59
119
  value_dim: int
60
- n_dofs: int
61
- n_ldofs: int
120
+
121
+ @property
122
+ def n_dofs(self) -> int: ...
123
+
124
+ @property
125
+ def n_ldofs(self) -> int: ...
62
126
 
63
127
  def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext: ...
64
128
 
@@ -79,12 +143,14 @@ class FESpaceClosure:
79
143
  """
80
144
  mesh: BaseMesh
81
145
  basis: Basis3D
82
- elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int32
146
+ elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int64
83
147
  value_dim: int = 1 # 1=scalar, 3=vector, etc.
84
148
  _n_dofs: int | None = None
85
149
  _n_ldofs: int | None = None
86
150
  data: SpaceData | None = None
87
- _pattern_cache: dict[bool, object] = field(default_factory=dict, repr=False)
151
+ _pattern_cache: PatternCache = field(default_factory=dict, repr=False)
152
+ _kernel_cache: KernelCache = field(default_factory=dict, repr=False)
153
+ _elem_rows_cache: jnp.ndarray | None = field(default=None, repr=False)
88
154
 
89
155
  def __post_init__(self):
90
156
  # Ensure value_dim is a Python int (avoid tracers).
@@ -116,6 +182,15 @@ class FESpaceClosure:
116
182
  assert self._n_ldofs is not None
117
183
  return self._n_ldofs
118
184
 
185
+ def get_elem_rows(self) -> jnp.ndarray:
186
+ cached = self._elem_rows_cache
187
+ if cached is not None:
188
+ return cached
189
+ rows = self.elem_dofs.reshape(-1)
190
+ if jax.core.trace_ctx.is_top_level():
191
+ self._elem_rows_cache = rows
192
+ return rows
193
+
119
194
  def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext:
120
195
  def _tie_in(x, y):
121
196
  if x is None:
@@ -148,43 +223,252 @@ class FESpaceClosure:
148
223
  )
149
224
 
150
225
  test = jax.vmap(make_field)(elem_coords)
151
- trial = jax.vmap(make_field)(elem_coords)
226
+ # Test/trial share the same field data for single-space bilinear forms.
227
+ trial = test
152
228
 
153
229
  return FormContext(
154
230
  test=test, trial=trial, x_q=x_q,
155
231
  w=w, elem_id=jnp.arange(elem_coords.shape[0])
156
232
  )
157
233
 
234
+ def make_batched_assembler(self, *, dep: jnp.ndarray | None = None, pattern=None):
235
+ from .assembly import BatchedAssembler
236
+ if pattern is None:
237
+ pattern = self.get_sparsity_pattern(with_idx=True)
238
+ return BatchedAssembler.from_space(self, dep=dep, pattern=pattern)
239
+
240
+ def build_cg_operator(
241
+ self,
242
+ A,
243
+ *,
244
+ matvec: str = "flux",
245
+ preconditioner=None,
246
+ solver: str = "cg",
247
+ dof_per_node: int | None = None,
248
+ block_sizes=None,
249
+ ):
250
+ from ..solver.cg import build_cg_operator
251
+
252
+ if dof_per_node is None:
253
+ dof_per_node = int(self.value_dim)
254
+ return build_cg_operator(
255
+ A,
256
+ matvec=matvec,
257
+ preconditioner=preconditioner,
258
+ solver=solver,
259
+ dof_per_node=dof_per_node,
260
+ block_sizes=block_sizes,
261
+ )
262
+
263
+ def _kernel_cache_key(self, kind: str, form, params, *, jit: bool) -> KernelCacheKey | None:
264
+ if not jit:
265
+ return None
266
+ form_key = getattr(form, "__wrapped__", form)
267
+ try:
268
+ params_key = hash(params)
269
+ except Exception:
270
+ return None
271
+ return (kind, id(form_key), params_key, True)
272
+
273
+ def _get_cached_kernel(self, kind: str, form, params, *, jit: bool, maker):
274
+ key = self._kernel_cache_key(kind, form, params, jit=jit)
275
+ if key is None:
276
+ return maker(form, params, jit=jit)
277
+ cached = self._kernel_cache.get(key)
278
+ if cached is not None:
279
+ return cached
280
+ kernel = maker(form, params, jit=jit)
281
+ if jax.core.trace_ctx.is_top_level():
282
+ self._kernel_cache[key] = kernel
283
+ return kernel
284
+
285
+ def assemble(
286
+ self,
287
+ form: FormKernel[P] | BilinearForm | LinearForm,
288
+ params: P | None = None,
289
+ *,
290
+ kind: str | None = None,
291
+ n_chunks: int | None = None,
292
+ dep: jnp.ndarray | None = None,
293
+ jit: bool = True,
294
+ pattern: PatternLike = "auto",
295
+ kernel: ElementBilinearKernel | ElementLinearKernel | None = None,
296
+ **kwargs,
297
+ ):
298
+ """
299
+ High-level assembly entry point with optional kernel caching.
300
+
301
+ kind: "bilinear" or "linear". If None, inferred from LinearForm/BilinearForm
302
+ or compiled/kernels tagged with _ff_kind metadata.
303
+ pattern: "auto" to reuse cached sparsity pattern for bilinear assembly.
304
+ """
305
+ from .weakform import BilinearForm, LinearForm
306
+ from .assembly import make_element_bilinear_kernel, make_element_linear_kernel
307
+
308
+ if kind is None:
309
+ if isinstance(form, BilinearForm):
310
+ kind = "bilinear"
311
+ form = form.get_compiled()
312
+ elif isinstance(form, LinearForm):
313
+ kind = "linear"
314
+ form = form.get_compiled()
315
+ else:
316
+ inferred_kind = getattr(form, "_ff_kind", None)
317
+ inferred_domain = getattr(form, "_ff_domain", None)
318
+ if inferred_kind is None:
319
+ raise ValueError(
320
+ f"kind is required for raw kernels without metadata (got {form!r}). "
321
+ "Use @ff.kernel(kind=..., domain=...) or pass kind=."
322
+ )
323
+ if inferred_domain not in (None, "volume"):
324
+ raise ValueError(
325
+ f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
326
+ "Use assemble_surface_linear_form or assemble_surface_bilinear_form."
327
+ )
328
+ kind = inferred_kind
329
+ else:
330
+ inferred_kind = getattr(form, "_ff_kind", None)
331
+ inferred_domain = getattr(form, "_ff_domain", None)
332
+ if inferred_kind is None:
333
+ _warn_untagged_kernel(form)
334
+ if inferred_kind is not None and inferred_kind != kind:
335
+ raise ValueError(
336
+ f"assemble kind '{kind}' does not match form kind '{inferred_kind}' "
337
+ f"for {form!r}. "
338
+ "Align kind= with the kernel metadata (or retag the kernel)."
339
+ )
340
+ if inferred_domain not in (None, "volume"):
341
+ raise ValueError(
342
+ f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
343
+ "Use assemble_surface_linear_form or assemble_surface_bilinear_form."
344
+ )
345
+
346
+ if kind == "bilinear":
347
+ if kernel is None:
348
+ kernel = self._get_cached_kernel(
349
+ "bilinear",
350
+ form,
351
+ params,
352
+ jit=jit,
353
+ maker=make_element_bilinear_kernel,
354
+ )
355
+ form_kernel = cast(FormKernel, form)
356
+ pattern_use = None
357
+ if pattern == "auto":
358
+ pattern_use = self.get_sparsity_pattern(with_idx=True)
359
+ else:
360
+ pattern_use = pattern
361
+ return self.assemble_bilinear_form(
362
+ form_kernel,
363
+ params,
364
+ n_chunks=n_chunks,
365
+ dep=dep,
366
+ kernel=kernel,
367
+ pattern=pattern_use,
368
+ **kwargs,
369
+ )
370
+ if kind == "linear":
371
+ if kernel is None:
372
+ kernel = self._get_cached_kernel(
373
+ "linear",
374
+ form,
375
+ params,
376
+ jit=jit,
377
+ maker=make_element_linear_kernel,
378
+ )
379
+ form_kernel = cast(FormKernel, form)
380
+ return self.assemble_linear_form(
381
+ form_kernel,
382
+ params,
383
+ n_chunks=n_chunks,
384
+ dep=dep,
385
+ kernel=kernel,
386
+ **kwargs,
387
+ )
388
+ raise ValueError(f"Unsupported assemble kind: {kind}")
389
+
158
390
  # --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
159
- def assemble_bilinear_form(self, form, params, *, chunk_size=None, dep=None, **kwargs):
391
+ def assemble_bilinear_form(
392
+ self,
393
+ form: FormKernel[P],
394
+ params: P,
395
+ *,
396
+ n_chunks: int | None = None,
397
+ dep: jnp.ndarray | None = None,
398
+ kernel: ElementBilinearKernel | None = None,
399
+ **kwargs,
400
+ ) -> FluxSparseMatrix:
401
+ """Assemble bilinear form; kernel(ctx) -> (n_ldofs, n_ldofs) if provided."""
160
402
  from .assembly import assemble_bilinear_form
161
403
  if "pattern" not in kwargs or kwargs.get("pattern") is None:
162
404
  kwargs["pattern"] = self.get_sparsity_pattern(with_idx=True)
163
- return assemble_bilinear_form(self, form, params, chunk_size=chunk_size, dep=dep, **kwargs)
405
+ return assemble_bilinear_form(
406
+ self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
407
+ )
164
408
 
165
- def assemble_linear_form(self, form, params, *, chunk_size=None, dep=None, **kwargs):
409
+ def assemble_linear_form(
410
+ self,
411
+ form: FormKernel[P],
412
+ params: P,
413
+ *,
414
+ n_chunks: int | None = None,
415
+ dep: jnp.ndarray | None = None,
416
+ kernel: ElementLinearKernel | None = None,
417
+ **kwargs,
418
+ ) -> LinearReturn:
419
+ """Assemble linear form; kernel(ctx) -> (n_ldofs,) if provided."""
166
420
  from .assembly import assemble_linear_form
167
- return assemble_linear_form(self, form, params, chunk_size=chunk_size, dep=dep, **kwargs)
421
+ return assemble_linear_form(
422
+ self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
423
+ )
168
424
 
169
- def assemble_functional(self, form, params):
425
+ def assemble_functional(self, form: FormKernel[P], params: P) -> jnp.ndarray:
170
426
  from .assembly import assemble_functional
171
427
  return assemble_functional(self, form, params)
172
428
 
173
- def assemble_mass_matrix(self, *, chunk_size=None, **kwargs):
429
+ def assemble_mass_matrix(
430
+ self,
431
+ *,
432
+ n_chunks: int | None = None,
433
+ **kwargs,
434
+ ) -> MassReturn:
174
435
  from .assembly import assemble_mass_matrix
175
- return assemble_mass_matrix(self, chunk_size=chunk_size, **kwargs)
176
-
177
- def assemble_bilinear_dense(self, kernel, params, **kwargs):
436
+ return assemble_mass_matrix(self, n_chunks=n_chunks, **kwargs)
437
+
438
+ def assemble_bilinear_dense(
439
+ self,
440
+ kernel: FormKernel[P],
441
+ params: P,
442
+ **kwargs,
443
+ ) -> BilinearReturn:
178
444
  from .assembly import assemble_bilinear_dense
179
445
  return assemble_bilinear_dense(self, kernel, params, **kwargs)
180
446
 
181
- def assemble_residual(self, res_form, u, params, **kwargs):
447
+ def assemble_residual(
448
+ self,
449
+ res_form: ResidualForm[P],
450
+ u: jnp.ndarray,
451
+ params: P,
452
+ *,
453
+ kernel: ElementResidualKernel | None = None,
454
+ **kwargs,
455
+ ) -> LinearReturn:
456
+ """Assemble residual; kernel(ctx, u_elem) -> (n_ldofs,) if provided."""
182
457
  from .assembly import assemble_residual
183
- return assemble_residual(self, res_form, u, params, **kwargs)
184
-
185
- def assemble_jacobian(self, res_form, u, params, **kwargs):
458
+ return assemble_residual(self, res_form, u, params, kernel=kernel, **kwargs)
459
+
460
+ def assemble_jacobian(
461
+ self,
462
+ res_form: ResidualForm[P],
463
+ u: jnp.ndarray,
464
+ params: P,
465
+ *,
466
+ kernel: ElementJacobianKernel | None = None,
467
+ **kwargs,
468
+ ) -> JacobianReturn:
469
+ """Assemble Jacobian; kernel(u_elem, ctx) -> (n_ldofs, n_ldofs) if provided."""
186
470
  from .assembly import assemble_jacobian
187
- return assemble_jacobian(self, res_form, u, params, **kwargs)
471
+ return assemble_jacobian(self, res_form, u, params, kernel=kernel, **kwargs)
188
472
 
189
473
  def get_sparsity_pattern(self, *, with_idx: bool = True):
190
474
  cached = self._pattern_cache.get(with_idx)
@@ -192,7 +476,8 @@ class FESpaceClosure:
192
476
  return cached
193
477
  from .assembly import make_sparsity_pattern
194
478
  pat = make_sparsity_pattern(self, with_idx=with_idx)
195
- self._pattern_cache[with_idx] = pat
479
+ if jax.core.trace_ctx.is_top_level():
480
+ self._pattern_cache[with_idx] = pat
196
481
  return pat
197
482
 
198
483
 
@@ -251,7 +536,7 @@ def make_space(
251
536
  return FESpace(
252
537
  mesh=mesh,
253
538
  basis=basis,
254
- elem_dofs=jnp.asarray(elem_dofs, dtype=jnp.int32),
539
+ elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
255
540
  value_dim=value_dim
256
541
  )
257
542
 
@@ -320,7 +605,7 @@ def make_space_pytree(
320
605
  return FESpacePytree(
321
606
  mesh=mesh_py,
322
607
  basis=basis_py,
323
- elem_dofs=jnp.asarray(elem_dofs, dtype=jnp.int32),
608
+ elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
324
609
  value_dim=value_dim,
325
610
  )
326
611
 
@@ -340,7 +625,7 @@ def make_tet10_space_pytree(
340
625
  """Create a pytree quadratic tet space (10-node elements)."""
341
626
  basis = make_tet10_basis_pytree(intorder)
342
627
  element = None if dim == 1 else ElementVector(dim)
343
- return make_space_pytree(mesh, basis, element)
628
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
344
629
 
345
630
 
346
631
  def make_hex_space(mesh: HexMesh, dim: int = 1, intorder: int = 2) -> FESpace:
@@ -356,7 +641,7 @@ def make_hex_space_pytree(
356
641
  """Create a pytree trilinear hex space (8-node elements)."""
357
642
  basis = make_hex_basis_pytree(intorder)
358
643
  element = None if dim == 1 else ElementVector(dim)
359
- return make_space_pytree(mesh, basis, element)
644
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
360
645
 
361
646
 
362
647
  def make_hex20_space(
@@ -374,7 +659,7 @@ def make_hex20_space_pytree(
374
659
  """Create a pytree serendipity hex space (20-node elements)."""
375
660
  basis = make_hex20_basis_pytree(intorder)
376
661
  element = None if dim == 1 else ElementVector(dim)
377
- return make_space_pytree(mesh, basis, element)
662
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
378
663
 
379
664
 
380
665
  def make_hex27_space(
@@ -392,14 +677,14 @@ def make_hex27_space_pytree(
392
677
  """Create a pytree triquadratic hex space (27-node elements)."""
393
678
  basis = make_hex27_basis_pytree(intorder)
394
679
  element = None if dim == 1 else ElementVector(dim)
395
- return make_space_pytree(mesh, basis, element)
680
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
396
681
 
397
682
 
398
683
  def make_tet_space(mesh: TetMesh, dim: int = 1, intorder: int = 2) -> FESpace:
399
684
  """Create a linear or quadratic tet space based on mesh nodes."""
400
685
  n_nodes = mesh.conn.shape[1]
401
686
  if n_nodes == 10:
402
- basis = make_tet10_basis(intorder if intorder > 1 else 2)
687
+ basis: Basis3D = make_tet10_basis(intorder if intorder > 1 else 2)
403
688
  else:
404
689
  basis = make_tet_basis(intorder)
405
690
  element = None if dim == 1 else ElementVector(dim)
@@ -412,8 +697,8 @@ def make_tet_space_pytree(
412
697
  """Create a pytree linear or quadratic tet space based on mesh nodes."""
413
698
  n_nodes = mesh.conn.shape[1]
414
699
  if n_nodes == 10:
415
- basis = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
700
+ basis: Basis3D = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
416
701
  else:
417
702
  basis = make_tet_basis_pytree(intorder)
418
703
  element = None if dim == 1 else ElementVector(dim)
419
- return make_space_pytree(mesh, basis, element)
704
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)