fluxfem 0.1.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fluxfem might be problematic. Click here for more details.
- fluxfem/__init__.py +343 -0
- fluxfem/core/__init__.py +318 -0
- fluxfem/core/assembly.py +788 -0
- fluxfem/core/basis.py +996 -0
- fluxfem/core/data.py +64 -0
- fluxfem/core/dtypes.py +4 -0
- fluxfem/core/forms.py +234 -0
- fluxfem/core/interp.py +55 -0
- fluxfem/core/solver.py +113 -0
- fluxfem/core/space.py +419 -0
- fluxfem/core/weakform.py +828 -0
- fluxfem/helpers_ts.py +11 -0
- fluxfem/helpers_wf.py +44 -0
- fluxfem/mesh/__init__.py +29 -0
- fluxfem/mesh/base.py +244 -0
- fluxfem/mesh/hex.py +327 -0
- fluxfem/mesh/io.py +87 -0
- fluxfem/mesh/predicate.py +45 -0
- fluxfem/mesh/surface.py +257 -0
- fluxfem/mesh/tet.py +246 -0
- fluxfem/physics/__init__.py +53 -0
- fluxfem/physics/diffusion.py +18 -0
- fluxfem/physics/elasticity/__init__.py +39 -0
- fluxfem/physics/elasticity/hyperelastic.py +99 -0
- fluxfem/physics/elasticity/linear.py +58 -0
- fluxfem/physics/elasticity/materials.py +32 -0
- fluxfem/physics/elasticity/stress.py +46 -0
- fluxfem/physics/operators.py +109 -0
- fluxfem/physics/postprocess.py +113 -0
- fluxfem/solver/__init__.py +47 -0
- fluxfem/solver/bc.py +439 -0
- fluxfem/solver/cg.py +326 -0
- fluxfem/solver/dirichlet.py +126 -0
- fluxfem/solver/history.py +31 -0
- fluxfem/solver/newton.py +400 -0
- fluxfem/solver/result.py +62 -0
- fluxfem/solver/solve_runner.py +534 -0
- fluxfem/solver/solver.py +148 -0
- fluxfem/solver/sparse.py +188 -0
- fluxfem/tools/__init__.py +7 -0
- fluxfem/tools/jit.py +51 -0
- fluxfem/tools/timer.py +659 -0
- fluxfem/tools/visualizer.py +101 -0
- fluxfem-0.1.3a0.dist-info/LICENSE +201 -0
- fluxfem-0.1.3a0.dist-info/METADATA +125 -0
- fluxfem-0.1.3a0.dist-info/RECORD +47 -0
- fluxfem-0.1.3a0.dist-info/WHEEL +4 -0
fluxfem/core/basis.py
ADDED
|
@@ -0,0 +1,996 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Protocol
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
|
+
from .dtypes import DEFAULT_DTYPE as _FDTYPE
|
|
8
|
+
# from .dtypes import DEFAULT_DTYPE
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_B_matrices(dN_dx: jnp.ndarray) -> jnp.ndarray:
|
|
12
|
+
"""
|
|
13
|
+
Build B matrices for all quadrature points.
|
|
14
|
+
|
|
15
|
+
dN_dx: (n_q, 8, 3)
|
|
16
|
+
Returns:
|
|
17
|
+
B: (n_q, 6, 24) # 6 strain components, 3 dofs/node * 8 nodes
|
|
18
|
+
"""
|
|
19
|
+
n_q = dN_dx.shape[0]
|
|
20
|
+
n_nodes = 8
|
|
21
|
+
dofs_per_node = 3
|
|
22
|
+
n_dofs = n_nodes * dofs_per_node
|
|
23
|
+
|
|
24
|
+
def B_single(dN):
|
|
25
|
+
# dN: (8,3) for one quad point
|
|
26
|
+
B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
|
|
27
|
+
|
|
28
|
+
# loop over nodes (fixed small, ok as Python loop)
|
|
29
|
+
def body_fun(i, B):
|
|
30
|
+
dNdx = dN[i, 0]
|
|
31
|
+
dNdy = dN[i, 1]
|
|
32
|
+
dNdz = dN[i, 2]
|
|
33
|
+
|
|
34
|
+
col = 3 * i
|
|
35
|
+
# ε_xx, ε_yy, ε_zz
|
|
36
|
+
B = B.at[0, col + 0].set(dNdx)
|
|
37
|
+
B = B.at[1, col + 1].set(dNdy)
|
|
38
|
+
B = B.at[2, col + 2].set(dNdz)
|
|
39
|
+
|
|
40
|
+
# γ_xy
|
|
41
|
+
B = B.at[3, col + 0].set(dNdy)
|
|
42
|
+
B = B.at[3, col + 1].set(dNdx)
|
|
43
|
+
|
|
44
|
+
# γ_yz
|
|
45
|
+
B = B.at[4, col + 1].set(dNdz)
|
|
46
|
+
B = B.at[4, col + 2].set(dNdy)
|
|
47
|
+
|
|
48
|
+
# γ_zx
|
|
49
|
+
B = B.at[5, col + 0].set(dNdz)
|
|
50
|
+
B = B.at[5, col + 2].set(dNdx)
|
|
51
|
+
|
|
52
|
+
return B
|
|
53
|
+
|
|
54
|
+
B = jax.lax.fori_loop(0, n_nodes, body_fun, B)
|
|
55
|
+
return B
|
|
56
|
+
|
|
57
|
+
B = jax.vmap(B_single)(dN_dx) # (n_q, 6, 24)
|
|
58
|
+
return B
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def build_B_matrices_finite(dN_dX: jnp.ndarray, F: jnp.ndarray) -> jnp.ndarray:
|
|
62
|
+
"""
|
|
63
|
+
Build finite-strain B matrices (Voigt) at each quadrature point.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
dN_dX: (n_q, n_nodes, 3) gradients of shape functions in reference config.
|
|
67
|
+
F: (n_q, 3, 3) deformation gradient at each quadrature point.
|
|
68
|
+
Returns:
|
|
69
|
+
B: (n_q, 6, n_dofs) where n_dofs = 3 * n_nodes
|
|
70
|
+
Voigt order: [xx, yy, zz, xy, yz, zx]
|
|
71
|
+
"""
|
|
72
|
+
n_q, n_nodes, _ = dN_dX.shape
|
|
73
|
+
dofs_per_node = 3
|
|
74
|
+
n_dofs = n_nodes * dofs_per_node
|
|
75
|
+
|
|
76
|
+
def B_single(dN, Fq):
|
|
77
|
+
B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
|
|
78
|
+
|
|
79
|
+
def body_fun(i, B):
|
|
80
|
+
dNa = dN[i, :] # (3,)
|
|
81
|
+
col = dofs_per_node * i
|
|
82
|
+
|
|
83
|
+
ex = jnp.array([1.0, 0.0, 0.0], dtype=dN.dtype)
|
|
84
|
+
ey = jnp.array([0.0, 1.0, 0.0], dtype=dN.dtype)
|
|
85
|
+
ez = jnp.array([0.0, 0.0, 1.0], dtype=dN.dtype)
|
|
86
|
+
grads = jnp.stack(
|
|
87
|
+
(
|
|
88
|
+
jnp.outer(ex, dNa), # (3,3)
|
|
89
|
+
jnp.outer(ey, dNa),
|
|
90
|
+
jnp.outer(ez, dNa),
|
|
91
|
+
),
|
|
92
|
+
axis=0,
|
|
93
|
+
) # (3,3,3)
|
|
94
|
+
|
|
95
|
+
def fill_dof(k, B):
|
|
96
|
+
grad_delta = grads[k]
|
|
97
|
+
# Total Lagrange variation: dE = 0.5 * (∇δu · F + (∇δu · F)^T)
|
|
98
|
+
dE = 0.5 * (grad_delta @ Fq + (grad_delta @ Fq).T)
|
|
99
|
+
B = B.at[0, col + k].set(dE[0, 0]) # xx
|
|
100
|
+
B = B.at[1, col + k].set(dE[1, 1]) # yy
|
|
101
|
+
B = B.at[2, col + k].set(dE[2, 2]) # zz
|
|
102
|
+
B = B.at[3, col + k].set(dE[0, 1]) # xy
|
|
103
|
+
B = B.at[4, col + k].set(dE[1, 2]) # yz
|
|
104
|
+
B = B.at[5, col + k].set(dE[2, 0]) # zx
|
|
105
|
+
return B
|
|
106
|
+
|
|
107
|
+
B = jax.lax.fori_loop(0, dofs_per_node, fill_dof, B)
|
|
108
|
+
return B
|
|
109
|
+
|
|
110
|
+
B = jax.lax.fori_loop(0, n_nodes, body_fun, B)
|
|
111
|
+
return B
|
|
112
|
+
|
|
113
|
+
B = jax.vmap(B_single)(dN_dX, F) # (n_q, 6, n_dofs)
|
|
114
|
+
return B
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SmallStrainBMixin:
|
|
118
|
+
dofs_per_node: int = 3
|
|
119
|
+
|
|
120
|
+
def B_small_strain(self, dN_dx: jnp.ndarray) -> jnp.ndarray:
|
|
121
|
+
"""
|
|
122
|
+
dN_dx: (n_q, n_nodes, 3)
|
|
123
|
+
returns: (n_q, 6, 3*n_nodes) Voigt: [xx,yy,zz,xy,yz,zx]
|
|
124
|
+
"""
|
|
125
|
+
n_q, n_nodes, _ = dN_dx.shape
|
|
126
|
+
n_dofs = self.dofs_per_node * n_nodes
|
|
127
|
+
|
|
128
|
+
def B_single(dN):
|
|
129
|
+
B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
|
|
130
|
+
|
|
131
|
+
def body_fun(i, B):
|
|
132
|
+
dNdx, dNdy, dNdz = dN[i, 0], dN[i, 1], dN[i, 2]
|
|
133
|
+
col = self.dofs_per_node * i
|
|
134
|
+
|
|
135
|
+
# xx, yy, zz
|
|
136
|
+
B = B.at[0, col + 0].set(dNdx)
|
|
137
|
+
B = B.at[1, col + 1].set(dNdy)
|
|
138
|
+
B = B.at[2, col + 2].set(dNdz)
|
|
139
|
+
# xy
|
|
140
|
+
B = B.at[3, col + 0].set(dNdy)
|
|
141
|
+
B = B.at[3, col + 1].set(dNdx)
|
|
142
|
+
# yz
|
|
143
|
+
B = B.at[4, col + 1].set(dNdz)
|
|
144
|
+
B = B.at[4, col + 2].set(dNdy)
|
|
145
|
+
# zx
|
|
146
|
+
B = B.at[5, col + 0].set(dNdz)
|
|
147
|
+
B = B.at[5, col + 2].set(dNdx)
|
|
148
|
+
return B
|
|
149
|
+
|
|
150
|
+
return jax.lax.fori_loop(0, n_nodes, body_fun, B)
|
|
151
|
+
|
|
152
|
+
return jax.vmap(B_single)(dN_dx)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class TotalLagrangeBMixin:
|
|
156
|
+
dofs_per_node: int = 3
|
|
157
|
+
|
|
158
|
+
def B_total_lagrange(self, dN_dX: jnp.ndarray, F: jnp.ndarray) -> jnp.ndarray:
|
|
159
|
+
"""
|
|
160
|
+
dN_dX: (n_q, n_nodes, 3), F: (n_q, 3, 3)
|
|
161
|
+
returns: (n_q, 6, 3*n_nodes) Voigt: [xx,yy,zz,xy,yz,zx]
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
n_q, n_nodes, _ = dN_dX.shape
|
|
165
|
+
n_dofs = self.dofs_per_node * n_nodes
|
|
166
|
+
|
|
167
|
+
ex = jnp.array([1.0, 0.0, 0.0], dtype=dN_dX.dtype)
|
|
168
|
+
ey = jnp.array([0.0, 1.0, 0.0], dtype=dN_dX.dtype)
|
|
169
|
+
ez = jnp.array([0.0, 0.0, 1.0], dtype=dN_dX.dtype)
|
|
170
|
+
|
|
171
|
+
def B_single(dN, Fq):
|
|
172
|
+
B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
|
|
173
|
+
|
|
174
|
+
def node_fun(i, B):
|
|
175
|
+
dNa = dN[i, :]
|
|
176
|
+
col = self.dofs_per_node * i
|
|
177
|
+
|
|
178
|
+
grads = jnp.stack(
|
|
179
|
+
(jnp.outer(ex, dNa), jnp.outer(ey, dNa), jnp.outer(ez, dNa)),
|
|
180
|
+
axis=0,
|
|
181
|
+
) # (3,3,3)
|
|
182
|
+
|
|
183
|
+
def dof_fun(k, B):
|
|
184
|
+
grad_delta = grads[k]
|
|
185
|
+
dE = 0.5 * (grad_delta @ Fq + (grad_delta @ Fq).T)
|
|
186
|
+
B = B.at[0, col + k].set(dE[0, 0])
|
|
187
|
+
B = B.at[1, col + k].set(dE[1, 1])
|
|
188
|
+
B = B.at[2, col + k].set(dE[2, 2])
|
|
189
|
+
B = B.at[3, col + k].set(dE[0, 1])
|
|
190
|
+
B = B.at[4, col + k].set(dE[1, 2])
|
|
191
|
+
B = B.at[5, col + k].set(dE[2, 0])
|
|
192
|
+
return B
|
|
193
|
+
|
|
194
|
+
return jax.lax.fori_loop(0, self.dofs_per_node, dof_fun, B)
|
|
195
|
+
|
|
196
|
+
return jax.lax.fori_loop(0, n_nodes, node_fun, B)
|
|
197
|
+
|
|
198
|
+
return jax.vmap(B_single)(dN_dX, F)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class Basis3D(Protocol):
|
|
202
|
+
def B_small_strain(self, dN_dx: jnp.ndarray) -> jnp.ndarray: ...
|
|
203
|
+
def B_total_lagrange(
|
|
204
|
+
self, dN_dX: jnp.ndarray, F: jnp.ndarray) -> jnp.ndarray: ...
|
|
205
|
+
|
|
206
|
+
quad_points: jnp.ndarray # (n_q, 3)
|
|
207
|
+
quad_weights: jnp.ndarray # (n_q,)
|
|
208
|
+
dofs_per_node: int # usually 3 for vector mechanics
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def n_q(self) -> int: ...
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def n_nodes(self) -> int: ...
|
|
215
|
+
|
|
216
|
+
def shape_functions(self) -> jnp.ndarray: ...
|
|
217
|
+
|
|
218
|
+
def shape_grads_ref(self) -> jnp.ndarray: ...
|
|
219
|
+
|
|
220
|
+
def spatial_grads_and_detJ(
|
|
221
|
+
self, elem_coords: jnp.ndarray
|
|
222
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]: ...
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _quadratic_1d(x: jnp.ndarray):
|
|
226
|
+
"""1D serendipity shape funcs (vertex, mid-edge) for [-1,1]."""
|
|
227
|
+
N1 = -0.5 * x * (1.0 - x) # at -1
|
|
228
|
+
N2 = 1.0 - x * x # at 0
|
|
229
|
+
N3 = 0.5 * x * (1.0 + x) # at 1
|
|
230
|
+
dN1 = 0.5 * (2.0 * x - 1.0)
|
|
231
|
+
dN2 = -2.0 * x
|
|
232
|
+
dN3 = 0.5 * (2.0 * x + 1.0)
|
|
233
|
+
return (N1, N2, N3), (dN1, dN2, dN3)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _quad1d_full(x: jnp.ndarray):
|
|
237
|
+
"""1D quadratic (Lagrange) shape funcs at nodes (-1, 0, 1)."""
|
|
238
|
+
N0 = 0.5 * x * (x - 1.0)
|
|
239
|
+
N1 = 1.0 - x * x
|
|
240
|
+
N2 = 0.5 * x * (x + 1.0)
|
|
241
|
+
dN0 = x - 0.5
|
|
242
|
+
dN1 = -2.0 * x
|
|
243
|
+
dN2 = x + 0.5
|
|
244
|
+
return (N0, N1, N2), (dN0, dN1, dN2)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@dataclass(eq=False)
|
|
248
|
+
class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
249
|
+
"""4-node linear tetra basis with simple quadrature."""
|
|
250
|
+
|
|
251
|
+
quad_points: jnp.ndarray # (n_q, 3)
|
|
252
|
+
quad_weights: jnp.ndarray # (n_q,)
|
|
253
|
+
dofs_per_node: int = 3
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def n_nodes(self) -> int:
|
|
257
|
+
return 4
|
|
258
|
+
|
|
259
|
+
def tree_flatten(self):
|
|
260
|
+
children = (self.quad_points, self.quad_weights)
|
|
261
|
+
return children, {}
|
|
262
|
+
|
|
263
|
+
@classmethod
|
|
264
|
+
def tree_unflatten(cls, aux_data, children):
|
|
265
|
+
qp, qw = children
|
|
266
|
+
return cls(qp, qw)
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def n_q(self) -> int:
|
|
270
|
+
return int(self.quad_points.shape[0])
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def ref_node_coords(self) -> jnp.ndarray:
|
|
274
|
+
return jnp.array(
|
|
275
|
+
[
|
|
276
|
+
[0.0, 0.0, 0.0],
|
|
277
|
+
[1.0, 0.0, 0.0],
|
|
278
|
+
[0.0, 1.0, 0.0],
|
|
279
|
+
[0.0, 0.0, 1.0],
|
|
280
|
+
],
|
|
281
|
+
dtype=_FDTYPE,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def shape_functions(self) -> jnp.ndarray:
|
|
285
|
+
qp = self.quad_points # (n_q, 3)
|
|
286
|
+
xi = qp[:, 0]
|
|
287
|
+
eta = qp[:, 1]
|
|
288
|
+
zeta = qp[:, 2]
|
|
289
|
+
N1 = 1.0 - xi - eta - zeta
|
|
290
|
+
N2 = xi
|
|
291
|
+
N3 = eta
|
|
292
|
+
N4 = zeta
|
|
293
|
+
return jnp.stack([N1, N2, N3, N4], axis=1) # (n_q,4)
|
|
294
|
+
|
|
295
|
+
def shape_grads_ref(self) -> jnp.ndarray:
|
|
296
|
+
# constant gradients in reference tetra
|
|
297
|
+
dN = jnp.array(
|
|
298
|
+
[
|
|
299
|
+
[-1.0, -1.0, -1.0],
|
|
300
|
+
[1.0, 0.0, 0.0],
|
|
301
|
+
[0.0, 1.0, 0.0],
|
|
302
|
+
[0.0, 0.0, 1.0],
|
|
303
|
+
],
|
|
304
|
+
dtype=_FDTYPE,
|
|
305
|
+
)
|
|
306
|
+
dN = jnp.tile(dN[None, :, :], (self.n_q, 1, 1)) # (n_q,4,3)
|
|
307
|
+
return dN
|
|
308
|
+
|
|
309
|
+
def spatial_grads_and_detJ(
|
|
310
|
+
self, elem_coords: jnp.ndarray
|
|
311
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
312
|
+
dN_dxi = self.shape_grads_ref()[0] # (4,3) constant
|
|
313
|
+
J = jnp.einsum("ia,ik->ak", elem_coords, dN_dxi) # (3,3)
|
|
314
|
+
J_inv = jnp.linalg.inv(J)
|
|
315
|
+
detJ = jnp.linalg.det(J)
|
|
316
|
+
dN_dx = jnp.einsum("ik,ka->ia", dN_dxi, J_inv) # (4,3)
|
|
317
|
+
dN_dx = jnp.tile(dN_dx[None, :, :], (self.n_q, 1, 1))
|
|
318
|
+
detJ = jnp.full((self.n_q,), detJ, dtype=elem_coords.dtype)
|
|
319
|
+
return dN_dx, detJ
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@dataclass(eq=False)
|
|
323
|
+
class TetQuadraticBasis10(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
324
|
+
"""10-node quadratic tetra basis (corner + edge mids)."""
|
|
325
|
+
|
|
326
|
+
quad_points: jnp.ndarray
|
|
327
|
+
quad_weights: jnp.ndarray
|
|
328
|
+
dofs_per_node: int = 3
|
|
329
|
+
@property
|
|
330
|
+
def n_nodes(self) -> int:
|
|
331
|
+
return 10
|
|
332
|
+
|
|
333
|
+
def tree_flatten(self):
|
|
334
|
+
return (self.quad_points, self.quad_weights), {}
|
|
335
|
+
|
|
336
|
+
@classmethod
|
|
337
|
+
def tree_unflatten(cls, aux, children):
|
|
338
|
+
qp, qw = children
|
|
339
|
+
return cls(qp, qw)
|
|
340
|
+
|
|
341
|
+
@property
|
|
342
|
+
def n_q(self) -> int:
|
|
343
|
+
return int(self.quad_points.shape[0])
|
|
344
|
+
|
|
345
|
+
def shape_functions(self) -> jnp.ndarray:
|
|
346
|
+
qp = self.quad_points
|
|
347
|
+
L1 = 1.0 - qp[:, 0] - qp[:, 1] - qp[:, 2]
|
|
348
|
+
L2 = qp[:, 0]
|
|
349
|
+
L3 = qp[:, 1]
|
|
350
|
+
L4 = qp[:, 2]
|
|
351
|
+
N1 = L1 * (2 * L1 - 1)
|
|
352
|
+
N2 = L2 * (2 * L2 - 1)
|
|
353
|
+
N3 = L3 * (2 * L3 - 1)
|
|
354
|
+
N4 = L4 * (2 * L4 - 1)
|
|
355
|
+
N5 = 4 * L1 * L2
|
|
356
|
+
N6 = 4 * L2 * L3
|
|
357
|
+
N7 = 4 * L1 * L3
|
|
358
|
+
N8 = 4 * L1 * L4
|
|
359
|
+
N9 = 4 * L2 * L4
|
|
360
|
+
N10 = 4 * L3 * L4
|
|
361
|
+
return jnp.stack([N1, N2, N3, N4, N5, N6, N7, N8, N9, N10], axis=1)
|
|
362
|
+
|
|
363
|
+
def shape_grads_ref(self) -> jnp.ndarray:
|
|
364
|
+
qp = self.quad_points
|
|
365
|
+
L1 = 1.0 - qp[:, 0] - qp[:, 1] - qp[:, 2]
|
|
366
|
+
L2 = qp[:, 0]
|
|
367
|
+
L3 = qp[:, 1]
|
|
368
|
+
L4 = qp[:, 2]
|
|
369
|
+
|
|
370
|
+
dL1 = jnp.array([-1.0, -1.0, -1.0])
|
|
371
|
+
dL2 = jnp.array([1.0, 0.0, 0.0])
|
|
372
|
+
dL3 = jnp.array([0.0, 1.0, 0.0])
|
|
373
|
+
dL4 = jnp.array([0.0, 0.0, 1.0])
|
|
374
|
+
|
|
375
|
+
grads = []
|
|
376
|
+
for a, dLa in zip([L1, L2, L3, L4], [dL1, dL2, dL3, dL4]):
|
|
377
|
+
grads.append((2 * a - 1)[..., None] * dLa[None, :])
|
|
378
|
+
|
|
379
|
+
dN1 = grads[0]
|
|
380
|
+
dN2 = grads[1]
|
|
381
|
+
dN3 = grads[2]
|
|
382
|
+
dN4 = grads[3]
|
|
383
|
+
dN5 = 4 * (L2[..., None] * dL1[None, :] + L1[..., None] * dL2[None, :])
|
|
384
|
+
dN6 = 4 * (L3[..., None] * dL2[None, :] + L2[..., None] * dL3[None, :])
|
|
385
|
+
dN7 = 4 * (L3[..., None] * dL1[None, :] + L1[..., None] * dL3[None, :])
|
|
386
|
+
dN8 = 4 * (L4[..., None] * dL1[None, :] + L1[..., None] * dL4[None, :])
|
|
387
|
+
dN9 = 4 * (L4[..., None] * dL2[None, :] + L2[..., None] * dL4[None, :])
|
|
388
|
+
dN10 = 4 * (L4[..., None] * dL3[None, :] + L3[..., None] * dL4[None, :])
|
|
389
|
+
|
|
390
|
+
dN = jnp.stack([dN1, dN2, dN3, dN4, dN5, dN6, dN7, dN8, dN9, dN10], axis=1)
|
|
391
|
+
return dN # (n_q, 10, 3)
|
|
392
|
+
|
|
393
|
+
def spatial_grads_and_detJ(
|
|
394
|
+
self, elem_coords: jnp.ndarray
|
|
395
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
396
|
+
dN_dxi = self.shape_grads_ref() # (n_q,10,3)
|
|
397
|
+
J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi)
|
|
398
|
+
J_inv = jnp.linalg.inv(J)
|
|
399
|
+
detJ = jnp.linalg.det(J)
|
|
400
|
+
dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv)
|
|
401
|
+
return dN_dx, detJ
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@dataclass(eq=False)
|
|
406
|
+
class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
407
|
+
"""
|
|
408
|
+
Trilinear 8-node hex element basis with given quadrature rule.
|
|
409
|
+
|
|
410
|
+
quad_points: (n_q, 3) # (xi, eta, zeta) in [-1, 1]^3
|
|
411
|
+
quad_weights: (n_q,) # weights
|
|
412
|
+
"""
|
|
413
|
+
quad_points: jnp.ndarray # (n_q, 3)
|
|
414
|
+
quad_weights: jnp.ndarray # (n_q,)
|
|
415
|
+
dofs_per_node: int = 3
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def n_nodes(self) -> int:
|
|
419
|
+
return 8
|
|
420
|
+
|
|
421
|
+
def tree_flatten(self):
|
|
422
|
+
children = (self.quad_points, self.quad_weights)
|
|
423
|
+
aux_data = {}
|
|
424
|
+
return children, aux_data
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def tree_unflatten(cls, aux_data, children):
|
|
428
|
+
qp, qw = children
|
|
429
|
+
return cls(qp, qw)
|
|
430
|
+
|
|
431
|
+
@property
|
|
432
|
+
def n_q(self) -> int:
|
|
433
|
+
return int(self.quad_points.shape[0])
|
|
434
|
+
|
|
435
|
+
@property
|
|
436
|
+
def ref_node_signs(self) -> jnp.ndarray:
|
|
437
|
+
"""
|
|
438
|
+
Node signs (8,3) for (-1,1)^3 reference hex.
|
|
439
|
+
Node ordering:
|
|
440
|
+
0: (-1,-1,-1)
|
|
441
|
+
1: ( 1,-1,-1)
|
|
442
|
+
2: ( 1, 1,-1)
|
|
443
|
+
3: (-1, 1,-1)
|
|
444
|
+
4: (-1,-1, 1)
|
|
445
|
+
5: ( 1,-1, 1)
|
|
446
|
+
6: ( 1, 1, 1)
|
|
447
|
+
7: (-1, 1, 1)
|
|
448
|
+
"""
|
|
449
|
+
return jnp.array(
|
|
450
|
+
[
|
|
451
|
+
[-1.0, -1.0, -1.0],
|
|
452
|
+
[ 1.0, -1.0, -1.0],
|
|
453
|
+
[ 1.0, 1.0, -1.0],
|
|
454
|
+
[-1.0, 1.0, -1.0],
|
|
455
|
+
[-1.0, -1.0, 1.0],
|
|
456
|
+
[ 1.0, -1.0, 1.0],
|
|
457
|
+
[ 1.0, 1.0, 1.0],
|
|
458
|
+
[-1.0, 1.0, 1.0],
|
|
459
|
+
],
|
|
460
|
+
dtype=_FDTYPE,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# ---------- reference shape functions & gradients ----------
|
|
464
|
+
def shape_functions(self) -> jnp.ndarray:
|
|
465
|
+
"""
|
|
466
|
+
Evaluate shape functions at all quadrature points.
|
|
467
|
+
Returns: (n_q, 8)
|
|
468
|
+
"""
|
|
469
|
+
qp = self.quad_points # (n_q, 3)
|
|
470
|
+
s = self.ref_node_signs # (8, 3)
|
|
471
|
+
|
|
472
|
+
# broadcast: (n_q, 1, 3) * (1, 8, 3) -> (n_q, 8, 3)
|
|
473
|
+
# but we only need linear forms of xi,eta,ζ
|
|
474
|
+
xi = qp[:, 0:1] # (n_q, 1)
|
|
475
|
+
eta = qp[:, 1:2]
|
|
476
|
+
zeta = qp[:, 2:3]
|
|
477
|
+
|
|
478
|
+
s_xi = s[:, 0] # (8,)
|
|
479
|
+
s_eta = s[:, 1]
|
|
480
|
+
s_zeta = s[:, 2]
|
|
481
|
+
|
|
482
|
+
# (n_q, 8) via broadcasting
|
|
483
|
+
f_xi = 1.0 + xi * s_xi
|
|
484
|
+
f_eta = 1.0 + eta * s_eta
|
|
485
|
+
f_zeta = 1.0 + zeta * s_zeta
|
|
486
|
+
|
|
487
|
+
N = 0.125 * f_xi * f_eta * f_zeta # (n_q, 8)
|
|
488
|
+
return N
|
|
489
|
+
|
|
490
|
+
def shape_grads_ref(self) -> jnp.ndarray:
|
|
491
|
+
"""
|
|
492
|
+
Gradients in reference coords (ξ,η,ζ) at all quad points.
|
|
493
|
+
Returns: (n_q, 8, 3) [dN/dξ, dN/dη, dN/dζ]
|
|
494
|
+
"""
|
|
495
|
+
qp = self.quad_points # (n_q, 3)
|
|
496
|
+
s = self.ref_node_signs # (8, 3)
|
|
497
|
+
xi = qp[:, 0:1]
|
|
498
|
+
eta = qp[:, 1:2]
|
|
499
|
+
zeta = qp[:, 2:3]
|
|
500
|
+
|
|
501
|
+
s_xi = s[:, 0] # (8,)
|
|
502
|
+
s_eta = s[:, 1]
|
|
503
|
+
s_zeta = s[:, 2]
|
|
504
|
+
|
|
505
|
+
# helper linear terms
|
|
506
|
+
f_xi = 1.0 + xi * s_xi
|
|
507
|
+
f_eta = 1.0 + eta * s_eta
|
|
508
|
+
f_zeta = 1.0 + zeta * s_zeta
|
|
509
|
+
|
|
510
|
+
# dN/dξ
|
|
511
|
+
dN_dxi = 0.125 * s_xi * f_eta * f_zeta # (n_q, 8)
|
|
512
|
+
# dN/dη
|
|
513
|
+
dN_deta = 0.125 * s_eta * f_xi * f_zeta
|
|
514
|
+
# dN/dζ
|
|
515
|
+
dN_dzeta = 0.125 * s_zeta * f_xi * f_eta
|
|
516
|
+
|
|
517
|
+
# stack into (n_q, 8, 3)
|
|
518
|
+
dN = jnp.stack([dN_dxi, dN_deta, dN_dzeta], axis=-1)
|
|
519
|
+
return dN # (n_q, 8, 3)
|
|
520
|
+
|
|
521
|
+
# ---------- mapping to physical element ----------
|
|
522
|
+
def spatial_grads_and_detJ(
|
|
523
|
+
self, elem_coords: jnp.ndarray
|
|
524
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
525
|
+
"""
|
|
526
|
+
Compute spatial gradients and detJ for one element.
|
|
527
|
+
|
|
528
|
+
Parameters
|
|
529
|
+
----------
|
|
530
|
+
elem_coords : jnp.ndarray
|
|
531
|
+
Element coordinates of shape (8, 3).
|
|
532
|
+
|
|
533
|
+
Returns
|
|
534
|
+
-------
|
|
535
|
+
dN_dx : jnp.ndarray
|
|
536
|
+
Spatial gradients of shape (n_q, 8, 3).
|
|
537
|
+
detJ : jnp.ndarray
|
|
538
|
+
Determinant of the Jacobian of shape (n_q,).
|
|
539
|
+
"""
|
|
540
|
+
dN_dxi = self.shape_grads_ref() # (n_q, 8, 3)
|
|
541
|
+
|
|
542
|
+
# Jacobian: J(q)[α,k] = sum_i X_i[α] * ∂N_i/∂ξ_k
|
|
543
|
+
# elem_coords: (8,3) -> i,α
|
|
544
|
+
# dN_dxi: (n_q,8,3)-> q,i,k
|
|
545
|
+
J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi) # (n_q, 3, 3)
|
|
546
|
+
|
|
547
|
+
J_inv = jnp.linalg.inv(J) # (n_q, 3, 3)
|
|
548
|
+
detJ = jnp.linalg.det(J) # (n_q,)
|
|
549
|
+
|
|
550
|
+
# ∇_x N = ∇_ξ N · J^{-1}
|
|
551
|
+
dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv) # (n_q, 8, 3)
|
|
552
|
+
return dN_dx, detJ
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
@dataclass(eq=False)
|
|
556
|
+
class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
557
|
+
"""
|
|
558
|
+
20-node serendipity hex element basis (corner + edge midpoints).
|
|
559
|
+
"""
|
|
560
|
+
quad_points: jnp.ndarray # (n_q, 3)
|
|
561
|
+
quad_weights: jnp.ndarray # (n_q,)
|
|
562
|
+
dofs_per_node: int = 3
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def n_nodes(self) -> int:
|
|
566
|
+
return 20
|
|
567
|
+
|
|
568
|
+
def tree_flatten(self):
|
|
569
|
+
children = (self.quad_points, self.quad_weights)
|
|
570
|
+
aux_data = {}
|
|
571
|
+
return children, aux_data
|
|
572
|
+
|
|
573
|
+
@classmethod
|
|
574
|
+
def tree_unflatten(cls, aux_data, children):
|
|
575
|
+
qp, qw = children
|
|
576
|
+
return cls(qp, qw)
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
def n_q(self) -> int:
|
|
580
|
+
return int(self.quad_points.shape[0])
|
|
581
|
+
|
|
582
|
+
@property
|
|
583
|
+
def ref_node_coords(self) -> jnp.ndarray:
|
|
584
|
+
corners = jnp.array(
|
|
585
|
+
[
|
|
586
|
+
[-1.0, -1.0, -1.0],
|
|
587
|
+
[ 1.0, -1.0, -1.0],
|
|
588
|
+
[ 1.0, 1.0, -1.0],
|
|
589
|
+
[-1.0, 1.0, -1.0],
|
|
590
|
+
[-1.0, -1.0, 1.0],
|
|
591
|
+
[ 1.0, -1.0, 1.0],
|
|
592
|
+
[ 1.0, 1.0, 1.0],
|
|
593
|
+
[-1.0, 1.0, 1.0],
|
|
594
|
+
],
|
|
595
|
+
dtype=_FDTYPE,
|
|
596
|
+
)
|
|
597
|
+
edges = jnp.array(
|
|
598
|
+
[
|
|
599
|
+
[ 0.0, -1.0, -1.0], # 0-1
|
|
600
|
+
[ 1.0, 0.0, -1.0], # 1-2
|
|
601
|
+
[ 0.0, 1.0, -1.0], # 2-3
|
|
602
|
+
[-1.0, 0.0, -1.0], # 3-0
|
|
603
|
+
[ 0.0, -1.0, 1.0], # 4-5
|
|
604
|
+
[ 1.0, 0.0, 1.0], # 5-6
|
|
605
|
+
[ 0.0, 1.0, 1.0], # 6-7
|
|
606
|
+
[-1.0, 0.0, 1.0], # 7-4
|
|
607
|
+
[-1.0, -1.0, 0.0], # 0-4
|
|
608
|
+
[ 1.0, -1.0, 0.0], # 1-5
|
|
609
|
+
[ 1.0, 1.0, 0.0], # 2-6
|
|
610
|
+
[-1.0, 1.0, 0.0], # 3-7
|
|
611
|
+
],
|
|
612
|
+
dtype=_FDTYPE,
|
|
613
|
+
)
|
|
614
|
+
return jnp.concatenate([corners, edges], axis=0) # (20,3)
|
|
615
|
+
|
|
616
|
+
def shape_functions(self) -> jnp.ndarray:
|
|
617
|
+
qp = self.quad_points # (n_q, 3)
|
|
618
|
+
xi = qp[:, 0:1]
|
|
619
|
+
eta = qp[:, 1:2]
|
|
620
|
+
zeta = qp[:, 2:3]
|
|
621
|
+
|
|
622
|
+
# corners
|
|
623
|
+
s = jnp.array(
|
|
624
|
+
[
|
|
625
|
+
[-1, -1, -1],
|
|
626
|
+
[ 1, -1, -1],
|
|
627
|
+
[ 1, 1, -1],
|
|
628
|
+
[-1, 1, -1],
|
|
629
|
+
[-1, -1, 1],
|
|
630
|
+
[ 1, -1, 1],
|
|
631
|
+
[ 1, 1, 1],
|
|
632
|
+
[-1, 1, 1],
|
|
633
|
+
],
|
|
634
|
+
dtype=_FDTYPE,
|
|
635
|
+
)
|
|
636
|
+
sx = s[:, 0]
|
|
637
|
+
sy = s[:, 1]
|
|
638
|
+
sz = s[:, 2]
|
|
639
|
+
term = xi * sx + eta * sy + zeta * sz - 2.0 # (n_q,8)
|
|
640
|
+
N_corner = 0.125 * (1 + sx * xi) * (1 + sy * eta) * (1 + sz * zeta) * term # (n_q,8)
|
|
641
|
+
|
|
642
|
+
# edges
|
|
643
|
+
edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # eta, zeta fixed
|
|
644
|
+
edges_y = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # xi fixed
|
|
645
|
+
edges_z = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # xi, eta fixed
|
|
646
|
+
|
|
647
|
+
N_edges = []
|
|
648
|
+
# along xi (1 - xi^2)
|
|
649
|
+
for sy, sz in edges_x:
|
|
650
|
+
N_edges.append(0.25 * (1 - xi * xi) * (1 + sy * eta) * (1 + sz * zeta))
|
|
651
|
+
# along eta
|
|
652
|
+
for sx, sz in edges_y:
|
|
653
|
+
N_edges.append(0.25 * (1 - eta * eta) * (1 + sx * xi) * (1 + sz * zeta))
|
|
654
|
+
# along zeta
|
|
655
|
+
for sx, sy in edges_z:
|
|
656
|
+
N_edges.append(0.25 * (1 - zeta * zeta) * (1 + sx * xi) * (1 + sy * eta))
|
|
657
|
+
|
|
658
|
+
N_edges = jnp.concatenate(N_edges, axis=1) # (n_q, 12)
|
|
659
|
+
return jnp.concatenate([N_corner, N_edges], axis=1) # (n_q,20)
|
|
660
|
+
|
|
661
|
+
def shape_grads_ref(self) -> jnp.ndarray:
|
|
662
|
+
qp = self.quad_points
|
|
663
|
+
xi = qp[:, 0:1]
|
|
664
|
+
eta = qp[:, 1:2]
|
|
665
|
+
zeta = qp[:, 2:3]
|
|
666
|
+
|
|
667
|
+
s = jnp.array(
|
|
668
|
+
[
|
|
669
|
+
[-1, -1, -1],
|
|
670
|
+
[ 1, -1, -1],
|
|
671
|
+
[ 1, 1, -1],
|
|
672
|
+
[-1, 1, -1],
|
|
673
|
+
[-1, -1, 1],
|
|
674
|
+
[ 1, -1, 1],
|
|
675
|
+
[ 1, 1, 1],
|
|
676
|
+
[-1, 1, 1],
|
|
677
|
+
],
|
|
678
|
+
dtype=_FDTYPE,
|
|
679
|
+
)
|
|
680
|
+
sx = s[:, 0]
|
|
681
|
+
sy = s[:, 1]
|
|
682
|
+
sz = s[:, 2]
|
|
683
|
+
term = xi * sx + eta * sy + zeta * sz - 2.0
|
|
684
|
+
|
|
685
|
+
dN_dxi_corner = (sx / 8.0) * (1 + sy * eta) * (1 + sz * zeta) * (term + (1 + sx * xi))
|
|
686
|
+
dN_deta_corner = (sy / 8.0) * (1 + sx * xi) * (1 + sz * zeta) * (term + (1 + sy * eta))
|
|
687
|
+
dN_dzeta_corner = (sz / 8.0) * (1 + sx * xi) * (1 + sy * eta) * (term + (1 + sz * zeta))
|
|
688
|
+
|
|
689
|
+
d_corner = jnp.stack(
|
|
690
|
+
[
|
|
691
|
+
dN_dxi_corner,
|
|
692
|
+
dN_deta_corner,
|
|
693
|
+
dN_dzeta_corner,
|
|
694
|
+
],
|
|
695
|
+
axis=2,
|
|
696
|
+
) # (n_q,8,3)
|
|
697
|
+
|
|
698
|
+
# edges derivatives
|
|
699
|
+
d_list = []
|
|
700
|
+
# along xi
|
|
701
|
+
edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
|
|
702
|
+
for sy_val, sz_val in edges_x:
|
|
703
|
+
sy_ = sy_val
|
|
704
|
+
sz_ = sz_val
|
|
705
|
+
dxi = -0.5 * xi * (1 + sy_ * eta) * (1 + sz_ * zeta)
|
|
706
|
+
deta = 0.25 * (1 - xi * xi) * sy_ * (1 + sz_ * zeta)
|
|
707
|
+
dzeta = 0.25 * (1 - xi * xi) * (1 + sy_ * eta) * sz_
|
|
708
|
+
d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
|
|
709
|
+
# along eta
|
|
710
|
+
edges_y = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
|
|
711
|
+
for sx_val, sz_val in edges_y:
|
|
712
|
+
sx_ = sx_val
|
|
713
|
+
sz_ = sz_val
|
|
714
|
+
dxi = 0.25 * (1 - eta * eta) * sx_ * (1 + sz_ * zeta)
|
|
715
|
+
deta = -0.5 * eta * (1 + sx_ * xi) * (1 + sz_ * zeta)
|
|
716
|
+
dzeta = 0.25 * (1 - eta * eta) * (1 + sx_ * xi) * sz_
|
|
717
|
+
d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
|
|
718
|
+
# along zeta
|
|
719
|
+
edges_z = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
|
|
720
|
+
for sx_val, sy_val in edges_z:
|
|
721
|
+
sx_ = sx_val
|
|
722
|
+
sy_ = sy_val
|
|
723
|
+
dxi = 0.25 * (1 - zeta * zeta) * sx_ * (1 + sy_ * eta)
|
|
724
|
+
deta = 0.25 * (1 - zeta * zeta) * (1 + sx_ * xi) * sy_
|
|
725
|
+
dzeta = -0.5 * zeta * (1 + sx_ * xi) * (1 + sy_ * eta)
|
|
726
|
+
d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
|
|
727
|
+
|
|
728
|
+
d_edges = jnp.concatenate(d_list, axis=1) # (n_q,12,3)
|
|
729
|
+
return jnp.concatenate([d_corner, d_edges], axis=1) # (n_q,20,3)
|
|
730
|
+
|
|
731
|
+
def spatial_grads_and_detJ(
|
|
732
|
+
self, elem_coords: jnp.ndarray
|
|
733
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
734
|
+
dN_dxi = self.shape_grads_ref() # (n_q, 20, 3)
|
|
735
|
+
J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi) # (n_q, 3, 3)
|
|
736
|
+
J_inv = jnp.linalg.inv(J)
|
|
737
|
+
detJ = jnp.linalg.det(J)
|
|
738
|
+
dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv)
|
|
739
|
+
return dN_dx, detJ
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
@dataclass(eq=False)
|
|
743
|
+
class HexTriQuadraticBasis27(SmallStrainBMixin, TotalLagrangeBMixin):
|
|
744
|
+
"""27-node triquadratic hex (tensor-product)."""
|
|
745
|
+
|
|
746
|
+
quad_points: jnp.ndarray
|
|
747
|
+
quad_weights: jnp.ndarray
|
|
748
|
+
dofs_per_node: int = 3
|
|
749
|
+
@property
|
|
750
|
+
def n_nodes(self) -> int:
|
|
751
|
+
return 27
|
|
752
|
+
|
|
753
|
+
def tree_flatten(self):
|
|
754
|
+
return (self.quad_points, self.quad_weights), {}
|
|
755
|
+
|
|
756
|
+
@classmethod
|
|
757
|
+
def tree_unflatten(cls, aux, children):
|
|
758
|
+
qp, qw = children
|
|
759
|
+
return cls(qp, qw)
|
|
760
|
+
|
|
761
|
+
@property
|
|
762
|
+
def n_q(self) -> int:
|
|
763
|
+
return int(self.quad_points.shape[0])
|
|
764
|
+
|
|
765
|
+
def shape_functions(self) -> jnp.ndarray:
|
|
766
|
+
qp = self.quad_points
|
|
767
|
+
xi = qp[:, 0]
|
|
768
|
+
eta = qp[:, 1]
|
|
769
|
+
zeta = qp[:, 2]
|
|
770
|
+
|
|
771
|
+
(Nx0, Nx1, Nx2), _ = _quad1d_full(xi)
|
|
772
|
+
(Ny0, Ny1, Ny2), _ = _quad1d_full(eta)
|
|
773
|
+
(Nz0, Nz1, Nz2), _ = _quad1d_full(zeta)
|
|
774
|
+
|
|
775
|
+
N = []
|
|
776
|
+
for k, Nk in enumerate([Nz0, Nz1, Nz2]):
|
|
777
|
+
for j, Nj in enumerate([Ny0, Ny1, Ny2]):
|
|
778
|
+
for i, Ni in enumerate([Nx0, Nx1, Nx2]):
|
|
779
|
+
N.append(Ni * Nj * Nk)
|
|
780
|
+
return jnp.stack(N, axis=1) # (n_q, 27)
|
|
781
|
+
|
|
782
|
+
def shape_grads_ref(self) -> jnp.ndarray:
|
|
783
|
+
qp = self.quad_points
|
|
784
|
+
xi = qp[:, 0]
|
|
785
|
+
eta = qp[:, 1]
|
|
786
|
+
zeta = qp[:, 2]
|
|
787
|
+
|
|
788
|
+
(Nx0, Nx1, Nx2), (dNx0, dNx1, dNx2) = _quad1d_full(xi)
|
|
789
|
+
(Ny0, Ny1, Ny2), (dNy0, dNy1, dNy2) = _quad1d_full(eta)
|
|
790
|
+
(Nz0, Nz1, Nz2), (dNz0, dNz1, dNz2) = _quad1d_full(zeta)
|
|
791
|
+
|
|
792
|
+
grads = []
|
|
793
|
+
for k in range(3):
|
|
794
|
+
Nz = [Nz0, Nz1, Nz2][k]
|
|
795
|
+
dNz = [dNz0, dNz1, dNz2][k]
|
|
796
|
+
for j in range(3):
|
|
797
|
+
Ny = [Ny0, Ny1, Ny2][j]
|
|
798
|
+
dNy = [dNy0, dNy1, dNy2][j]
|
|
799
|
+
for i in range(3):
|
|
800
|
+
Nx = [Nx0, Nx1, Nx2][i]
|
|
801
|
+
dNx = [dNx0, dNx1, dNx2][i]
|
|
802
|
+
dxi = dNx * Ny * Nz
|
|
803
|
+
deta = Nx * dNy * Nz
|
|
804
|
+
dzeta = Nx * Ny * dNz
|
|
805
|
+
grads.append(jnp.stack([dxi, deta, dzeta], axis=1))
|
|
806
|
+
return jnp.stack(grads, axis=1) # (n_q, 27, 3)
|
|
807
|
+
|
|
808
|
+
def spatial_grads_and_detJ(
|
|
809
|
+
self, elem_coords: jnp.ndarray
|
|
810
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
811
|
+
dN_dxi = self.shape_grads_ref() # (n_q, 27, 3)
|
|
812
|
+
J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi)
|
|
813
|
+
J_inv = jnp.linalg.inv(J)
|
|
814
|
+
detJ = jnp.linalg.det(J)
|
|
815
|
+
dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv)
|
|
816
|
+
return dN_dx, detJ
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
@jax.tree_util.register_pytree_node_class
|
|
820
|
+
class TetLinearBasisPytree(TetLinearBasis):
|
|
821
|
+
pass
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
@jax.tree_util.register_pytree_node_class
|
|
825
|
+
class TetQuadraticBasis10Pytree(TetQuadraticBasis10):
|
|
826
|
+
pass
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
@jax.tree_util.register_pytree_node_class
|
|
830
|
+
class HexTriLinearBasisPytree(HexTriLinearBasis):
|
|
831
|
+
pass
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
@jax.tree_util.register_pytree_node_class
|
|
835
|
+
class HexSerendipityBasis20Pytree(HexSerendipityBasis20):
|
|
836
|
+
pass
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
@jax.tree_util.register_pytree_node_class
|
|
840
|
+
class HexTriQuadraticBasis27Pytree(HexTriQuadraticBasis27):
|
|
841
|
+
pass
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
def _gauss_legendre_1d(order: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
845
|
+
"""1D Gauss-Legendre points and weights of given order."""
|
|
846
|
+
if order <= 0:
|
|
847
|
+
raise ValueError("quadrature order must be positive")
|
|
848
|
+
pts, wts = np.polynomial.legendre.leggauss(order)
|
|
849
|
+
return jnp.array(pts, dtype=_FDTYPE), jnp.array(wts, dtype=_FDTYPE)
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
def _gl_points_for_degree(degree: int) -> int:
|
|
853
|
+
"""
|
|
854
|
+
Map polynomial exactness degree to Gauss-Legendre point count in 1D.
|
|
855
|
+
n points integrate degree (2n-1) exactly.
|
|
856
|
+
"""
|
|
857
|
+
if degree <= 0:
|
|
858
|
+
return 1
|
|
859
|
+
return int(np.ceil((degree + 1) / 2))
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
863
|
+
"""
|
|
864
|
+
Degree-based quadrature rules for reference tetra (volume = 1/6).
|
|
865
|
+
degree<=1: 1-point; degree<=2: 4-point; degree>=3: 5-point (Stroud T3-5).
|
|
866
|
+
"""
|
|
867
|
+
if degree <= 1:
|
|
868
|
+
qp = jnp.array([[0.25, 0.25, 0.25]], dtype=_FDTYPE)
|
|
869
|
+
qw = jnp.array([1.0 / 6.0], dtype=_FDTYPE)
|
|
870
|
+
return qp, qw
|
|
871
|
+
if degree <= 2:
|
|
872
|
+
qp = jnp.array(
|
|
873
|
+
[
|
|
874
|
+
[0.58541020, 0.13819660, 0.13819660],
|
|
875
|
+
[0.13819660, 0.58541020, 0.13819660],
|
|
876
|
+
[0.13819660, 0.13819660, 0.58541020],
|
|
877
|
+
[0.13819660, 0.13819660, 0.13819660],
|
|
878
|
+
],
|
|
879
|
+
dtype=_FDTYPE,
|
|
880
|
+
)
|
|
881
|
+
qw = jnp.full((4,), (1.0 / 24.0), dtype=_FDTYPE)
|
|
882
|
+
return qp, qw
|
|
883
|
+
# degree 3 rule: centroid + 4 symmetric points
|
|
884
|
+
qp = jnp.array(
|
|
885
|
+
[
|
|
886
|
+
[0.25, 0.25, 0.25],
|
|
887
|
+
[0.50, 1.0 / 6.0, 1.0 / 6.0],
|
|
888
|
+
[1.0 / 6.0, 0.50, 1.0 / 6.0],
|
|
889
|
+
[1.0 / 6.0, 1.0 / 6.0, 0.50],
|
|
890
|
+
[1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0],
|
|
891
|
+
],
|
|
892
|
+
dtype=_FDTYPE,
|
|
893
|
+
)
|
|
894
|
+
qw = jnp.array(
|
|
895
|
+
[-2.0 / 15.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0],
|
|
896
|
+
dtype=_FDTYPE,
|
|
897
|
+
)
|
|
898
|
+
return qp, qw
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def make_tet_basis(intorder: int = 2) -> TetLinearBasis:
|
|
902
|
+
"""Create a linear tet basis with degree-based quadrature."""
|
|
903
|
+
qp, qw = _tet_quadrature(intorder)
|
|
904
|
+
return TetLinearBasis(qp, qw)
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def make_tet_basis_pytree(intorder: int = 2) -> TetLinearBasisPytree:
|
|
908
|
+
"""Create a pytree linear tet basis with degree-based quadrature."""
|
|
909
|
+
qp, qw = _tet_quadrature(intorder)
|
|
910
|
+
return TetLinearBasisPytree(qp, qw)
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def make_hex_basis(intorder: int = 2) -> HexTriLinearBasis:
|
|
914
|
+
"""
|
|
915
|
+
Trilinear hex basis with tensor-product Gauss-Legendre quadrature.
|
|
916
|
+
intorder = polynomial exactness degree (scikit-fem style).
|
|
917
|
+
degree=1 → 1×1×1, degree=2/3 → 2×2×2, degree=4/5 → 3×3×3, etc.
|
|
918
|
+
"""
|
|
919
|
+
n_1d = _gl_points_for_degree(intorder)
|
|
920
|
+
pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
|
|
921
|
+
xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
|
|
922
|
+
qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3) # (intorder^3, 3)
|
|
923
|
+
|
|
924
|
+
w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
|
|
925
|
+
qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1) # (intorder^3,)
|
|
926
|
+
return HexTriLinearBasis(qp, qw)
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def make_hex_basis_pytree(intorder: int = 2) -> HexTriLinearBasisPytree:
|
|
930
|
+
"""Create a pytree trilinear hex basis with tensor-product quadrature."""
|
|
931
|
+
n_1d = _gl_points_for_degree(intorder)
|
|
932
|
+
pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
|
|
933
|
+
xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
|
|
934
|
+
qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
|
|
935
|
+
|
|
936
|
+
w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
|
|
937
|
+
qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
|
|
938
|
+
return HexTriLinearBasisPytree(qp, qw)
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
def make_hex20_basis(intorder: int = 2) -> HexSerendipityBasis20:
|
|
942
|
+
"""Create a serendipity hex basis with tensor-product quadrature."""
|
|
943
|
+
n_1d = _gl_points_for_degree(intorder)
|
|
944
|
+
pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
|
|
945
|
+
xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
|
|
946
|
+
qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
|
|
947
|
+
|
|
948
|
+
w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
|
|
949
|
+
qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
|
|
950
|
+
return HexSerendipityBasis20(qp, qw)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def make_hex20_basis_pytree(intorder: int = 2) -> HexSerendipityBasis20Pytree:
|
|
954
|
+
"""Create a pytree serendipity hex basis with tensor-product quadrature."""
|
|
955
|
+
n_1d = _gl_points_for_degree(intorder)
|
|
956
|
+
pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
|
|
957
|
+
xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
|
|
958
|
+
qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
|
|
959
|
+
|
|
960
|
+
w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
|
|
961
|
+
qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
|
|
962
|
+
return HexSerendipityBasis20Pytree(qp, qw)
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def make_hex27_basis(intorder: int = 3) -> HexTriQuadraticBasis27:
|
|
966
|
+
"""Create a triquadratic hex basis with tensor-product quadrature."""
|
|
967
|
+
n_1d = _gl_points_for_degree(intorder)
|
|
968
|
+
pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
|
|
969
|
+
xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
|
|
970
|
+
qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
|
|
971
|
+
w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
|
|
972
|
+
qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
|
|
973
|
+
return HexTriQuadraticBasis27(qp, qw)
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def make_hex27_basis_pytree(intorder: int = 3) -> HexTriQuadraticBasis27Pytree:
|
|
977
|
+
"""Create a pytree triquadratic hex basis with tensor-product quadrature."""
|
|
978
|
+
n_1d = _gl_points_for_degree(intorder)
|
|
979
|
+
pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
|
|
980
|
+
xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
|
|
981
|
+
qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
|
|
982
|
+
w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
|
|
983
|
+
qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
|
|
984
|
+
return HexTriQuadraticBasis27Pytree(qp, qw)
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
def make_tet10_basis(intorder: int = 2) -> TetQuadraticBasis10:
|
|
988
|
+
"""Create a quadratic tet basis with degree-based quadrature."""
|
|
989
|
+
qp, qw = _tet_quadrature(intorder)
|
|
990
|
+
return TetQuadraticBasis10(qp, qw)
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def make_tet10_basis_pytree(intorder: int = 2) -> TetQuadraticBasis10Pytree:
|
|
994
|
+
"""Create a pytree quadratic tet basis with degree-based quadrature."""
|
|
995
|
+
qp, qw = _tet_quadrature(intorder)
|
|
996
|
+
return TetQuadraticBasis10Pytree(qp, qw)
|