fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
fluxfem/core/basis.py ADDED
@@ -0,0 +1,996 @@
1
+ from dataclasses import dataclass
2
+ from typing import Protocol
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ from .dtypes import DEFAULT_DTYPE as _FDTYPE
8
+ # from .dtypes import DEFAULT_DTYPE
9
+
10
+
11
+ def build_B_matrices(dN_dx: jnp.ndarray) -> jnp.ndarray:
12
+ """
13
+ Build B matrices for all quadrature points.
14
+
15
+ dN_dx: (n_q, 8, 3)
16
+ Returns:
17
+ B: (n_q, 6, 24) # 6 strain components, 3 dofs/node * 8 nodes
18
+ """
19
+ n_q = dN_dx.shape[0]
20
+ n_nodes = 8
21
+ dofs_per_node = 3
22
+ n_dofs = n_nodes * dofs_per_node
23
+
24
+ def B_single(dN):
25
+ # dN: (8,3) for one quad point
26
+ B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
27
+
28
+ # loop over nodes (fixed small, ok as Python loop)
29
+ def body_fun(i, B):
30
+ dNdx = dN[i, 0]
31
+ dNdy = dN[i, 1]
32
+ dNdz = dN[i, 2]
33
+
34
+ col = 3 * i
35
+ # ε_xx, ε_yy, ε_zz
36
+ B = B.at[0, col + 0].set(dNdx)
37
+ B = B.at[1, col + 1].set(dNdy)
38
+ B = B.at[2, col + 2].set(dNdz)
39
+
40
+ # γ_xy
41
+ B = B.at[3, col + 0].set(dNdy)
42
+ B = B.at[3, col + 1].set(dNdx)
43
+
44
+ # γ_yz
45
+ B = B.at[4, col + 1].set(dNdz)
46
+ B = B.at[4, col + 2].set(dNdy)
47
+
48
+ # γ_zx
49
+ B = B.at[5, col + 0].set(dNdz)
50
+ B = B.at[5, col + 2].set(dNdx)
51
+
52
+ return B
53
+
54
+ B = jax.lax.fori_loop(0, n_nodes, body_fun, B)
55
+ return B
56
+
57
+ B = jax.vmap(B_single)(dN_dx) # (n_q, 6, 24)
58
+ return B
59
+
60
+
61
+ def build_B_matrices_finite(dN_dX: jnp.ndarray, F: jnp.ndarray) -> jnp.ndarray:
62
+ """
63
+ Build finite-strain B matrices (Voigt) at each quadrature point.
64
+
65
+ Args:
66
+ dN_dX: (n_q, n_nodes, 3) gradients of shape functions in reference config.
67
+ F: (n_q, 3, 3) deformation gradient at each quadrature point.
68
+ Returns:
69
+ B: (n_q, 6, n_dofs) where n_dofs = 3 * n_nodes
70
+ Voigt order: [xx, yy, zz, xy, yz, zx]
71
+ """
72
+ n_q, n_nodes, _ = dN_dX.shape
73
+ dofs_per_node = 3
74
+ n_dofs = n_nodes * dofs_per_node
75
+
76
+ def B_single(dN, Fq):
77
+ B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
78
+
79
+ def body_fun(i, B):
80
+ dNa = dN[i, :] # (3,)
81
+ col = dofs_per_node * i
82
+
83
+ ex = jnp.array([1.0, 0.0, 0.0], dtype=dN.dtype)
84
+ ey = jnp.array([0.0, 1.0, 0.0], dtype=dN.dtype)
85
+ ez = jnp.array([0.0, 0.0, 1.0], dtype=dN.dtype)
86
+ grads = jnp.stack(
87
+ (
88
+ jnp.outer(ex, dNa), # (3,3)
89
+ jnp.outer(ey, dNa),
90
+ jnp.outer(ez, dNa),
91
+ ),
92
+ axis=0,
93
+ ) # (3,3,3)
94
+
95
+ def fill_dof(k, B):
96
+ grad_delta = grads[k]
97
+ # Total Lagrange variation: dE = 0.5 * (∇δu · F + (∇δu · F)^T)
98
+ dE = 0.5 * (grad_delta @ Fq + (grad_delta @ Fq).T)
99
+ B = B.at[0, col + k].set(dE[0, 0]) # xx
100
+ B = B.at[1, col + k].set(dE[1, 1]) # yy
101
+ B = B.at[2, col + k].set(dE[2, 2]) # zz
102
+ B = B.at[3, col + k].set(dE[0, 1]) # xy
103
+ B = B.at[4, col + k].set(dE[1, 2]) # yz
104
+ B = B.at[5, col + k].set(dE[2, 0]) # zx
105
+ return B
106
+
107
+ B = jax.lax.fori_loop(0, dofs_per_node, fill_dof, B)
108
+ return B
109
+
110
+ B = jax.lax.fori_loop(0, n_nodes, body_fun, B)
111
+ return B
112
+
113
+ B = jax.vmap(B_single)(dN_dX, F) # (n_q, 6, n_dofs)
114
+ return B
115
+
116
+
117
+ class SmallStrainBMixin:
118
+ dofs_per_node: int = 3
119
+
120
+ def B_small_strain(self, dN_dx: jnp.ndarray) -> jnp.ndarray:
121
+ """
122
+ dN_dx: (n_q, n_nodes, 3)
123
+ returns: (n_q, 6, 3*n_nodes) Voigt: [xx,yy,zz,xy,yz,zx]
124
+ """
125
+ n_q, n_nodes, _ = dN_dx.shape
126
+ n_dofs = self.dofs_per_node * n_nodes
127
+
128
+ def B_single(dN):
129
+ B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
130
+
131
+ def body_fun(i, B):
132
+ dNdx, dNdy, dNdz = dN[i, 0], dN[i, 1], dN[i, 2]
133
+ col = self.dofs_per_node * i
134
+
135
+ # xx, yy, zz
136
+ B = B.at[0, col + 0].set(dNdx)
137
+ B = B.at[1, col + 1].set(dNdy)
138
+ B = B.at[2, col + 2].set(dNdz)
139
+ # xy
140
+ B = B.at[3, col + 0].set(dNdy)
141
+ B = B.at[3, col + 1].set(dNdx)
142
+ # yz
143
+ B = B.at[4, col + 1].set(dNdz)
144
+ B = B.at[4, col + 2].set(dNdy)
145
+ # zx
146
+ B = B.at[5, col + 0].set(dNdz)
147
+ B = B.at[5, col + 2].set(dNdx)
148
+ return B
149
+
150
+ return jax.lax.fori_loop(0, n_nodes, body_fun, B)
151
+
152
+ return jax.vmap(B_single)(dN_dx)
153
+
154
+
155
+ class TotalLagrangeBMixin:
156
+ dofs_per_node: int = 3
157
+
158
+ def B_total_lagrange(self, dN_dX: jnp.ndarray, F: jnp.ndarray) -> jnp.ndarray:
159
+ """
160
+ dN_dX: (n_q, n_nodes, 3), F: (n_q, 3, 3)
161
+ returns: (n_q, 6, 3*n_nodes) Voigt: [xx,yy,zz,xy,yz,zx]
162
+ """
163
+
164
+ n_q, n_nodes, _ = dN_dX.shape
165
+ n_dofs = self.dofs_per_node * n_nodes
166
+
167
+ ex = jnp.array([1.0, 0.0, 0.0], dtype=dN_dX.dtype)
168
+ ey = jnp.array([0.0, 1.0, 0.0], dtype=dN_dX.dtype)
169
+ ez = jnp.array([0.0, 0.0, 1.0], dtype=dN_dX.dtype)
170
+
171
+ def B_single(dN, Fq):
172
+ B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
173
+
174
+ def node_fun(i, B):
175
+ dNa = dN[i, :]
176
+ col = self.dofs_per_node * i
177
+
178
+ grads = jnp.stack(
179
+ (jnp.outer(ex, dNa), jnp.outer(ey, dNa), jnp.outer(ez, dNa)),
180
+ axis=0,
181
+ ) # (3,3,3)
182
+
183
+ def dof_fun(k, B):
184
+ grad_delta = grads[k]
185
+ dE = 0.5 * (grad_delta @ Fq + (grad_delta @ Fq).T)
186
+ B = B.at[0, col + k].set(dE[0, 0])
187
+ B = B.at[1, col + k].set(dE[1, 1])
188
+ B = B.at[2, col + k].set(dE[2, 2])
189
+ B = B.at[3, col + k].set(dE[0, 1])
190
+ B = B.at[4, col + k].set(dE[1, 2])
191
+ B = B.at[5, col + k].set(dE[2, 0])
192
+ return B
193
+
194
+ return jax.lax.fori_loop(0, self.dofs_per_node, dof_fun, B)
195
+
196
+ return jax.lax.fori_loop(0, n_nodes, node_fun, B)
197
+
198
+ return jax.vmap(B_single)(dN_dX, F)
199
+
200
+
201
+ class Basis3D(Protocol):
202
+ def B_small_strain(self, dN_dx: jnp.ndarray) -> jnp.ndarray: ...
203
+ def B_total_lagrange(
204
+ self, dN_dX: jnp.ndarray, F: jnp.ndarray) -> jnp.ndarray: ...
205
+
206
+ quad_points: jnp.ndarray # (n_q, 3)
207
+ quad_weights: jnp.ndarray # (n_q,)
208
+ dofs_per_node: int # usually 3 for vector mechanics
209
+
210
+ @property
211
+ def n_q(self) -> int: ...
212
+
213
+ @property
214
+ def n_nodes(self) -> int: ...
215
+
216
+ def shape_functions(self) -> jnp.ndarray: ...
217
+
218
+ def shape_grads_ref(self) -> jnp.ndarray: ...
219
+
220
+ def spatial_grads_and_detJ(
221
+ self, elem_coords: jnp.ndarray
222
+ ) -> tuple[jnp.ndarray, jnp.ndarray]: ...
223
+
224
+
225
+ def _quadratic_1d(x: jnp.ndarray):
226
+ """1D serendipity shape funcs (vertex, mid-edge) for [-1,1]."""
227
+ N1 = -0.5 * x * (1.0 - x) # at -1
228
+ N2 = 1.0 - x * x # at 0
229
+ N3 = 0.5 * x * (1.0 + x) # at 1
230
+ dN1 = 0.5 * (2.0 * x - 1.0)
231
+ dN2 = -2.0 * x
232
+ dN3 = 0.5 * (2.0 * x + 1.0)
233
+ return (N1, N2, N3), (dN1, dN2, dN3)
234
+
235
+
236
+ def _quad1d_full(x: jnp.ndarray):
237
+ """1D quadratic (Lagrange) shape funcs at nodes (-1, 0, 1)."""
238
+ N0 = 0.5 * x * (x - 1.0)
239
+ N1 = 1.0 - x * x
240
+ N2 = 0.5 * x * (x + 1.0)
241
+ dN0 = x - 0.5
242
+ dN1 = -2.0 * x
243
+ dN2 = x + 0.5
244
+ return (N0, N1, N2), (dN0, dN1, dN2)
245
+
246
+
247
+ @dataclass(eq=False)
248
+ class TetLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
249
+ """4-node linear tetra basis with simple quadrature."""
250
+
251
+ quad_points: jnp.ndarray # (n_q, 3)
252
+ quad_weights: jnp.ndarray # (n_q,)
253
+ dofs_per_node: int = 3
254
+
255
+ @property
256
+ def n_nodes(self) -> int:
257
+ return 4
258
+
259
+ def tree_flatten(self):
260
+ children = (self.quad_points, self.quad_weights)
261
+ return children, {}
262
+
263
+ @classmethod
264
+ def tree_unflatten(cls, aux_data, children):
265
+ qp, qw = children
266
+ return cls(qp, qw)
267
+
268
+ @property
269
+ def n_q(self) -> int:
270
+ return int(self.quad_points.shape[0])
271
+
272
+ @property
273
+ def ref_node_coords(self) -> jnp.ndarray:
274
+ return jnp.array(
275
+ [
276
+ [0.0, 0.0, 0.0],
277
+ [1.0, 0.0, 0.0],
278
+ [0.0, 1.0, 0.0],
279
+ [0.0, 0.0, 1.0],
280
+ ],
281
+ dtype=_FDTYPE,
282
+ )
283
+
284
+ def shape_functions(self) -> jnp.ndarray:
285
+ qp = self.quad_points # (n_q, 3)
286
+ xi = qp[:, 0]
287
+ eta = qp[:, 1]
288
+ zeta = qp[:, 2]
289
+ N1 = 1.0 - xi - eta - zeta
290
+ N2 = xi
291
+ N3 = eta
292
+ N4 = zeta
293
+ return jnp.stack([N1, N2, N3, N4], axis=1) # (n_q,4)
294
+
295
+ def shape_grads_ref(self) -> jnp.ndarray:
296
+ # constant gradients in reference tetra
297
+ dN = jnp.array(
298
+ [
299
+ [-1.0, -1.0, -1.0],
300
+ [1.0, 0.0, 0.0],
301
+ [0.0, 1.0, 0.0],
302
+ [0.0, 0.0, 1.0],
303
+ ],
304
+ dtype=_FDTYPE,
305
+ )
306
+ dN = jnp.tile(dN[None, :, :], (self.n_q, 1, 1)) # (n_q,4,3)
307
+ return dN
308
+
309
+ def spatial_grads_and_detJ(
310
+ self, elem_coords: jnp.ndarray
311
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
312
+ dN_dxi = self.shape_grads_ref()[0] # (4,3) constant
313
+ J = jnp.einsum("ia,ik->ak", elem_coords, dN_dxi) # (3,3)
314
+ J_inv = jnp.linalg.inv(J)
315
+ detJ = jnp.linalg.det(J)
316
+ dN_dx = jnp.einsum("ik,ka->ia", dN_dxi, J_inv) # (4,3)
317
+ dN_dx = jnp.tile(dN_dx[None, :, :], (self.n_q, 1, 1))
318
+ detJ = jnp.full((self.n_q,), detJ, dtype=elem_coords.dtype)
319
+ return dN_dx, detJ
320
+
321
+
322
+ @dataclass(eq=False)
323
+ class TetQuadraticBasis10(SmallStrainBMixin, TotalLagrangeBMixin):
324
+ """10-node quadratic tetra basis (corner + edge mids)."""
325
+
326
+ quad_points: jnp.ndarray
327
+ quad_weights: jnp.ndarray
328
+ dofs_per_node: int = 3
329
+ @property
330
+ def n_nodes(self) -> int:
331
+ return 10
332
+
333
+ def tree_flatten(self):
334
+ return (self.quad_points, self.quad_weights), {}
335
+
336
+ @classmethod
337
+ def tree_unflatten(cls, aux, children):
338
+ qp, qw = children
339
+ return cls(qp, qw)
340
+
341
+ @property
342
+ def n_q(self) -> int:
343
+ return int(self.quad_points.shape[0])
344
+
345
+ def shape_functions(self) -> jnp.ndarray:
346
+ qp = self.quad_points
347
+ L1 = 1.0 - qp[:, 0] - qp[:, 1] - qp[:, 2]
348
+ L2 = qp[:, 0]
349
+ L3 = qp[:, 1]
350
+ L4 = qp[:, 2]
351
+ N1 = L1 * (2 * L1 - 1)
352
+ N2 = L2 * (2 * L2 - 1)
353
+ N3 = L3 * (2 * L3 - 1)
354
+ N4 = L4 * (2 * L4 - 1)
355
+ N5 = 4 * L1 * L2
356
+ N6 = 4 * L2 * L3
357
+ N7 = 4 * L1 * L3
358
+ N8 = 4 * L1 * L4
359
+ N9 = 4 * L2 * L4
360
+ N10 = 4 * L3 * L4
361
+ return jnp.stack([N1, N2, N3, N4, N5, N6, N7, N8, N9, N10], axis=1)
362
+
363
+ def shape_grads_ref(self) -> jnp.ndarray:
364
+ qp = self.quad_points
365
+ L1 = 1.0 - qp[:, 0] - qp[:, 1] - qp[:, 2]
366
+ L2 = qp[:, 0]
367
+ L3 = qp[:, 1]
368
+ L4 = qp[:, 2]
369
+
370
+ dL1 = jnp.array([-1.0, -1.0, -1.0])
371
+ dL2 = jnp.array([1.0, 0.0, 0.0])
372
+ dL3 = jnp.array([0.0, 1.0, 0.0])
373
+ dL4 = jnp.array([0.0, 0.0, 1.0])
374
+
375
+ grads = []
376
+ for a, dLa in zip([L1, L2, L3, L4], [dL1, dL2, dL3, dL4]):
377
+ grads.append((2 * a - 1)[..., None] * dLa[None, :])
378
+
379
+ dN1 = grads[0]
380
+ dN2 = grads[1]
381
+ dN3 = grads[2]
382
+ dN4 = grads[3]
383
+ dN5 = 4 * (L2[..., None] * dL1[None, :] + L1[..., None] * dL2[None, :])
384
+ dN6 = 4 * (L3[..., None] * dL2[None, :] + L2[..., None] * dL3[None, :])
385
+ dN7 = 4 * (L3[..., None] * dL1[None, :] + L1[..., None] * dL3[None, :])
386
+ dN8 = 4 * (L4[..., None] * dL1[None, :] + L1[..., None] * dL4[None, :])
387
+ dN9 = 4 * (L4[..., None] * dL2[None, :] + L2[..., None] * dL4[None, :])
388
+ dN10 = 4 * (L4[..., None] * dL3[None, :] + L3[..., None] * dL4[None, :])
389
+
390
+ dN = jnp.stack([dN1, dN2, dN3, dN4, dN5, dN6, dN7, dN8, dN9, dN10], axis=1)
391
+ return dN # (n_q, 10, 3)
392
+
393
+ def spatial_grads_and_detJ(
394
+ self, elem_coords: jnp.ndarray
395
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
396
+ dN_dxi = self.shape_grads_ref() # (n_q,10,3)
397
+ J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi)
398
+ J_inv = jnp.linalg.inv(J)
399
+ detJ = jnp.linalg.det(J)
400
+ dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv)
401
+ return dN_dx, detJ
402
+
403
+
404
+
405
+ @dataclass(eq=False)
406
+ class HexTriLinearBasis(SmallStrainBMixin, TotalLagrangeBMixin):
407
+ """
408
+ Trilinear 8-node hex element basis with given quadrature rule.
409
+
410
+ quad_points: (n_q, 3) # (xi, eta, zeta) in [-1, 1]^3
411
+ quad_weights: (n_q,) # weights
412
+ """
413
+ quad_points: jnp.ndarray # (n_q, 3)
414
+ quad_weights: jnp.ndarray # (n_q,)
415
+ dofs_per_node: int = 3
416
+
417
+ @property
418
+ def n_nodes(self) -> int:
419
+ return 8
420
+
421
+ def tree_flatten(self):
422
+ children = (self.quad_points, self.quad_weights)
423
+ aux_data = {}
424
+ return children, aux_data
425
+
426
+ @classmethod
427
+ def tree_unflatten(cls, aux_data, children):
428
+ qp, qw = children
429
+ return cls(qp, qw)
430
+
431
+ @property
432
+ def n_q(self) -> int:
433
+ return int(self.quad_points.shape[0])
434
+
435
+ @property
436
+ def ref_node_signs(self) -> jnp.ndarray:
437
+ """
438
+ Node signs (8,3) for (-1,1)^3 reference hex.
439
+ Node ordering:
440
+ 0: (-1,-1,-1)
441
+ 1: ( 1,-1,-1)
442
+ 2: ( 1, 1,-1)
443
+ 3: (-1, 1,-1)
444
+ 4: (-1,-1, 1)
445
+ 5: ( 1,-1, 1)
446
+ 6: ( 1, 1, 1)
447
+ 7: (-1, 1, 1)
448
+ """
449
+ return jnp.array(
450
+ [
451
+ [-1.0, -1.0, -1.0],
452
+ [ 1.0, -1.0, -1.0],
453
+ [ 1.0, 1.0, -1.0],
454
+ [-1.0, 1.0, -1.0],
455
+ [-1.0, -1.0, 1.0],
456
+ [ 1.0, -1.0, 1.0],
457
+ [ 1.0, 1.0, 1.0],
458
+ [-1.0, 1.0, 1.0],
459
+ ],
460
+ dtype=_FDTYPE,
461
+ )
462
+
463
+ # ---------- reference shape functions & gradients ----------
464
+ def shape_functions(self) -> jnp.ndarray:
465
+ """
466
+ Evaluate shape functions at all quadrature points.
467
+ Returns: (n_q, 8)
468
+ """
469
+ qp = self.quad_points # (n_q, 3)
470
+ s = self.ref_node_signs # (8, 3)
471
+
472
+ # broadcast: (n_q, 1, 3) * (1, 8, 3) -> (n_q, 8, 3)
473
+ # but we only need linear forms of xi,eta,ζ
474
+ xi = qp[:, 0:1] # (n_q, 1)
475
+ eta = qp[:, 1:2]
476
+ zeta = qp[:, 2:3]
477
+
478
+ s_xi = s[:, 0] # (8,)
479
+ s_eta = s[:, 1]
480
+ s_zeta = s[:, 2]
481
+
482
+ # (n_q, 8) via broadcasting
483
+ f_xi = 1.0 + xi * s_xi
484
+ f_eta = 1.0 + eta * s_eta
485
+ f_zeta = 1.0 + zeta * s_zeta
486
+
487
+ N = 0.125 * f_xi * f_eta * f_zeta # (n_q, 8)
488
+ return N
489
+
490
+ def shape_grads_ref(self) -> jnp.ndarray:
491
+ """
492
+ Gradients in reference coords (ξ,η,ζ) at all quad points.
493
+ Returns: (n_q, 8, 3) [dN/dξ, dN/dη, dN/dζ]
494
+ """
495
+ qp = self.quad_points # (n_q, 3)
496
+ s = self.ref_node_signs # (8, 3)
497
+ xi = qp[:, 0:1]
498
+ eta = qp[:, 1:2]
499
+ zeta = qp[:, 2:3]
500
+
501
+ s_xi = s[:, 0] # (8,)
502
+ s_eta = s[:, 1]
503
+ s_zeta = s[:, 2]
504
+
505
+ # helper linear terms
506
+ f_xi = 1.0 + xi * s_xi
507
+ f_eta = 1.0 + eta * s_eta
508
+ f_zeta = 1.0 + zeta * s_zeta
509
+
510
+ # dN/dξ
511
+ dN_dxi = 0.125 * s_xi * f_eta * f_zeta # (n_q, 8)
512
+ # dN/dη
513
+ dN_deta = 0.125 * s_eta * f_xi * f_zeta
514
+ # dN/dζ
515
+ dN_dzeta = 0.125 * s_zeta * f_xi * f_eta
516
+
517
+ # stack into (n_q, 8, 3)
518
+ dN = jnp.stack([dN_dxi, dN_deta, dN_dzeta], axis=-1)
519
+ return dN # (n_q, 8, 3)
520
+
521
+ # ---------- mapping to physical element ----------
522
+ def spatial_grads_and_detJ(
523
+ self, elem_coords: jnp.ndarray
524
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
525
+ """
526
+ Compute spatial gradients and detJ for one element.
527
+
528
+ Parameters
529
+ ----------
530
+ elem_coords : jnp.ndarray
531
+ Element coordinates of shape (8, 3).
532
+
533
+ Returns
534
+ -------
535
+ dN_dx : jnp.ndarray
536
+ Spatial gradients of shape (n_q, 8, 3).
537
+ detJ : jnp.ndarray
538
+ Determinant of the Jacobian of shape (n_q,).
539
+ """
540
+ dN_dxi = self.shape_grads_ref() # (n_q, 8, 3)
541
+
542
+ # Jacobian: J(q)[α,k] = sum_i X_i[α] * ∂N_i/∂ξ_k
543
+ # elem_coords: (8,3) -> i,α
544
+ # dN_dxi: (n_q,8,3)-> q,i,k
545
+ J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi) # (n_q, 3, 3)
546
+
547
+ J_inv = jnp.linalg.inv(J) # (n_q, 3, 3)
548
+ detJ = jnp.linalg.det(J) # (n_q,)
549
+
550
+ # ∇_x N = ∇_ξ N · J^{-1}
551
+ dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv) # (n_q, 8, 3)
552
+ return dN_dx, detJ
553
+
554
+
555
+ @dataclass(eq=False)
556
+ class HexSerendipityBasis20(SmallStrainBMixin, TotalLagrangeBMixin):
557
+ """
558
+ 20-node serendipity hex element basis (corner + edge midpoints).
559
+ """
560
+ quad_points: jnp.ndarray # (n_q, 3)
561
+ quad_weights: jnp.ndarray # (n_q,)
562
+ dofs_per_node: int = 3
563
+
564
+ @property
565
+ def n_nodes(self) -> int:
566
+ return 20
567
+
568
+ def tree_flatten(self):
569
+ children = (self.quad_points, self.quad_weights)
570
+ aux_data = {}
571
+ return children, aux_data
572
+
573
+ @classmethod
574
+ def tree_unflatten(cls, aux_data, children):
575
+ qp, qw = children
576
+ return cls(qp, qw)
577
+
578
+ @property
579
+ def n_q(self) -> int:
580
+ return int(self.quad_points.shape[0])
581
+
582
+ @property
583
+ def ref_node_coords(self) -> jnp.ndarray:
584
+ corners = jnp.array(
585
+ [
586
+ [-1.0, -1.0, -1.0],
587
+ [ 1.0, -1.0, -1.0],
588
+ [ 1.0, 1.0, -1.0],
589
+ [-1.0, 1.0, -1.0],
590
+ [-1.0, -1.0, 1.0],
591
+ [ 1.0, -1.0, 1.0],
592
+ [ 1.0, 1.0, 1.0],
593
+ [-1.0, 1.0, 1.0],
594
+ ],
595
+ dtype=_FDTYPE,
596
+ )
597
+ edges = jnp.array(
598
+ [
599
+ [ 0.0, -1.0, -1.0], # 0-1
600
+ [ 1.0, 0.0, -1.0], # 1-2
601
+ [ 0.0, 1.0, -1.0], # 2-3
602
+ [-1.0, 0.0, -1.0], # 3-0
603
+ [ 0.0, -1.0, 1.0], # 4-5
604
+ [ 1.0, 0.0, 1.0], # 5-6
605
+ [ 0.0, 1.0, 1.0], # 6-7
606
+ [-1.0, 0.0, 1.0], # 7-4
607
+ [-1.0, -1.0, 0.0], # 0-4
608
+ [ 1.0, -1.0, 0.0], # 1-5
609
+ [ 1.0, 1.0, 0.0], # 2-6
610
+ [-1.0, 1.0, 0.0], # 3-7
611
+ ],
612
+ dtype=_FDTYPE,
613
+ )
614
+ return jnp.concatenate([corners, edges], axis=0) # (20,3)
615
+
616
+ def shape_functions(self) -> jnp.ndarray:
617
+ qp = self.quad_points # (n_q, 3)
618
+ xi = qp[:, 0:1]
619
+ eta = qp[:, 1:2]
620
+ zeta = qp[:, 2:3]
621
+
622
+ # corners
623
+ s = jnp.array(
624
+ [
625
+ [-1, -1, -1],
626
+ [ 1, -1, -1],
627
+ [ 1, 1, -1],
628
+ [-1, 1, -1],
629
+ [-1, -1, 1],
630
+ [ 1, -1, 1],
631
+ [ 1, 1, 1],
632
+ [-1, 1, 1],
633
+ ],
634
+ dtype=_FDTYPE,
635
+ )
636
+ sx = s[:, 0]
637
+ sy = s[:, 1]
638
+ sz = s[:, 2]
639
+ term = xi * sx + eta * sy + zeta * sz - 2.0 # (n_q,8)
640
+ N_corner = 0.125 * (1 + sx * xi) * (1 + sy * eta) * (1 + sz * zeta) * term # (n_q,8)
641
+
642
+ # edges
643
+ edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # eta, zeta fixed
644
+ edges_y = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # xi fixed
645
+ edges_z = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # xi, eta fixed
646
+
647
+ N_edges = []
648
+ # along xi (1 - xi^2)
649
+ for sy, sz in edges_x:
650
+ N_edges.append(0.25 * (1 - xi * xi) * (1 + sy * eta) * (1 + sz * zeta))
651
+ # along eta
652
+ for sx, sz in edges_y:
653
+ N_edges.append(0.25 * (1 - eta * eta) * (1 + sx * xi) * (1 + sz * zeta))
654
+ # along zeta
655
+ for sx, sy in edges_z:
656
+ N_edges.append(0.25 * (1 - zeta * zeta) * (1 + sx * xi) * (1 + sy * eta))
657
+
658
+ N_edges = jnp.concatenate(N_edges, axis=1) # (n_q, 12)
659
+ return jnp.concatenate([N_corner, N_edges], axis=1) # (n_q,20)
660
+
661
+ def shape_grads_ref(self) -> jnp.ndarray:
662
+ qp = self.quad_points
663
+ xi = qp[:, 0:1]
664
+ eta = qp[:, 1:2]
665
+ zeta = qp[:, 2:3]
666
+
667
+ s = jnp.array(
668
+ [
669
+ [-1, -1, -1],
670
+ [ 1, -1, -1],
671
+ [ 1, 1, -1],
672
+ [-1, 1, -1],
673
+ [-1, -1, 1],
674
+ [ 1, -1, 1],
675
+ [ 1, 1, 1],
676
+ [-1, 1, 1],
677
+ ],
678
+ dtype=_FDTYPE,
679
+ )
680
+ sx = s[:, 0]
681
+ sy = s[:, 1]
682
+ sz = s[:, 2]
683
+ term = xi * sx + eta * sy + zeta * sz - 2.0
684
+
685
+ dN_dxi_corner = (sx / 8.0) * (1 + sy * eta) * (1 + sz * zeta) * (term + (1 + sx * xi))
686
+ dN_deta_corner = (sy / 8.0) * (1 + sx * xi) * (1 + sz * zeta) * (term + (1 + sy * eta))
687
+ dN_dzeta_corner = (sz / 8.0) * (1 + sx * xi) * (1 + sy * eta) * (term + (1 + sz * zeta))
688
+
689
+ d_corner = jnp.stack(
690
+ [
691
+ dN_dxi_corner,
692
+ dN_deta_corner,
693
+ dN_dzeta_corner,
694
+ ],
695
+ axis=2,
696
+ ) # (n_q,8,3)
697
+
698
+ # edges derivatives
699
+ d_list = []
700
+ # along xi
701
+ edges_x = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
702
+ for sy_val, sz_val in edges_x:
703
+ sy_ = sy_val
704
+ sz_ = sz_val
705
+ dxi = -0.5 * xi * (1 + sy_ * eta) * (1 + sz_ * zeta)
706
+ deta = 0.25 * (1 - xi * xi) * sy_ * (1 + sz_ * zeta)
707
+ dzeta = 0.25 * (1 - xi * xi) * (1 + sy_ * eta) * sz_
708
+ d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
709
+ # along eta
710
+ edges_y = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
711
+ for sx_val, sz_val in edges_y:
712
+ sx_ = sx_val
713
+ sz_ = sz_val
714
+ dxi = 0.25 * (1 - eta * eta) * sx_ * (1 + sz_ * zeta)
715
+ deta = -0.5 * eta * (1 + sx_ * xi) * (1 + sz_ * zeta)
716
+ dzeta = 0.25 * (1 - eta * eta) * (1 + sx_ * xi) * sz_
717
+ d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
718
+ # along zeta
719
+ edges_z = [(-1, -1), (1, -1), (1, 1), (-1, 1)]
720
+ for sx_val, sy_val in edges_z:
721
+ sx_ = sx_val
722
+ sy_ = sy_val
723
+ dxi = 0.25 * (1 - zeta * zeta) * sx_ * (1 + sy_ * eta)
724
+ deta = 0.25 * (1 - zeta * zeta) * (1 + sx_ * xi) * sy_
725
+ dzeta = -0.5 * zeta * (1 + sx_ * xi) * (1 + sy_ * eta)
726
+ d_list.append(jnp.stack([dxi, deta, dzeta], axis=2))
727
+
728
+ d_edges = jnp.concatenate(d_list, axis=1) # (n_q,12,3)
729
+ return jnp.concatenate([d_corner, d_edges], axis=1) # (n_q,20,3)
730
+
731
+ def spatial_grads_and_detJ(
732
+ self, elem_coords: jnp.ndarray
733
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
734
+ dN_dxi = self.shape_grads_ref() # (n_q, 20, 3)
735
+ J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi) # (n_q, 3, 3)
736
+ J_inv = jnp.linalg.inv(J)
737
+ detJ = jnp.linalg.det(J)
738
+ dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv)
739
+ return dN_dx, detJ
740
+
741
+
742
+ @dataclass(eq=False)
743
+ class HexTriQuadraticBasis27(SmallStrainBMixin, TotalLagrangeBMixin):
744
+ """27-node triquadratic hex (tensor-product)."""
745
+
746
+ quad_points: jnp.ndarray
747
+ quad_weights: jnp.ndarray
748
+ dofs_per_node: int = 3
749
+ @property
750
+ def n_nodes(self) -> int:
751
+ return 27
752
+
753
+ def tree_flatten(self):
754
+ return (self.quad_points, self.quad_weights), {}
755
+
756
+ @classmethod
757
+ def tree_unflatten(cls, aux, children):
758
+ qp, qw = children
759
+ return cls(qp, qw)
760
+
761
+ @property
762
+ def n_q(self) -> int:
763
+ return int(self.quad_points.shape[0])
764
+
765
+ def shape_functions(self) -> jnp.ndarray:
766
+ qp = self.quad_points
767
+ xi = qp[:, 0]
768
+ eta = qp[:, 1]
769
+ zeta = qp[:, 2]
770
+
771
+ (Nx0, Nx1, Nx2), _ = _quad1d_full(xi)
772
+ (Ny0, Ny1, Ny2), _ = _quad1d_full(eta)
773
+ (Nz0, Nz1, Nz2), _ = _quad1d_full(zeta)
774
+
775
+ N = []
776
+ for k, Nk in enumerate([Nz0, Nz1, Nz2]):
777
+ for j, Nj in enumerate([Ny0, Ny1, Ny2]):
778
+ for i, Ni in enumerate([Nx0, Nx1, Nx2]):
779
+ N.append(Ni * Nj * Nk)
780
+ return jnp.stack(N, axis=1) # (n_q, 27)
781
+
782
+ def shape_grads_ref(self) -> jnp.ndarray:
783
+ qp = self.quad_points
784
+ xi = qp[:, 0]
785
+ eta = qp[:, 1]
786
+ zeta = qp[:, 2]
787
+
788
+ (Nx0, Nx1, Nx2), (dNx0, dNx1, dNx2) = _quad1d_full(xi)
789
+ (Ny0, Ny1, Ny2), (dNy0, dNy1, dNy2) = _quad1d_full(eta)
790
+ (Nz0, Nz1, Nz2), (dNz0, dNz1, dNz2) = _quad1d_full(zeta)
791
+
792
+ grads = []
793
+ for k in range(3):
794
+ Nz = [Nz0, Nz1, Nz2][k]
795
+ dNz = [dNz0, dNz1, dNz2][k]
796
+ for j in range(3):
797
+ Ny = [Ny0, Ny1, Ny2][j]
798
+ dNy = [dNy0, dNy1, dNy2][j]
799
+ for i in range(3):
800
+ Nx = [Nx0, Nx1, Nx2][i]
801
+ dNx = [dNx0, dNx1, dNx2][i]
802
+ dxi = dNx * Ny * Nz
803
+ deta = Nx * dNy * Nz
804
+ dzeta = Nx * Ny * dNz
805
+ grads.append(jnp.stack([dxi, deta, dzeta], axis=1))
806
+ return jnp.stack(grads, axis=1) # (n_q, 27, 3)
807
+
808
+ def spatial_grads_and_detJ(
809
+ self, elem_coords: jnp.ndarray
810
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
811
+ dN_dxi = self.shape_grads_ref() # (n_q, 27, 3)
812
+ J = jnp.einsum("ia,qik->qak", elem_coords, dN_dxi)
813
+ J_inv = jnp.linalg.inv(J)
814
+ detJ = jnp.linalg.det(J)
815
+ dN_dx = jnp.einsum("qik,qka->qia", dN_dxi, J_inv)
816
+ return dN_dx, detJ
817
+
818
+
819
+ @jax.tree_util.register_pytree_node_class
820
+ class TetLinearBasisPytree(TetLinearBasis):
821
+ pass
822
+
823
+
824
+ @jax.tree_util.register_pytree_node_class
825
+ class TetQuadraticBasis10Pytree(TetQuadraticBasis10):
826
+ pass
827
+
828
+
829
+ @jax.tree_util.register_pytree_node_class
830
+ class HexTriLinearBasisPytree(HexTriLinearBasis):
831
+ pass
832
+
833
+
834
+ @jax.tree_util.register_pytree_node_class
835
+ class HexSerendipityBasis20Pytree(HexSerendipityBasis20):
836
+ pass
837
+
838
+
839
+ @jax.tree_util.register_pytree_node_class
840
+ class HexTriQuadraticBasis27Pytree(HexTriQuadraticBasis27):
841
+ pass
842
+
843
+
844
+ def _gauss_legendre_1d(order: int) -> tuple[jnp.ndarray, jnp.ndarray]:
845
+ """1D Gauss-Legendre points and weights of given order."""
846
+ if order <= 0:
847
+ raise ValueError("quadrature order must be positive")
848
+ pts, wts = np.polynomial.legendre.leggauss(order)
849
+ return jnp.array(pts, dtype=_FDTYPE), jnp.array(wts, dtype=_FDTYPE)
850
+
851
+
852
+ def _gl_points_for_degree(degree: int) -> int:
853
+ """
854
+ Map polynomial exactness degree to Gauss-Legendre point count in 1D.
855
+ n points integrate degree (2n-1) exactly.
856
+ """
857
+ if degree <= 0:
858
+ return 1
859
+ return int(np.ceil((degree + 1) / 2))
860
+
861
+
862
+ def _tet_quadrature(degree: int) -> tuple[jnp.ndarray, jnp.ndarray]:
863
+ """
864
+ Degree-based quadrature rules for reference tetra (volume = 1/6).
865
+ degree<=1: 1-point; degree<=2: 4-point; degree>=3: 5-point (Stroud T3-5).
866
+ """
867
+ if degree <= 1:
868
+ qp = jnp.array([[0.25, 0.25, 0.25]], dtype=_FDTYPE)
869
+ qw = jnp.array([1.0 / 6.0], dtype=_FDTYPE)
870
+ return qp, qw
871
+ if degree <= 2:
872
+ qp = jnp.array(
873
+ [
874
+ [0.58541020, 0.13819660, 0.13819660],
875
+ [0.13819660, 0.58541020, 0.13819660],
876
+ [0.13819660, 0.13819660, 0.58541020],
877
+ [0.13819660, 0.13819660, 0.13819660],
878
+ ],
879
+ dtype=_FDTYPE,
880
+ )
881
+ qw = jnp.full((4,), (1.0 / 24.0), dtype=_FDTYPE)
882
+ return qp, qw
883
+ # degree 3 rule: centroid + 4 symmetric points
884
+ qp = jnp.array(
885
+ [
886
+ [0.25, 0.25, 0.25],
887
+ [0.50, 1.0 / 6.0, 1.0 / 6.0],
888
+ [1.0 / 6.0, 0.50, 1.0 / 6.0],
889
+ [1.0 / 6.0, 1.0 / 6.0, 0.50],
890
+ [1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0],
891
+ ],
892
+ dtype=_FDTYPE,
893
+ )
894
+ qw = jnp.array(
895
+ [-2.0 / 15.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0, 3.0 / 40.0],
896
+ dtype=_FDTYPE,
897
+ )
898
+ return qp, qw
899
+
900
+
901
+ def make_tet_basis(intorder: int = 2) -> TetLinearBasis:
902
+ """Create a linear tet basis with degree-based quadrature."""
903
+ qp, qw = _tet_quadrature(intorder)
904
+ return TetLinearBasis(qp, qw)
905
+
906
+
907
+ def make_tet_basis_pytree(intorder: int = 2) -> TetLinearBasisPytree:
908
+ """Create a pytree linear tet basis with degree-based quadrature."""
909
+ qp, qw = _tet_quadrature(intorder)
910
+ return TetLinearBasisPytree(qp, qw)
911
+
912
+
913
+ def make_hex_basis(intorder: int = 2) -> HexTriLinearBasis:
914
+ """
915
+ Trilinear hex basis with tensor-product Gauss-Legendre quadrature.
916
+ intorder = polynomial exactness degree (scikit-fem style).
917
+ degree=1 → 1×1×1, degree=2/3 → 2×2×2, degree=4/5 → 3×3×3, etc.
918
+ """
919
+ n_1d = _gl_points_for_degree(intorder)
920
+ pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
921
+ xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
922
+ qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3) # (intorder^3, 3)
923
+
924
+ w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
925
+ qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1) # (intorder^3,)
926
+ return HexTriLinearBasis(qp, qw)
927
+
928
+
929
+ def make_hex_basis_pytree(intorder: int = 2) -> HexTriLinearBasisPytree:
930
+ """Create a pytree trilinear hex basis with tensor-product quadrature."""
931
+ n_1d = _gl_points_for_degree(intorder)
932
+ pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
933
+ xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
934
+ qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
935
+
936
+ w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
937
+ qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
938
+ return HexTriLinearBasisPytree(qp, qw)
939
+
940
+
941
+ def make_hex20_basis(intorder: int = 2) -> HexSerendipityBasis20:
942
+ """Create a serendipity hex basis with tensor-product quadrature."""
943
+ n_1d = _gl_points_for_degree(intorder)
944
+ pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
945
+ xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
946
+ qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
947
+
948
+ w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
949
+ qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
950
+ return HexSerendipityBasis20(qp, qw)
951
+
952
+
953
+ def make_hex20_basis_pytree(intorder: int = 2) -> HexSerendipityBasis20Pytree:
954
+ """Create a pytree serendipity hex basis with tensor-product quadrature."""
955
+ n_1d = _gl_points_for_degree(intorder)
956
+ pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
957
+ xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
958
+ qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
959
+
960
+ w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
961
+ qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
962
+ return HexSerendipityBasis20Pytree(qp, qw)
963
+
964
+
965
+ def make_hex27_basis(intorder: int = 3) -> HexTriQuadraticBasis27:
966
+ """Create a triquadratic hex basis with tensor-product quadrature."""
967
+ n_1d = _gl_points_for_degree(intorder)
968
+ pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
969
+ xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
970
+ qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
971
+ w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
972
+ qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
973
+ return HexTriQuadraticBasis27(qp, qw)
974
+
975
+
976
+ def make_hex27_basis_pytree(intorder: int = 3) -> HexTriQuadraticBasis27Pytree:
977
+ """Create a pytree triquadratic hex basis with tensor-product quadrature."""
978
+ n_1d = _gl_points_for_degree(intorder)
979
+ pt_1d, wt_1d = _gauss_legendre_1d(n_1d)
980
+ xi, eta, zeta = jnp.meshgrid(pt_1d, pt_1d, pt_1d, indexing="ij")
981
+ qp = jnp.stack([xi, eta, zeta], axis=-1).reshape(-1, 3)
982
+ w = jnp.meshgrid(wt_1d, wt_1d, wt_1d, indexing="ij")
983
+ qw = jnp.stack(w, axis=-1).prod(axis=-1).reshape(-1)
984
+ return HexTriQuadraticBasis27Pytree(qp, qw)
985
+
986
+
987
+ def make_tet10_basis(intorder: int = 2) -> TetQuadraticBasis10:
988
+ """Create a quadratic tet basis with degree-based quadrature."""
989
+ qp, qw = _tet_quadrature(intorder)
990
+ return TetQuadraticBasis10(qp, qw)
991
+
992
+
993
+ def make_tet10_basis_pytree(intorder: int = 2) -> TetQuadraticBasis10Pytree:
994
+ """Create a pytree quadratic tet basis with degree-based quadrature."""
995
+ qp, qw = _tet_quadrature(intorder)
996
+ return TetQuadraticBasis10Pytree(qp, qw)