fluxfem 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +136 -161
- fluxfem/core/__init__.py +172 -41
- fluxfem/core/assembly.py +676 -91
- fluxfem/core/basis.py +73 -52
- fluxfem/core/context_types.py +36 -0
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +15 -1
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +348 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +262 -17
- fluxfem/core/weakform.py +1503 -312
- fluxfem/helpers_wf.py +53 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +322 -8
- fluxfem/mesh/contact.py +825 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +18 -16
- fluxfem/mesh/io.py +8 -4
- fluxfem/mesh/mortar.py +3907 -0
- fluxfem/mesh/supermesh.py +316 -0
- fluxfem/mesh/surface.py +22 -4
- fluxfem/mesh/tet.py +10 -4
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +3 -0
- fluxfem/physics/elasticity/linear.py +9 -2
- fluxfem/solver/__init__.py +42 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +132 -0
- fluxfem/solver/block_system.py +454 -0
- fluxfem/solver/cg.py +115 -33
- fluxfem/solver/dirichlet.py +334 -4
- fluxfem/solver/newton.py +237 -60
- fluxfem/solver/petsc.py +439 -0
- fluxfem/solver/preconditioner.py +106 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +168 -1
- fluxfem/solver/solver.py +12 -1
- fluxfem/solver/sparse.py +124 -9
- fluxfem-0.2.0.dist-info/METADATA +303 -0
- fluxfem-0.2.0.dist-info/RECORD +59 -0
- fluxfem-0.1.3.dist-info/METADATA +0 -125
- fluxfem-0.1.3.dist-info/RECORD +0 -47
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.3.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/core/space.py
CHANGED
|
@@ -1,11 +1,47 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
import operator
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Protocol
|
|
4
|
+
from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeVar
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import numpy as np
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
P = TypeVar("P")
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .assembly import (
|
|
14
|
+
ElementBilinearKernel,
|
|
15
|
+
ElementJacobianKernel,
|
|
16
|
+
ElementLinearKernel,
|
|
17
|
+
ElementResidualKernel,
|
|
18
|
+
Kernel,
|
|
19
|
+
ResidualForm,
|
|
20
|
+
)
|
|
21
|
+
else:
|
|
22
|
+
Kernel = Callable[..., Any]
|
|
23
|
+
ResidualForm = Callable[..., Any]
|
|
24
|
+
ElementBilinearKernel = Callable[..., Any]
|
|
25
|
+
ElementLinearKernel = Callable[..., Any]
|
|
26
|
+
ElementResidualKernel = Callable[..., Any]
|
|
27
|
+
ElementJacobianKernel = Callable[..., Any]
|
|
28
|
+
|
|
29
|
+
_WARNED_UNTAGGED_KERNELS: set[int] = set()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _warn_untagged_kernel(form) -> None:
|
|
33
|
+
key = id(form)
|
|
34
|
+
if key in _WARNED_UNTAGGED_KERNELS:
|
|
35
|
+
return
|
|
36
|
+
_WARNED_UNTAGGED_KERNELS.add(key)
|
|
37
|
+
warnings.warn(
|
|
38
|
+
"Raw kernel has no _ff_kind metadata; prefer tagging with ff.kernel(...) or "
|
|
39
|
+
"set _ff_kind/_ff_domain on the callable.",
|
|
40
|
+
category=UserWarning,
|
|
41
|
+
stacklevel=3,
|
|
42
|
+
)
|
|
8
43
|
|
|
44
|
+
from .dtypes import INDEX_DTYPE
|
|
9
45
|
from ..mesh import (
|
|
10
46
|
BaseMesh,
|
|
11
47
|
BaseMeshPytree,
|
|
@@ -79,12 +115,14 @@ class FESpaceClosure:
|
|
|
79
115
|
"""
|
|
80
116
|
mesh: BaseMesh
|
|
81
117
|
basis: Basis3D
|
|
82
|
-
elem_dofs: jnp.ndarray # (n_elems, n_ldofs)
|
|
118
|
+
elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int64
|
|
83
119
|
value_dim: int = 1 # 1=scalar, 3=vector, etc.
|
|
84
120
|
_n_dofs: int | None = None
|
|
85
121
|
_n_ldofs: int | None = None
|
|
86
122
|
data: SpaceData | None = None
|
|
87
123
|
_pattern_cache: dict[bool, object] = field(default_factory=dict, repr=False)
|
|
124
|
+
_kernel_cache: dict[tuple, object] = field(default_factory=dict, repr=False)
|
|
125
|
+
_elem_rows_cache: jnp.ndarray | None = field(default=None, repr=False)
|
|
88
126
|
|
|
89
127
|
def __post_init__(self):
|
|
90
128
|
# Ensure value_dim is a Python int (avoid tracers).
|
|
@@ -116,6 +154,15 @@ class FESpaceClosure:
|
|
|
116
154
|
assert self._n_ldofs is not None
|
|
117
155
|
return self._n_ldofs
|
|
118
156
|
|
|
157
|
+
def get_elem_rows(self) -> jnp.ndarray:
|
|
158
|
+
cached = self._elem_rows_cache
|
|
159
|
+
if cached is not None:
|
|
160
|
+
return cached
|
|
161
|
+
rows = self.elem_dofs.reshape(-1)
|
|
162
|
+
if jax.core.trace_ctx.is_top_level():
|
|
163
|
+
self._elem_rows_cache = rows
|
|
164
|
+
return rows
|
|
165
|
+
|
|
119
166
|
def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext:
|
|
120
167
|
def _tie_in(x, y):
|
|
121
168
|
if x is None:
|
|
@@ -148,43 +195,240 @@ class FESpaceClosure:
|
|
|
148
195
|
)
|
|
149
196
|
|
|
150
197
|
test = jax.vmap(make_field)(elem_coords)
|
|
151
|
-
trial
|
|
198
|
+
# Test/trial share the same field data for single-space bilinear forms.
|
|
199
|
+
trial = test
|
|
152
200
|
|
|
153
201
|
return FormContext(
|
|
154
202
|
test=test, trial=trial, x_q=x_q,
|
|
155
203
|
w=w, elem_id=jnp.arange(elem_coords.shape[0])
|
|
156
204
|
)
|
|
157
205
|
|
|
206
|
+
def make_batched_assembler(self, *, dep: jnp.ndarray | None = None, pattern=None):
|
|
207
|
+
from .assembly import BatchedAssembler
|
|
208
|
+
if pattern is None:
|
|
209
|
+
pattern = self.get_sparsity_pattern(with_idx=True)
|
|
210
|
+
return BatchedAssembler.from_space(self, dep=dep, pattern=pattern)
|
|
211
|
+
|
|
212
|
+
def build_cg_operator(
|
|
213
|
+
self,
|
|
214
|
+
A,
|
|
215
|
+
*,
|
|
216
|
+
matvec: str = "flux",
|
|
217
|
+
preconditioner=None,
|
|
218
|
+
solver: str = "cg",
|
|
219
|
+
dof_per_node: int | None = None,
|
|
220
|
+
block_sizes=None,
|
|
221
|
+
):
|
|
222
|
+
from ..solver.cg import build_cg_operator
|
|
223
|
+
|
|
224
|
+
if dof_per_node is None:
|
|
225
|
+
dof_per_node = int(self.value_dim)
|
|
226
|
+
return build_cg_operator(
|
|
227
|
+
A,
|
|
228
|
+
matvec=matvec,
|
|
229
|
+
preconditioner=preconditioner,
|
|
230
|
+
solver=solver,
|
|
231
|
+
dof_per_node=dof_per_node,
|
|
232
|
+
block_sizes=block_sizes,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def _kernel_cache_key(self, kind: str, form, params, *, jit: bool):
|
|
236
|
+
if not jit:
|
|
237
|
+
return None
|
|
238
|
+
form_key = getattr(form, "__wrapped__", form)
|
|
239
|
+
try:
|
|
240
|
+
params_key = hash(params)
|
|
241
|
+
except Exception:
|
|
242
|
+
return None
|
|
243
|
+
return (kind, id(form_key), params_key, True)
|
|
244
|
+
|
|
245
|
+
def _get_cached_kernel(self, kind: str, form, params, *, jit: bool, maker):
|
|
246
|
+
key = self._kernel_cache_key(kind, form, params, jit=jit)
|
|
247
|
+
if key is None:
|
|
248
|
+
return maker(form, params, jit=jit)
|
|
249
|
+
cached = self._kernel_cache.get(key)
|
|
250
|
+
if cached is not None:
|
|
251
|
+
return cached
|
|
252
|
+
kernel = maker(form, params, jit=jit)
|
|
253
|
+
if jax.core.trace_ctx.is_top_level():
|
|
254
|
+
self._kernel_cache[key] = kernel
|
|
255
|
+
return kernel
|
|
256
|
+
|
|
257
|
+
def assemble(
|
|
258
|
+
self,
|
|
259
|
+
form,
|
|
260
|
+
params=None,
|
|
261
|
+
*,
|
|
262
|
+
kind: str | None = None,
|
|
263
|
+
n_chunks: int | None = None,
|
|
264
|
+
dep: jnp.ndarray | None = None,
|
|
265
|
+
jit: bool = True,
|
|
266
|
+
pattern: str | object | None = "auto",
|
|
267
|
+
kernel=None,
|
|
268
|
+
**kwargs,
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
High-level assembly entry point with optional kernel caching.
|
|
272
|
+
|
|
273
|
+
kind: "bilinear" or "linear". If None, inferred from LinearForm/BilinearForm
|
|
274
|
+
or compiled/kernels tagged with _ff_kind metadata.
|
|
275
|
+
pattern: "auto" to reuse cached sparsity pattern for bilinear assembly.
|
|
276
|
+
"""
|
|
277
|
+
from .weakform import BilinearForm, LinearForm
|
|
278
|
+
from .assembly import make_element_bilinear_kernel, make_element_linear_kernel
|
|
279
|
+
|
|
280
|
+
if kind is None:
|
|
281
|
+
if isinstance(form, BilinearForm):
|
|
282
|
+
kind = "bilinear"
|
|
283
|
+
form = form.get_compiled()
|
|
284
|
+
elif isinstance(form, LinearForm):
|
|
285
|
+
kind = "linear"
|
|
286
|
+
form = form.get_compiled()
|
|
287
|
+
else:
|
|
288
|
+
inferred_kind = getattr(form, "_ff_kind", None)
|
|
289
|
+
inferred_domain = getattr(form, "_ff_domain", None)
|
|
290
|
+
if inferred_kind is None:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
f"kind is required for raw kernels without metadata (got {form!r}). "
|
|
293
|
+
"Use @ff.kernel(kind=..., domain=...) or pass kind=."
|
|
294
|
+
)
|
|
295
|
+
if inferred_domain not in (None, "volume"):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
|
|
298
|
+
"Use assemble_surface_linear_form or assemble_surface_bilinear_form."
|
|
299
|
+
)
|
|
300
|
+
kind = inferred_kind
|
|
301
|
+
else:
|
|
302
|
+
inferred_kind = getattr(form, "_ff_kind", None)
|
|
303
|
+
inferred_domain = getattr(form, "_ff_domain", None)
|
|
304
|
+
if inferred_kind is None:
|
|
305
|
+
_warn_untagged_kernel(form)
|
|
306
|
+
if inferred_kind is not None and inferred_kind != kind:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"assemble kind '{kind}' does not match form kind '{inferred_kind}' "
|
|
309
|
+
f"for {form!r}. "
|
|
310
|
+
"Align kind= with the kernel metadata (or retag the kernel)."
|
|
311
|
+
)
|
|
312
|
+
if inferred_domain not in (None, "volume"):
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
|
|
315
|
+
"Use assemble_surface_linear_form or assemble_surface_bilinear_form."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if kind == "bilinear":
|
|
319
|
+
if kernel is None:
|
|
320
|
+
kernel = self._get_cached_kernel(
|
|
321
|
+
"bilinear",
|
|
322
|
+
form,
|
|
323
|
+
params,
|
|
324
|
+
jit=jit,
|
|
325
|
+
maker=make_element_bilinear_kernel,
|
|
326
|
+
)
|
|
327
|
+
pattern_use = None
|
|
328
|
+
if pattern == "auto":
|
|
329
|
+
pattern_use = self.get_sparsity_pattern(with_idx=True)
|
|
330
|
+
else:
|
|
331
|
+
pattern_use = pattern
|
|
332
|
+
return self.assemble_bilinear_form(
|
|
333
|
+
form,
|
|
334
|
+
params,
|
|
335
|
+
n_chunks=n_chunks,
|
|
336
|
+
dep=dep,
|
|
337
|
+
kernel=kernel,
|
|
338
|
+
pattern=pattern_use,
|
|
339
|
+
**kwargs,
|
|
340
|
+
)
|
|
341
|
+
if kind == "linear":
|
|
342
|
+
if kernel is None:
|
|
343
|
+
kernel = self._get_cached_kernel(
|
|
344
|
+
"linear",
|
|
345
|
+
form,
|
|
346
|
+
params,
|
|
347
|
+
jit=jit,
|
|
348
|
+
maker=make_element_linear_kernel,
|
|
349
|
+
)
|
|
350
|
+
return self.assemble_linear_form(
|
|
351
|
+
form,
|
|
352
|
+
params,
|
|
353
|
+
n_chunks=n_chunks,
|
|
354
|
+
dep=dep,
|
|
355
|
+
kernel=kernel,
|
|
356
|
+
**kwargs,
|
|
357
|
+
)
|
|
358
|
+
raise ValueError(f"Unsupported assemble kind: {kind}")
|
|
359
|
+
|
|
158
360
|
# --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
|
|
159
|
-
def assemble_bilinear_form(
|
|
361
|
+
def assemble_bilinear_form(
|
|
362
|
+
self,
|
|
363
|
+
form: Kernel[P],
|
|
364
|
+
params: P,
|
|
365
|
+
*,
|
|
366
|
+
n_chunks: int | None = None,
|
|
367
|
+
dep: jnp.ndarray | None = None,
|
|
368
|
+
kernel: ElementBilinearKernel | None = None,
|
|
369
|
+
**kwargs,
|
|
370
|
+
):
|
|
371
|
+
"""Assemble bilinear form; kernel(ctx) -> (n_ldofs, n_ldofs) if provided."""
|
|
160
372
|
from .assembly import assemble_bilinear_form
|
|
161
373
|
if "pattern" not in kwargs or kwargs.get("pattern") is None:
|
|
162
374
|
kwargs["pattern"] = self.get_sparsity_pattern(with_idx=True)
|
|
163
|
-
return assemble_bilinear_form(
|
|
375
|
+
return assemble_bilinear_form(
|
|
376
|
+
self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
|
|
377
|
+
)
|
|
164
378
|
|
|
165
|
-
def assemble_linear_form(
|
|
379
|
+
def assemble_linear_form(
|
|
380
|
+
self,
|
|
381
|
+
form: Kernel[P],
|
|
382
|
+
params: P,
|
|
383
|
+
*,
|
|
384
|
+
n_chunks: int | None = None,
|
|
385
|
+
dep: jnp.ndarray | None = None,
|
|
386
|
+
kernel: ElementLinearKernel | None = None,
|
|
387
|
+
**kwargs,
|
|
388
|
+
):
|
|
389
|
+
"""Assemble linear form; kernel(ctx) -> (n_ldofs,) if provided."""
|
|
166
390
|
from .assembly import assemble_linear_form
|
|
167
|
-
return assemble_linear_form(
|
|
391
|
+
return assemble_linear_form(
|
|
392
|
+
self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
|
|
393
|
+
)
|
|
168
394
|
|
|
169
395
|
def assemble_functional(self, form, params):
|
|
170
396
|
from .assembly import assemble_functional
|
|
171
397
|
return assemble_functional(self, form, params)
|
|
172
398
|
|
|
173
|
-
def assemble_mass_matrix(self, *,
|
|
399
|
+
def assemble_mass_matrix(self, *, n_chunks=None, **kwargs):
|
|
174
400
|
from .assembly import assemble_mass_matrix
|
|
175
|
-
return assemble_mass_matrix(self,
|
|
401
|
+
return assemble_mass_matrix(self, n_chunks=n_chunks, **kwargs)
|
|
176
402
|
|
|
177
403
|
def assemble_bilinear_dense(self, kernel, params, **kwargs):
|
|
178
404
|
from .assembly import assemble_bilinear_dense
|
|
179
405
|
return assemble_bilinear_dense(self, kernel, params, **kwargs)
|
|
180
406
|
|
|
181
|
-
def assemble_residual(
|
|
407
|
+
def assemble_residual(
|
|
408
|
+
self,
|
|
409
|
+
res_form: ResidualForm[P],
|
|
410
|
+
u: jnp.ndarray,
|
|
411
|
+
params: P,
|
|
412
|
+
*,
|
|
413
|
+
kernel: ElementResidualKernel | None = None,
|
|
414
|
+
**kwargs,
|
|
415
|
+
):
|
|
416
|
+
"""Assemble residual; kernel(ctx, u_elem) -> (n_ldofs,) if provided."""
|
|
182
417
|
from .assembly import assemble_residual
|
|
183
|
-
return assemble_residual(self, res_form, u, params, **kwargs)
|
|
184
|
-
|
|
185
|
-
def assemble_jacobian(
|
|
418
|
+
return assemble_residual(self, res_form, u, params, kernel=kernel, **kwargs)
|
|
419
|
+
|
|
420
|
+
def assemble_jacobian(
|
|
421
|
+
self,
|
|
422
|
+
res_form: ResidualForm[P],
|
|
423
|
+
u: jnp.ndarray,
|
|
424
|
+
params: P,
|
|
425
|
+
*,
|
|
426
|
+
kernel: ElementJacobianKernel | None = None,
|
|
427
|
+
**kwargs,
|
|
428
|
+
):
|
|
429
|
+
"""Assemble Jacobian; kernel(u_elem, ctx) -> (n_ldofs, n_ldofs) if provided."""
|
|
186
430
|
from .assembly import assemble_jacobian
|
|
187
|
-
return assemble_jacobian(self, res_form, u, params, **kwargs)
|
|
431
|
+
return assemble_jacobian(self, res_form, u, params, kernel=kernel, **kwargs)
|
|
188
432
|
|
|
189
433
|
def get_sparsity_pattern(self, *, with_idx: bool = True):
|
|
190
434
|
cached = self._pattern_cache.get(with_idx)
|
|
@@ -192,7 +436,8 @@ class FESpaceClosure:
|
|
|
192
436
|
return cached
|
|
193
437
|
from .assembly import make_sparsity_pattern
|
|
194
438
|
pat = make_sparsity_pattern(self, with_idx=with_idx)
|
|
195
|
-
|
|
439
|
+
if jax.core.trace_ctx.is_top_level():
|
|
440
|
+
self._pattern_cache[with_idx] = pat
|
|
196
441
|
return pat
|
|
197
442
|
|
|
198
443
|
|
|
@@ -251,7 +496,7 @@ def make_space(
|
|
|
251
496
|
return FESpace(
|
|
252
497
|
mesh=mesh,
|
|
253
498
|
basis=basis,
|
|
254
|
-
elem_dofs=jnp.asarray(elem_dofs, dtype=
|
|
499
|
+
elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
|
|
255
500
|
value_dim=value_dim
|
|
256
501
|
)
|
|
257
502
|
|
|
@@ -320,7 +565,7 @@ def make_space_pytree(
|
|
|
320
565
|
return FESpacePytree(
|
|
321
566
|
mesh=mesh_py,
|
|
322
567
|
basis=basis_py,
|
|
323
|
-
elem_dofs=jnp.asarray(elem_dofs, dtype=
|
|
568
|
+
elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
|
|
324
569
|
value_dim=value_dim,
|
|
325
570
|
)
|
|
326
571
|
|