fluxfem 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluxfem/__init__.py +69 -13
- fluxfem/core/__init__.py +140 -53
- fluxfem/core/assembly.py +691 -97
- fluxfem/core/basis.py +75 -54
- fluxfem/core/context_types.py +36 -12
- fluxfem/core/dtypes.py +9 -1
- fluxfem/core/forms.py +10 -0
- fluxfem/core/mixed_assembly.py +263 -0
- fluxfem/core/mixed_space.py +382 -0
- fluxfem/core/mixed_weakform.py +97 -0
- fluxfem/core/solver.py +2 -0
- fluxfem/core/space.py +315 -30
- fluxfem/core/weakform.py +821 -42
- fluxfem/helpers_wf.py +49 -0
- fluxfem/mesh/__init__.py +54 -2
- fluxfem/mesh/base.py +318 -9
- fluxfem/mesh/contact.py +841 -0
- fluxfem/mesh/dtypes.py +12 -0
- fluxfem/mesh/hex.py +17 -16
- fluxfem/mesh/io.py +9 -6
- fluxfem/mesh/mortar.py +3970 -0
- fluxfem/mesh/supermesh.py +318 -0
- fluxfem/mesh/surface.py +104 -26
- fluxfem/mesh/tet.py +16 -7
- fluxfem/physics/diffusion.py +3 -0
- fluxfem/physics/elasticity/hyperelastic.py +35 -3
- fluxfem/physics/elasticity/linear.py +22 -4
- fluxfem/physics/elasticity/stress.py +9 -5
- fluxfem/physics/operators.py +12 -5
- fluxfem/physics/postprocess.py +29 -3
- fluxfem/solver/__init__.py +47 -2
- fluxfem/solver/bc.py +38 -2
- fluxfem/solver/block_matrix.py +284 -0
- fluxfem/solver/block_system.py +477 -0
- fluxfem/solver/cg.py +150 -55
- fluxfem/solver/dirichlet.py +358 -5
- fluxfem/solver/history.py +15 -3
- fluxfem/solver/newton.py +260 -70
- fluxfem/solver/petsc.py +445 -0
- fluxfem/solver/preconditioner.py +109 -0
- fluxfem/solver/result.py +18 -0
- fluxfem/solver/solve_runner.py +208 -23
- fluxfem/solver/solver.py +35 -12
- fluxfem/solver/sparse.py +149 -15
- fluxfem/tools/jit.py +19 -7
- fluxfem/tools/timer.py +14 -12
- fluxfem/tools/visualizer.py +16 -4
- fluxfem-0.2.1.dist-info/METADATA +314 -0
- fluxfem-0.2.1.dist-info/RECORD +59 -0
- fluxfem-0.1.4.dist-info/METADATA +0 -127
- fluxfem-0.1.4.dist-info/RECORD +0 -48
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/LICENSE +0 -0
- {fluxfem-0.1.4.dist-info → fluxfem-0.2.1.dist-info}/WHEEL +0 -0
fluxfem/core/space.py
CHANGED
|
@@ -1,11 +1,71 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
import operator
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Protocol
|
|
4
|
+
from typing import Any, Callable, Protocol, TYPE_CHECKING, TypeAlias, TypeVar, cast
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import numpy as np
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
P = TypeVar("P")
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .assembly import (
|
|
14
|
+
ElementBilinearKernel,
|
|
15
|
+
ElementJacobianKernel,
|
|
16
|
+
ElementLinearKernel,
|
|
17
|
+
ElementResidualKernel,
|
|
18
|
+
Kernel,
|
|
19
|
+
FormKernel,
|
|
20
|
+
BilinearReturn,
|
|
21
|
+
JacobianReturn,
|
|
22
|
+
LinearReturn,
|
|
23
|
+
MassReturn,
|
|
24
|
+
ResidualForm,
|
|
25
|
+
)
|
|
26
|
+
from .weakform import BilinearForm, LinearForm
|
|
27
|
+
from ..solver import FluxSparseMatrix, SparsityPattern
|
|
28
|
+
else:
|
|
29
|
+
Kernel = Callable[..., Any]
|
|
30
|
+
FormKernel = Callable[..., Any]
|
|
31
|
+
ResidualForm = Callable[..., Any]
|
|
32
|
+
ElementBilinearKernel = Callable[..., Any]
|
|
33
|
+
ElementLinearKernel = Callable[..., Any]
|
|
34
|
+
ElementResidualKernel = Callable[..., Any]
|
|
35
|
+
ElementJacobianKernel = Callable[..., Any]
|
|
36
|
+
BilinearReturn = Any
|
|
37
|
+
JacobianReturn = Any
|
|
38
|
+
LinearReturn = Any
|
|
39
|
+
MassReturn = Any
|
|
40
|
+
BilinearForm = Any
|
|
41
|
+
LinearForm = Any
|
|
42
|
+
FluxSparseMatrix = Any
|
|
43
|
+
SparsityPattern = Any
|
|
44
|
+
|
|
45
|
+
KernelCacheKey: TypeAlias = tuple[str, int, int, bool]
|
|
46
|
+
KernelCache: TypeAlias = dict[
|
|
47
|
+
KernelCacheKey,
|
|
48
|
+
ElementBilinearKernel | ElementLinearKernel | ElementResidualKernel | ElementJacobianKernel,
|
|
49
|
+
]
|
|
50
|
+
PatternCache: TypeAlias = dict[bool, SparsityPattern]
|
|
51
|
+
PatternLike: TypeAlias = str | SparsityPattern | None
|
|
52
|
+
|
|
53
|
+
_WARNED_UNTAGGED_KERNELS: set[int] = set()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _warn_untagged_kernel(form) -> None:
|
|
57
|
+
key = id(form)
|
|
58
|
+
if key in _WARNED_UNTAGGED_KERNELS:
|
|
59
|
+
return
|
|
60
|
+
_WARNED_UNTAGGED_KERNELS.add(key)
|
|
61
|
+
warnings.warn(
|
|
62
|
+
"Raw kernel has no _ff_kind metadata; prefer tagging with ff.kernel(...) or "
|
|
63
|
+
"set _ff_kind/_ff_domain on the callable.",
|
|
64
|
+
category=UserWarning,
|
|
65
|
+
stacklevel=3,
|
|
66
|
+
)
|
|
8
67
|
|
|
68
|
+
from .dtypes import INDEX_DTYPE
|
|
9
69
|
from ..mesh import (
|
|
10
70
|
BaseMesh,
|
|
11
71
|
BaseMeshPytree,
|
|
@@ -57,8 +117,12 @@ class FESpaceBase(Protocol):
|
|
|
57
117
|
"""
|
|
58
118
|
elem_dofs: jnp.ndarray
|
|
59
119
|
value_dim: int
|
|
60
|
-
|
|
61
|
-
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def n_dofs(self) -> int: ...
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def n_ldofs(self) -> int: ...
|
|
62
126
|
|
|
63
127
|
def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext: ...
|
|
64
128
|
|
|
@@ -79,12 +143,14 @@ class FESpaceClosure:
|
|
|
79
143
|
"""
|
|
80
144
|
mesh: BaseMesh
|
|
81
145
|
basis: Basis3D
|
|
82
|
-
elem_dofs: jnp.ndarray # (n_elems, n_ldofs)
|
|
146
|
+
elem_dofs: jnp.ndarray # (n_elems, n_ldofs) int64
|
|
83
147
|
value_dim: int = 1 # 1=scalar, 3=vector, etc.
|
|
84
148
|
_n_dofs: int | None = None
|
|
85
149
|
_n_ldofs: int | None = None
|
|
86
150
|
data: SpaceData | None = None
|
|
87
|
-
_pattern_cache:
|
|
151
|
+
_pattern_cache: PatternCache = field(default_factory=dict, repr=False)
|
|
152
|
+
_kernel_cache: KernelCache = field(default_factory=dict, repr=False)
|
|
153
|
+
_elem_rows_cache: jnp.ndarray | None = field(default=None, repr=False)
|
|
88
154
|
|
|
89
155
|
def __post_init__(self):
|
|
90
156
|
# Ensure value_dim is a Python int (avoid tracers).
|
|
@@ -116,6 +182,15 @@ class FESpaceClosure:
|
|
|
116
182
|
assert self._n_ldofs is not None
|
|
117
183
|
return self._n_ldofs
|
|
118
184
|
|
|
185
|
+
def get_elem_rows(self) -> jnp.ndarray:
|
|
186
|
+
cached = self._elem_rows_cache
|
|
187
|
+
if cached is not None:
|
|
188
|
+
return cached
|
|
189
|
+
rows = self.elem_dofs.reshape(-1)
|
|
190
|
+
if jax.core.trace_ctx.is_top_level():
|
|
191
|
+
self._elem_rows_cache = rows
|
|
192
|
+
return rows
|
|
193
|
+
|
|
119
194
|
def build_form_contexts(self, dep: jnp.ndarray | None = None) -> FormContext:
|
|
120
195
|
def _tie_in(x, y):
|
|
121
196
|
if x is None:
|
|
@@ -148,43 +223,252 @@ class FESpaceClosure:
|
|
|
148
223
|
)
|
|
149
224
|
|
|
150
225
|
test = jax.vmap(make_field)(elem_coords)
|
|
151
|
-
trial
|
|
226
|
+
# Test/trial share the same field data for single-space bilinear forms.
|
|
227
|
+
trial = test
|
|
152
228
|
|
|
153
229
|
return FormContext(
|
|
154
230
|
test=test, trial=trial, x_q=x_q,
|
|
155
231
|
w=w, elem_id=jnp.arange(elem_coords.shape[0])
|
|
156
232
|
)
|
|
157
233
|
|
|
234
|
+
def make_batched_assembler(self, *, dep: jnp.ndarray | None = None, pattern=None):
|
|
235
|
+
from .assembly import BatchedAssembler
|
|
236
|
+
if pattern is None:
|
|
237
|
+
pattern = self.get_sparsity_pattern(with_idx=True)
|
|
238
|
+
return BatchedAssembler.from_space(self, dep=dep, pattern=pattern)
|
|
239
|
+
|
|
240
|
+
def build_cg_operator(
|
|
241
|
+
self,
|
|
242
|
+
A,
|
|
243
|
+
*,
|
|
244
|
+
matvec: str = "flux",
|
|
245
|
+
preconditioner=None,
|
|
246
|
+
solver: str = "cg",
|
|
247
|
+
dof_per_node: int | None = None,
|
|
248
|
+
block_sizes=None,
|
|
249
|
+
):
|
|
250
|
+
from ..solver.cg import build_cg_operator
|
|
251
|
+
|
|
252
|
+
if dof_per_node is None:
|
|
253
|
+
dof_per_node = int(self.value_dim)
|
|
254
|
+
return build_cg_operator(
|
|
255
|
+
A,
|
|
256
|
+
matvec=matvec,
|
|
257
|
+
preconditioner=preconditioner,
|
|
258
|
+
solver=solver,
|
|
259
|
+
dof_per_node=dof_per_node,
|
|
260
|
+
block_sizes=block_sizes,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def _kernel_cache_key(self, kind: str, form, params, *, jit: bool) -> KernelCacheKey | None:
|
|
264
|
+
if not jit:
|
|
265
|
+
return None
|
|
266
|
+
form_key = getattr(form, "__wrapped__", form)
|
|
267
|
+
try:
|
|
268
|
+
params_key = hash(params)
|
|
269
|
+
except Exception:
|
|
270
|
+
return None
|
|
271
|
+
return (kind, id(form_key), params_key, True)
|
|
272
|
+
|
|
273
|
+
def _get_cached_kernel(self, kind: str, form, params, *, jit: bool, maker):
|
|
274
|
+
key = self._kernel_cache_key(kind, form, params, jit=jit)
|
|
275
|
+
if key is None:
|
|
276
|
+
return maker(form, params, jit=jit)
|
|
277
|
+
cached = self._kernel_cache.get(key)
|
|
278
|
+
if cached is not None:
|
|
279
|
+
return cached
|
|
280
|
+
kernel = maker(form, params, jit=jit)
|
|
281
|
+
if jax.core.trace_ctx.is_top_level():
|
|
282
|
+
self._kernel_cache[key] = kernel
|
|
283
|
+
return kernel
|
|
284
|
+
|
|
285
|
+
def assemble(
|
|
286
|
+
self,
|
|
287
|
+
form: FormKernel[P] | BilinearForm | LinearForm,
|
|
288
|
+
params: P | None = None,
|
|
289
|
+
*,
|
|
290
|
+
kind: str | None = None,
|
|
291
|
+
n_chunks: int | None = None,
|
|
292
|
+
dep: jnp.ndarray | None = None,
|
|
293
|
+
jit: bool = True,
|
|
294
|
+
pattern: PatternLike = "auto",
|
|
295
|
+
kernel: ElementBilinearKernel | ElementLinearKernel | None = None,
|
|
296
|
+
**kwargs,
|
|
297
|
+
):
|
|
298
|
+
"""
|
|
299
|
+
High-level assembly entry point with optional kernel caching.
|
|
300
|
+
|
|
301
|
+
kind: "bilinear" or "linear". If None, inferred from LinearForm/BilinearForm
|
|
302
|
+
or compiled/kernels tagged with _ff_kind metadata.
|
|
303
|
+
pattern: "auto" to reuse cached sparsity pattern for bilinear assembly.
|
|
304
|
+
"""
|
|
305
|
+
from .weakform import BilinearForm, LinearForm
|
|
306
|
+
from .assembly import make_element_bilinear_kernel, make_element_linear_kernel
|
|
307
|
+
|
|
308
|
+
if kind is None:
|
|
309
|
+
if isinstance(form, BilinearForm):
|
|
310
|
+
kind = "bilinear"
|
|
311
|
+
form = form.get_compiled()
|
|
312
|
+
elif isinstance(form, LinearForm):
|
|
313
|
+
kind = "linear"
|
|
314
|
+
form = form.get_compiled()
|
|
315
|
+
else:
|
|
316
|
+
inferred_kind = getattr(form, "_ff_kind", None)
|
|
317
|
+
inferred_domain = getattr(form, "_ff_domain", None)
|
|
318
|
+
if inferred_kind is None:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"kind is required for raw kernels without metadata (got {form!r}). "
|
|
321
|
+
"Use @ff.kernel(kind=..., domain=...) or pass kind=."
|
|
322
|
+
)
|
|
323
|
+
if inferred_domain not in (None, "volume"):
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
|
|
326
|
+
"Use assemble_surface_linear_form or assemble_surface_bilinear_form."
|
|
327
|
+
)
|
|
328
|
+
kind = inferred_kind
|
|
329
|
+
else:
|
|
330
|
+
inferred_kind = getattr(form, "_ff_kind", None)
|
|
331
|
+
inferred_domain = getattr(form, "_ff_domain", None)
|
|
332
|
+
if inferred_kind is None:
|
|
333
|
+
_warn_untagged_kernel(form)
|
|
334
|
+
if inferred_kind is not None and inferred_kind != kind:
|
|
335
|
+
raise ValueError(
|
|
336
|
+
f"assemble kind '{kind}' does not match form kind '{inferred_kind}' "
|
|
337
|
+
f"for {form!r}. "
|
|
338
|
+
"Align kind= with the kernel metadata (or retag the kernel)."
|
|
339
|
+
)
|
|
340
|
+
if inferred_domain not in (None, "volume"):
|
|
341
|
+
raise ValueError(
|
|
342
|
+
f"Unsupported form domain '{inferred_domain}' for Space.assemble. "
|
|
343
|
+
"Use assemble_surface_linear_form or assemble_surface_bilinear_form."
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
if kind == "bilinear":
|
|
347
|
+
if kernel is None:
|
|
348
|
+
kernel = self._get_cached_kernel(
|
|
349
|
+
"bilinear",
|
|
350
|
+
form,
|
|
351
|
+
params,
|
|
352
|
+
jit=jit,
|
|
353
|
+
maker=make_element_bilinear_kernel,
|
|
354
|
+
)
|
|
355
|
+
form_kernel = cast(FormKernel, form)
|
|
356
|
+
pattern_use = None
|
|
357
|
+
if pattern == "auto":
|
|
358
|
+
pattern_use = self.get_sparsity_pattern(with_idx=True)
|
|
359
|
+
else:
|
|
360
|
+
pattern_use = pattern
|
|
361
|
+
return self.assemble_bilinear_form(
|
|
362
|
+
form_kernel,
|
|
363
|
+
params,
|
|
364
|
+
n_chunks=n_chunks,
|
|
365
|
+
dep=dep,
|
|
366
|
+
kernel=kernel,
|
|
367
|
+
pattern=pattern_use,
|
|
368
|
+
**kwargs,
|
|
369
|
+
)
|
|
370
|
+
if kind == "linear":
|
|
371
|
+
if kernel is None:
|
|
372
|
+
kernel = self._get_cached_kernel(
|
|
373
|
+
"linear",
|
|
374
|
+
form,
|
|
375
|
+
params,
|
|
376
|
+
jit=jit,
|
|
377
|
+
maker=make_element_linear_kernel,
|
|
378
|
+
)
|
|
379
|
+
form_kernel = cast(FormKernel, form)
|
|
380
|
+
return self.assemble_linear_form(
|
|
381
|
+
form_kernel,
|
|
382
|
+
params,
|
|
383
|
+
n_chunks=n_chunks,
|
|
384
|
+
dep=dep,
|
|
385
|
+
kernel=kernel,
|
|
386
|
+
**kwargs,
|
|
387
|
+
)
|
|
388
|
+
raise ValueError(f"Unsupported assemble kind: {kind}")
|
|
389
|
+
|
|
158
390
|
# --- Thin wrappers over functional assembly APIs (kept functional for JAX friendliness) ---
|
|
159
|
-
def assemble_bilinear_form(
|
|
391
|
+
def assemble_bilinear_form(
|
|
392
|
+
self,
|
|
393
|
+
form: FormKernel[P],
|
|
394
|
+
params: P,
|
|
395
|
+
*,
|
|
396
|
+
n_chunks: int | None = None,
|
|
397
|
+
dep: jnp.ndarray | None = None,
|
|
398
|
+
kernel: ElementBilinearKernel | None = None,
|
|
399
|
+
**kwargs,
|
|
400
|
+
) -> FluxSparseMatrix:
|
|
401
|
+
"""Assemble bilinear form; kernel(ctx) -> (n_ldofs, n_ldofs) if provided."""
|
|
160
402
|
from .assembly import assemble_bilinear_form
|
|
161
403
|
if "pattern" not in kwargs or kwargs.get("pattern") is None:
|
|
162
404
|
kwargs["pattern"] = self.get_sparsity_pattern(with_idx=True)
|
|
163
|
-
return assemble_bilinear_form(
|
|
405
|
+
return assemble_bilinear_form(
|
|
406
|
+
self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
|
|
407
|
+
)
|
|
164
408
|
|
|
165
|
-
def assemble_linear_form(
|
|
409
|
+
def assemble_linear_form(
|
|
410
|
+
self,
|
|
411
|
+
form: FormKernel[P],
|
|
412
|
+
params: P,
|
|
413
|
+
*,
|
|
414
|
+
n_chunks: int | None = None,
|
|
415
|
+
dep: jnp.ndarray | None = None,
|
|
416
|
+
kernel: ElementLinearKernel | None = None,
|
|
417
|
+
**kwargs,
|
|
418
|
+
) -> LinearReturn:
|
|
419
|
+
"""Assemble linear form; kernel(ctx) -> (n_ldofs,) if provided."""
|
|
166
420
|
from .assembly import assemble_linear_form
|
|
167
|
-
return assemble_linear_form(
|
|
421
|
+
return assemble_linear_form(
|
|
422
|
+
self, form, params, n_chunks=n_chunks, dep=dep, kernel=kernel, **kwargs
|
|
423
|
+
)
|
|
168
424
|
|
|
169
|
-
def assemble_functional(self, form, params):
|
|
425
|
+
def assemble_functional(self, form: FormKernel[P], params: P) -> jnp.ndarray:
|
|
170
426
|
from .assembly import assemble_functional
|
|
171
427
|
return assemble_functional(self, form, params)
|
|
172
428
|
|
|
173
|
-
def assemble_mass_matrix(
|
|
429
|
+
def assemble_mass_matrix(
|
|
430
|
+
self,
|
|
431
|
+
*,
|
|
432
|
+
n_chunks: int | None = None,
|
|
433
|
+
**kwargs,
|
|
434
|
+
) -> MassReturn:
|
|
174
435
|
from .assembly import assemble_mass_matrix
|
|
175
|
-
return assemble_mass_matrix(self,
|
|
176
|
-
|
|
177
|
-
def assemble_bilinear_dense(
|
|
436
|
+
return assemble_mass_matrix(self, n_chunks=n_chunks, **kwargs)
|
|
437
|
+
|
|
438
|
+
def assemble_bilinear_dense(
|
|
439
|
+
self,
|
|
440
|
+
kernel: FormKernel[P],
|
|
441
|
+
params: P,
|
|
442
|
+
**kwargs,
|
|
443
|
+
) -> BilinearReturn:
|
|
178
444
|
from .assembly import assemble_bilinear_dense
|
|
179
445
|
return assemble_bilinear_dense(self, kernel, params, **kwargs)
|
|
180
446
|
|
|
181
|
-
def assemble_residual(
|
|
447
|
+
def assemble_residual(
|
|
448
|
+
self,
|
|
449
|
+
res_form: ResidualForm[P],
|
|
450
|
+
u: jnp.ndarray,
|
|
451
|
+
params: P,
|
|
452
|
+
*,
|
|
453
|
+
kernel: ElementResidualKernel | None = None,
|
|
454
|
+
**kwargs,
|
|
455
|
+
) -> LinearReturn:
|
|
456
|
+
"""Assemble residual; kernel(ctx, u_elem) -> (n_ldofs,) if provided."""
|
|
182
457
|
from .assembly import assemble_residual
|
|
183
|
-
return assemble_residual(self, res_form, u, params, **kwargs)
|
|
184
|
-
|
|
185
|
-
def assemble_jacobian(
|
|
458
|
+
return assemble_residual(self, res_form, u, params, kernel=kernel, **kwargs)
|
|
459
|
+
|
|
460
|
+
def assemble_jacobian(
|
|
461
|
+
self,
|
|
462
|
+
res_form: ResidualForm[P],
|
|
463
|
+
u: jnp.ndarray,
|
|
464
|
+
params: P,
|
|
465
|
+
*,
|
|
466
|
+
kernel: ElementJacobianKernel | None = None,
|
|
467
|
+
**kwargs,
|
|
468
|
+
) -> JacobianReturn:
|
|
469
|
+
"""Assemble Jacobian; kernel(u_elem, ctx) -> (n_ldofs, n_ldofs) if provided."""
|
|
186
470
|
from .assembly import assemble_jacobian
|
|
187
|
-
return assemble_jacobian(self, res_form, u, params, **kwargs)
|
|
471
|
+
return assemble_jacobian(self, res_form, u, params, kernel=kernel, **kwargs)
|
|
188
472
|
|
|
189
473
|
def get_sparsity_pattern(self, *, with_idx: bool = True):
|
|
190
474
|
cached = self._pattern_cache.get(with_idx)
|
|
@@ -192,7 +476,8 @@ class FESpaceClosure:
|
|
|
192
476
|
return cached
|
|
193
477
|
from .assembly import make_sparsity_pattern
|
|
194
478
|
pat = make_sparsity_pattern(self, with_idx=with_idx)
|
|
195
|
-
|
|
479
|
+
if jax.core.trace_ctx.is_top_level():
|
|
480
|
+
self._pattern_cache[with_idx] = pat
|
|
196
481
|
return pat
|
|
197
482
|
|
|
198
483
|
|
|
@@ -251,7 +536,7 @@ def make_space(
|
|
|
251
536
|
return FESpace(
|
|
252
537
|
mesh=mesh,
|
|
253
538
|
basis=basis,
|
|
254
|
-
elem_dofs=jnp.asarray(elem_dofs, dtype=
|
|
539
|
+
elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
|
|
255
540
|
value_dim=value_dim
|
|
256
541
|
)
|
|
257
542
|
|
|
@@ -320,7 +605,7 @@ def make_space_pytree(
|
|
|
320
605
|
return FESpacePytree(
|
|
321
606
|
mesh=mesh_py,
|
|
322
607
|
basis=basis_py,
|
|
323
|
-
elem_dofs=jnp.asarray(elem_dofs, dtype=
|
|
608
|
+
elem_dofs=jnp.asarray(elem_dofs, dtype=INDEX_DTYPE),
|
|
324
609
|
value_dim=value_dim,
|
|
325
610
|
)
|
|
326
611
|
|
|
@@ -340,7 +625,7 @@ def make_tet10_space_pytree(
|
|
|
340
625
|
"""Create a pytree quadratic tet space (10-node elements)."""
|
|
341
626
|
basis = make_tet10_basis_pytree(intorder)
|
|
342
627
|
element = None if dim == 1 else ElementVector(dim)
|
|
343
|
-
return make_space_pytree(mesh, basis, element)
|
|
628
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
344
629
|
|
|
345
630
|
|
|
346
631
|
def make_hex_space(mesh: HexMesh, dim: int = 1, intorder: int = 2) -> FESpace:
|
|
@@ -356,7 +641,7 @@ def make_hex_space_pytree(
|
|
|
356
641
|
"""Create a pytree trilinear hex space (8-node elements)."""
|
|
357
642
|
basis = make_hex_basis_pytree(intorder)
|
|
358
643
|
element = None if dim == 1 else ElementVector(dim)
|
|
359
|
-
return make_space_pytree(mesh, basis, element)
|
|
644
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
360
645
|
|
|
361
646
|
|
|
362
647
|
def make_hex20_space(
|
|
@@ -374,7 +659,7 @@ def make_hex20_space_pytree(
|
|
|
374
659
|
"""Create a pytree serendipity hex space (20-node elements)."""
|
|
375
660
|
basis = make_hex20_basis_pytree(intorder)
|
|
376
661
|
element = None if dim == 1 else ElementVector(dim)
|
|
377
|
-
return make_space_pytree(mesh, basis, element)
|
|
662
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
378
663
|
|
|
379
664
|
|
|
380
665
|
def make_hex27_space(
|
|
@@ -392,14 +677,14 @@ def make_hex27_space_pytree(
|
|
|
392
677
|
"""Create a pytree triquadratic hex space (27-node elements)."""
|
|
393
678
|
basis = make_hex27_basis_pytree(intorder)
|
|
394
679
|
element = None if dim == 1 else ElementVector(dim)
|
|
395
|
-
return make_space_pytree(mesh, basis, element)
|
|
680
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|
|
396
681
|
|
|
397
682
|
|
|
398
683
|
def make_tet_space(mesh: TetMesh, dim: int = 1, intorder: int = 2) -> FESpace:
|
|
399
684
|
"""Create a linear or quadratic tet space based on mesh nodes."""
|
|
400
685
|
n_nodes = mesh.conn.shape[1]
|
|
401
686
|
if n_nodes == 10:
|
|
402
|
-
basis = make_tet10_basis(intorder if intorder > 1 else 2)
|
|
687
|
+
basis: Basis3D = make_tet10_basis(intorder if intorder > 1 else 2)
|
|
403
688
|
else:
|
|
404
689
|
basis = make_tet_basis(intorder)
|
|
405
690
|
element = None if dim == 1 else ElementVector(dim)
|
|
@@ -412,8 +697,8 @@ def make_tet_space_pytree(
|
|
|
412
697
|
"""Create a pytree linear or quadratic tet space based on mesh nodes."""
|
|
413
698
|
n_nodes = mesh.conn.shape[1]
|
|
414
699
|
if n_nodes == 10:
|
|
415
|
-
basis = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
|
|
700
|
+
basis: Basis3D = make_tet10_basis_pytree(intorder if intorder > 1 else 2)
|
|
416
701
|
else:
|
|
417
702
|
basis = make_tet_basis_pytree(intorder)
|
|
418
703
|
element = None if dim == 1 else ElementVector(dim)
|
|
419
|
-
return make_space_pytree(mesh, basis, element)
|
|
704
|
+
return make_space_pytree(cast(BaseMeshPytree, mesh), basis, element)
|