fluxfem 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fluxfem/__init__.py +1 -13
  2. fluxfem/core/__init__.py +53 -71
  3. fluxfem/core/assembly.py +41 -32
  4. fluxfem/core/basis.py +2 -2
  5. fluxfem/core/context_types.py +36 -12
  6. fluxfem/core/mixed_space.py +42 -8
  7. fluxfem/core/mixed_weakform.py +1 -1
  8. fluxfem/core/space.py +68 -28
  9. fluxfem/core/weakform.py +95 -77
  10. fluxfem/mesh/base.py +3 -3
  11. fluxfem/mesh/contact.py +33 -17
  12. fluxfem/mesh/io.py +3 -2
  13. fluxfem/mesh/mortar.py +106 -43
  14. fluxfem/mesh/supermesh.py +2 -0
  15. fluxfem/mesh/surface.py +82 -22
  16. fluxfem/mesh/tet.py +7 -4
  17. fluxfem/physics/elasticity/hyperelastic.py +32 -3
  18. fluxfem/physics/elasticity/linear.py +13 -2
  19. fluxfem/physics/elasticity/stress.py +9 -5
  20. fluxfem/physics/operators.py +12 -5
  21. fluxfem/physics/postprocess.py +29 -3
  22. fluxfem/solver/__init__.py +6 -1
  23. fluxfem/solver/block_matrix.py +165 -13
  24. fluxfem/solver/block_system.py +52 -29
  25. fluxfem/solver/cg.py +43 -30
  26. fluxfem/solver/dirichlet.py +35 -12
  27. fluxfem/solver/history.py +15 -3
  28. fluxfem/solver/newton.py +25 -12
  29. fluxfem/solver/petsc.py +13 -7
  30. fluxfem/solver/preconditioner.py +7 -4
  31. fluxfem/solver/solve_runner.py +42 -24
  32. fluxfem/solver/solver.py +23 -11
  33. fluxfem/solver/sparse.py +32 -13
  34. fluxfem/tools/jit.py +19 -7
  35. fluxfem/tools/timer.py +14 -12
  36. fluxfem/tools/visualizer.py +16 -4
  37. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/METADATA +18 -7
  38. fluxfem-0.2.1.dist-info/RECORD +59 -0
  39. fluxfem-0.2.0.dist-info/RECORD +0 -59
  40. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
  41. {fluxfem-0.2.0.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/space.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import operator
3
3
  from dataclasses import dataclass, field
4
- from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeVar
4
+ from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, cast
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
@@ -16,15 +16,39 @@ if TYPE_CHECKING:
16
16
  ElementLinearKernel,
17
17
  ElementResidualKernel,
18
18
  Kernel,
19
+ FormKernel,
20
+ BilinearReturn,
21
+ JacobianReturn,
22
+ LinearReturn,
23
+ MassReturn,
19
24
  ResidualForm,
20
25
  )
26
+ from .weakform import BilinearForm, LinearForm
27
+ from ..solver import FluxSparseMatrix, SparsityPattern
21
28
  else:
22
29
  Kernel = Callable[..., Any]
30
+ FormKernel = Callable[..., Any]
23
31
  ResidualForm = Callable[..., Any]
24
32
  ElementBilinearKernel = Callable[..., Any]
25
33
  ElementLinearKernel = Callable[..., Any]
26
34
  ElementResidualKernel = Callable[..., Any]
27
35
  ElementJacobianKernel = Callable[..., Any]
36
+ BilinearReturn = Any
37
+ JacobianReturn = Any
38
+ LinearReturn = Any
39
+ MassReturn = Any
40
+ BilinearForm = Any
41
+ LinearForm = Any
42
+ FluxSparseMatrix = Any
43
+ SparsityPattern = Any
44
+
45
+ KernelCacheKey: TypeAlias = tuple[str, int, int, bool]
46
+ KernelCache: TypeAlias = dict[
47
+ KernelCacheKey,
48
+ ElementBilinearKernel | ElementLinearKernel | ElementResidualKernel | ElementJacobianKernel,
49
+ ]
50
+ PatternCache: TypeAlias = dict[bool, SparsityPattern]
51
+ PatternLike: TypeAlias = str | SparsityPattern | None
28
52
 
29
53
  _WARNED_UNTAGGED_KERNELS: set[int] = set()
30
54
 
@@ -93,8 +117,12 @@ class FESpaceBase(Protocol):
93
117
  """
94
118
  elem_dofs: jnp.ndarray
95
119
  value_dim: int
96
- n_dofs: int
97
- n_ldofs: int
120
+
121
+ @property
122
+ def n_dofs(self) -> int: ...
123
+
124
+ @property
125
+ def n_ldofs(self) -> int: ...
98
126
 
99
127
  def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext: ...
100
128
 
@@ -120,8 +148,8 @@ class FESpaceClosure:
120
148
  _n_dofs: int | None = None
121
149
  _n_ldofs: int | None = None
122
150
  data: SpaceData | None = None
123
- _pattern_cache: dict[bool, object] = field(default_factory=dict, repr=False)
124
- _kernel_cache: dict[tuple, object] = field(default_factory=dict, repr=False)
151
+ _pattern_cache: PatternCache = field(default_factory=dict, repr=False)
152
+ _kernel_cache: KernelCache = field(default_factory=dict, repr=False)
125
153
  _elem_rows_cache: jnp.ndarray | None = field(default=None, repr=False)
126
154
 
127
155
  def __post_init__(self):
@@ -232,7 +260,7 @@ class FESpaceClosure:
232
260
  block_sizes=block_sizes,
233
261
  )
234
262
 
235
- def _kernel_cache_key(self, kind: str, form, params, *, jit: bool):
263
+ def _kernel_cache_key(self, kind: str, form, params, *, jit: bool) -> KernelCacheKey | None:
236
264
  if not jit:
237
265
  return None
238
266
  form_key = getattr(form, "__wrapped__", form)
@@ -256,15 +284,15 @@ class FESpaceClosure:
256
284
 
257
285
  def assemble(
258
286
  self,
259
- form,
260
- params=None,
287
+ form: FormKernel[P] | BilinearForm | LinearForm,
288
+ params: P | None = None,
261
289
  *,
262
290
  kind: str | None = None,
263
291
  n_chunks: int | None = None,
264
292
  dep: jnp.ndarray | None = None,
265
293
  jit: bool = True,
266
- pattern: str | object | None = "auto",
267
- kernel=None,
294
+ pattern: PatternLike = "auto",
295
+ kernel: ElementBilinearKernel | ElementLinearKernel | None = None,
268
296
  **kwargs,
269
297
  ):
270
298
  """
@@ -324,13 +352,14 @@ class FESpaceClosure:
324
352
  jit=jit,
325
353
  maker=make_element_bilinear_kernel,
326
354
  )
355
+ form_kernel = cast(FormKernel, form)
327
356
  pattern_use = None
328
357
  if pattern == "auto":
329
358
  pattern_use = self.get_sparsity_pattern(with_idx=True)
330
359
  else:
331
360
  pattern_use = pattern
332
361
  return self.assemble_bilinear_form(
333
- form,
362
+ form_kernel,
334
363
  params,
335
364
  n_chunks=n_chunks,
336
365
  dep=dep,
@@ -347,8 +376,9 @@ class FESpaceClosure:
347
376
  jit=jit,
348
377
  maker=make_element_linear_kernel,
349
378
  )
379
+ form_kernel = cast(FormKernel, form)
350
380
  return self.assemble_linear_form(
351
- form,
381
+ form_kernel,
352
382
  params,
353
383
  n_chunks=n_chunks,
354
384
  dep=dep,
@@ -360,14 +390,14 @@ class FESpaceClosure:
360
390
  # --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
361
391
  def assemble_bilinear_form(
362
392
  self,
363
- form: Kernel[P],
393
+ form: FormKernel[P],
364
394
  params: P,
365
395
  *,
366
396
  n_chunks: int | None = None,
367
397
  dep: jnp.ndarray | None = None,
368
398
  kernel: ElementBilinearKernel | None = None,
369
399
  **kwargs,
370
- ):
400
+ ) -> FluxSparseMatrix:
371
401
  """Assemble bilinear form; kernel(ctx) -> (n_ldofs, n_ldofs) if provided."""
372
402
  from .assembly import assemble_bilinear_form
373
403
  if "pattern" not in kwargs or kwargs.get("pattern") is None:
@@ -378,29 +408,39 @@ class FESpaceClosure:
378
408
 
379
409
  def assemble_linear_form(
380
410
  self,
381
- form: Kernel[P],
411
+ form: FormKernel[P],
382
412
  params: P,
383
413
  *,
384
414
  n_chunks: int | None = None,
385
415
  dep: jnp.ndarray | None = None,
386
416
  kernel: ElementLinearKernel | None = None,
387
417
  **kwargs,
388
- ):
418
+ ) -> LinearReturn:
389
419
  """Assemble linear form; kernel(ctx) -> (n_ldofs,) if provided."""
390
420
  from .assembly import assemble_linear_form
391
421
  return assemble_linear_form(
392
422
  self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
393
423
  )
394
424
 
395
- def assemble_functional(self, form, params):
425
+ def assemble_functional(self, form: FormKernel[P], params: P) -> jnp.ndarray:
396
426
  from .assembly import assemble_functional
397
427
  return assemble_functional(self, form, params)
398
428
 
399
- def assemble_mass_matrix(self, *, n_chunks=None, **kwargs):
429
+ def assemble_mass_matrix(
430
+ self,
431
+ *,
432
+ n_chunks: int | None = None,
433
+ **kwargs,
434
+ ) -> MassReturn:
400
435
  from .assembly import assemble_mass_matrix
401
436
  return assemble_mass_matrix(self, n_chunks=n_chunks, **kwargs)
402
437
 
403
- def assemble_bilinear_dense(self, kernel, params, **kwargs):
438
+ def assemble_bilinear_dense(
439
+ self,
440
+ kernel: FormKernel[P],
441
+ params: P,
442
+ **kwargs,
443
+ ) -> BilinearReturn:
404
444
  from .assembly import assemble_bilinear_dense
405
445
  return assemble_bilinear_dense(self, kernel, params, **kwargs)
406
446
 
@@ -412,7 +452,7 @@ class FESpaceClosure:
412
452
  *,
413
453
  kernel: ElementResidualKernel | None = None,
414
454
  **kwargs,
415
- ):
455
+ ) -> LinearReturn:
416
456
  """Assemble residual; kernel(ctx, u_elem) -> (n_ldofs,) if provided."""
417
457
  from .assembly import assemble_residual
418
458
  return assemble_residual(self, res_form, u, params, kernel=kernel, **kwargs)
@@ -425,7 +465,7 @@ class FESpaceClosure:
425
465
  *,
426
466
  kernel: ElementJacobianKernel | None = None,
427
467
  **kwargs,
428
- ):
468
+ ) -> JacobianReturn:
429
469
  """Assemble Jacobian; kernel(u_elem, ctx) -> (n_ldofs, n_ldofs) if provided."""
430
470
  from .assembly import assemble_jacobian
431
471
  return assemble_jacobian(self, res_form, u, params, kernel=kernel, **kwargs)
@@ -585,7 +625,7 @@ def make_tet10_space_pytree(
585
625
  """Create a pytree quadratic tet space (10-node elements)."""
586
626
  basis = make_tet10_basis_pytree(intorder)
587
627
  element = None if dim == 1 else ElementVector(dim)
588
- return make_space_pytree(mesh, basis, element)
628
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
589
629
 
590
630
 
591
631
  def make_hex_space(mesh: HexMesh, dim: int = 1, intorder: int = 2) -> FESpace:
@@ -601,7 +641,7 @@ def make_hex_space_pytree(
601
641
  """Create a pytree trilinear hex space (8-node elements)."""
602
642
  basis = make_hex_basis_pytree(intorder)
603
643
  element = None if dim == 1 else ElementVector(dim)
604
- return make_space_pytree(mesh, basis, element)
644
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
605
645
 
606
646
 
607
647
  def make_hex20_space(
@@ -619,7 +659,7 @@ def make_hex20_space_pytree(
619
659
  """Create a pytree serendipity hex space (20-node elements)."""
620
660
  basis = make_hex20_basis_pytree(intorder)
621
661
  element = None if dim == 1 else ElementVector(dim)
622
- return make_space_pytree(mesh, basis, element)
662
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
623
663
 
624
664
 
625
665
  def make_hex27_space(
@@ -637,14 +677,14 @@ def make_hex27_space_pytree(
637
677
  """Create a pytree triquadratic hex space (27-node elements)."""
638
678
  basis = make_hex27_basis_pytree(intorder)
639
679
  element = None if dim == 1 else ElementVector(dim)
640
- return make_space_pytree(mesh, basis, element)
680
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
641
681
 
642
682
 
643
683
  def make_tet_space(mesh: TetMesh, dim: int = 1, intorder: int = 2) -> FESpace:
644
684
  """Create a linear or quadratic tet space based on mesh nodes."""
645
685
  n_nodes = mesh.conn.shape[1]
646
686
  if n_nodes == 10:
647
- basis = make_tet10_basis(intorder if intorder > 1 else 2)
687
+ basis: Basis3D = make_tet10_basis(intorder if intorder > 1 else 2)
648
688
  else:
649
689
  basis = make_tet_basis(intorder)
650
690
  element = None if dim == 1 else ElementVector(dim)
@@ -657,8 +697,8 @@ def make_tet_space_pytree(
657
697
  """Create a pytree linear or quadratic tet space based on mesh nodes."""
658
698
  n_nodes = mesh.conn.shape[1]
659
699
  if n_nodes == 10:
660
- basis = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
700
+ basis: Basis3D = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
661
701
  else:
662
702
  basis = make_tet_basis_pytree(intorder)
663
703
  element = None if dim == 1 else ElementVector(dim)
664
- return make_space_pytree(mesh, basis, element)
704
+ return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)