PyMHD 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymhd/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/__init__.py
7
+ -----------------
8
+
9
+ Unified interface for PyMHD.
10
+ """
11
+
12
+ from .turbulence import Scalar, Vector, ScalarField, VectorField, Turbulence, avg, std
13
+ from .derivatives.derivative import Dx, Dy, Dz, grad, div, curl, laplacian, Algorithm
14
+
15
+ from .spectra import Spectrum, Spectrum1D, EnergySpectra
16
+ from .numdiss import NumericalDissipation
17
+ from .preprocess import output2turbulence
18
+
19
+ from .plot import plot
20
+
21
+ __all__ = [
22
+ "Scalar", "Vector",
23
+ "ScalarField", "VectorField", "avg", "std", "Turbulence",
24
+ "Algorithm",
25
+ "Dx", "Dy", "Dz", "grad", "div", "curl", "laplacian",
26
+ "Spectrum", "Spectrum1D",
27
+ "EnergySpectra",
28
+ "NumericalDissipation",
29
+ "output2turbulence",
30
+ "plot",
31
+ ]
@@ -0,0 +1,278 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/TENO.py
7
+ --------------------------
8
+
9
+ Implements the TENO7-M reconstruction scheme with two backends:
10
+ - NumPy (default): more compatible, but limited to a single CPU core
11
+ - JAX: requires JAX, but leverages multi-core CPU and GPU acceleration
12
+ """
13
+
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+ from numpy.typing import ArrayLike
18
+
19
+ hasJAX = False
20
+
21
+ try:
22
+ import jax
23
+ jax.config.update("jax_enable_x64", True)
24
+
25
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
26
+ jax.config.update("jax_num_cpu_devices", 8)
27
+
28
+ mesh: Mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=("X", "Y"))
29
+
30
+ sharding = {
31
+ "x": NamedSharding(mesh, P(None, "X", "Y")),
32
+ "y": NamedSharding(mesh, P("X", None, "Y")),
33
+ "z": NamedSharding(mesh, P("X", "Y", None)),
34
+ }
35
+
36
+ hasJAX = True
37
+
38
+ except ImportError:
39
+ pass
40
+
41
+ def coreTENO7M(
42
+ u: ArrayLike, direction: str, axis: int, CT: float = 0.01
43
+ ) -> tuple[Any, Any]:
44
+ """TENO7-M reconstruction.
45
+
46
+ Reconstruct the cell interface values from the cell averages
47
+
48
+ Parameters
49
+ ----------
50
+ u : npt.ArrayLike, cell-centered or cell-averaged values
51
+ direction: str, reconstruction direction, 'L' (for i - 1/2) or 'R' (for i + 1/2)
52
+ axis : int, the axis direction (0, 1, 2 corresponds to x, y, z)
53
+ CT : float, smoothness threshold, defaults to 0.01
54
+
55
+ Returns
56
+ -------
57
+ uHalf : np.ndarray or jax.Array, reconstructed value at the cell interfaces (i ± 1/2)
58
+ shockflag: np.ndarray or jax.Array, shock detection flag
59
+ """
60
+ if hasJAX:
61
+ import jax.numpy as jnp
62
+ xp = jnp
63
+ else:
64
+ xp = np
65
+ u = xp.asarray(u)
66
+
67
+ def roll(arr, offset: int):
68
+ shift = -offset if direction == "R" else offset
69
+ return xp.roll(arr, shift, axis=axis)
70
+
71
+ uR3 = roll(u, 3)
72
+ uR2 = roll(u, 2)
73
+ uR1 = roll(u, 1)
74
+ u0 = u
75
+ uL1 = roll(u, -1)
76
+ uL2 = roll(u, -2)
77
+ uL3 = roll(u, -3)
78
+
79
+ # Calculate the smoothness indicators for (7, 4) stencil
80
+ # Ref: Balsara & Shu, Journal of Computational Physics 160, no. 2 (May 2000): 405–52
81
+ beta0 = 1 / 240 * (
82
+ uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
83
+ uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
84
+ uL1 * (11003 * uL1 - 9402 * u0) + \
85
+ u0 * ( 2107 * u0)
86
+ )
87
+ beta1 = 1 / 240 * (
88
+ uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
89
+ uL1 * (2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
90
+ u0 * (3443 * u0 - 2522 * uR1) + \
91
+ uR1 * ( 547 * uR1)
92
+ )
93
+ beta2 = 1 / 240 * (
94
+ uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
95
+ u0 * (3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
96
+ uR1 * (2843 * uR1 - 1642 * uR2) + \
97
+ uR2 * ( 267 * uR2)
98
+ )
99
+ beta3 = 1 / 240 * (
100
+ u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
101
+ uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
102
+ uR2 * ( 7043 * uR2 - 3882 * uR3) + \
103
+ uR3 * ( 547 * uR3)
104
+ )
105
+
106
+ epsilon = 1e-40
107
+ q = 6
108
+
109
+ gamma0 = 1 / (beta0 + epsilon) ** q
110
+ gamma1 = 1 / (beta1 + epsilon) ** q
111
+ gamma2 = 1 / (beta2 + epsilon) ** q
112
+ gamma3 = 1 / (beta3 + epsilon) ** q
113
+
114
+ chi0mask = (gamma0 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
115
+ chi1mask = (gamma1 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
116
+ chi2mask = (gamma2 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
117
+ chi3mask = (gamma3 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
118
+
119
+ delta0 = ~(chi0mask & roll(chi1mask, -1) & roll(chi2mask, -2) & roll(chi3mask, -3))
120
+ delta1 = ~(chi1mask & roll(chi2mask, -1) & roll(chi3mask, -2) & roll(chi0mask, 1))
121
+ delta2 = ~(chi2mask & roll(chi3mask, -1) & roll(chi1mask, 1) & roll(chi0mask, 2))
122
+ delta3 = ~(chi3mask & roll(chi2mask, 1) & roll(chi1mask, 2) & roll(chi0mask, 3))
123
+
124
+ # ----- Dispersion-relation-preserving (DRP) -----
125
+ # Ref: Tam. Computational aeroacoustics: A wave number approach, 2012.
126
+ a71 = 0.77088238051822552
127
+ a72 = -0.166705904414580469
128
+ a73 = 0.02084314277031176
129
+
130
+ a91 = 0.8301178834769906875382633360472085
131
+ a92 = -0.23175338776901819008451262109655756
132
+ a93 = 0.05287205020483696423592156502901203
133
+ a94 = -6.306814638366300019250697235282424e-3
134
+
135
+ cL3 = 2 * a94
136
+ cL2 = 2 * a93 - a73
137
+ cL1 = 2 * a92 + 2 * a94 - a72 + a73
138
+ c0 = 2 * a91 + 2 * a93 - a71 + a72 - a73
139
+ cR1 = 2 * a92 + 2 * a94 - a72 + a73 + a71
140
+ cR2 = 2 * a93 - a73 + a72
141
+ cR3 = 2 * a94 + a73
142
+
143
+ # Optimal reconstruction polynomial for smooth stencils
144
+ uHalf = cL3 * uL3 + cL2 * uL2 + cL1 * uL1 + c0 * u0 + cR1 * uR1 + cR2 * uR2 + cR3 * uR3
145
+
146
+ # a single discontinuity lies in (uL3, uL2)
147
+ uHalf = xp.where(
148
+ (~delta0) & delta1 & delta2 & delta3,
149
+ 1 / 60 * uL2 - 2 / 15 * uL1 + 37 / 60 * u0 + 37 / 60 * uR1 - 2 / 15 * uR2 + 1 / 60 * uR3,
150
+ uHalf,
151
+ )
152
+
153
+ # a single discontinuity lies in (uL2, uL1)
154
+ uHalf = xp.where(
155
+ (~delta0) & (~delta1) & delta2 & delta3,
156
+ -1 / 20 * uL1 + 9 / 20 * u0 + 47 / 60 * uR1 - 13 / 60 * uR2 + 1 / 30 * uR3,
157
+ uHalf,
158
+ )
159
+
160
+ # a single discontinuity lies in (uL1, u0)
161
+ uHalf = xp.where(
162
+ (~delta0) & (~delta1) & (~delta2) & delta3,
163
+ 1 / 4 * u0 + 13 / 12 * uR1 - 5 / 12 * uR2 + 1 / 12 * uR3,
164
+ uHalf,
165
+ )
166
+
167
+ # a single discontinuity lies in (u0, uR1)
168
+ uHalf = xp.where(
169
+ delta0 & (~delta1) & (~delta2) & (~delta3),
170
+ -1 / 4 * uL3 + 13 / 12 * uL2 - 23 / 12 * uL1 + 25 / 12 * u0,
171
+ uHalf,
172
+ )
173
+
174
+ # a single discontinuity lies in (uR1, uR2)
175
+ uHalf = xp.where(
176
+ delta0 & delta1 & (~delta2) & (~delta3),
177
+ -1 / 20 * uL3 + 17 / 60 * uL2 - 43 / 60 * uL1 + 77 / 60 * u0 + 1 / 5 * uR1,
178
+ uHalf,
179
+ )
180
+
181
+ # a single discontinuity lies in (uR2, uR3)
182
+ uHalf = xp.where(
183
+ delta0 & delta1 & delta2 & (~delta3),
184
+ -1 / 60 * uL3 + 7 / 60 * uL2 - 23 / 60 * uL1 + 19 / 20 * u0 + 11 / 30 * uR1 - 1 / 30 * uR2,
185
+ uHalf,
186
+ )
187
+
188
+ # two discontinuities lie in (uL3, uL2) and (uR2, uR3)
189
+ uHalf = xp.where(
190
+ (~delta0) & delta1 & delta2 & (~delta3),
191
+ 2 / 60 * uL2 - 13 / 60 * uL1 + 47 / 60 * u0 + 27 / 60 * uR1 - 3 / 60 * uR2,
192
+ uHalf,
193
+ )
194
+
195
+ # Any stencil other than fully smooth is considered discontinuous.
196
+ shockflag = ~(delta0 & delta1 & delta2 & delta3)
197
+
198
+ return uHalf, shockflag
199
+
200
+ if hasJAX:
201
+ import jax
202
+ coreTENO7M = jax.jit(coreTENO7M, static_argnames=("direction", "axis", "CT"))
203
+
204
+ def TENO7M(
205
+ u: ArrayLike, axis: int, mode: str = "hybrid", CT: float = 0.01
206
+ ) -> tuple[np.ndarray, np.ndarray]:
207
+ """TENO7-M reconstruction."""
208
+ if hasJAX:
209
+ import jax.numpy as jnp
210
+ xp: Any = jnp
211
+ else:
212
+ xp = np
213
+ u = xp.asarray(u)
214
+
215
+ if axis == 2:
216
+ u = xp.moveaxis(u, 2, 0)
217
+
218
+ uL, shockFlagL = coreTENO7M(u, direction="L", axis=0, CT=CT)
219
+ uR, shockFlagR = coreTENO7M(u, direction="R", axis=0, CT=CT)
220
+
221
+ if mode == "inner":
222
+ return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
223
+
224
+ if mode == "upwind":
225
+ uLupwind = xp.roll(uR, 1, axis=0)
226
+ return np.asarray(xp.moveaxis(uLupwind, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
227
+
228
+ if mode == "hybrid":
229
+ uLhybrid = xp.where(~shockFlagL, (uL + xp.roll(uR, 1, axis=0)) / 2, uL)
230
+ uRhybrid = xp.where(~shockFlagR, (uR + xp.roll(uL, -1, axis=0)) / 2, uR)
231
+ return np.asarray(xp.moveaxis(uLhybrid, 0, 2)), np.asarray(xp.moveaxis(uRhybrid, 0, 2))
232
+
233
+ raise ValueError(f"Invalid mode: {mode}. Options: 'inner', 'upwind', and 'hybrid'.")
234
+
235
+ uL, shockFlagL = coreTENO7M(u, direction="L", axis=axis, CT=CT)
236
+ uR, shockFlagR = coreTENO7M(u, direction="R", axis=axis, CT=CT)
237
+
238
+ if mode == "inner":
239
+ return np.asarray(uL), np.asarray(uR)
240
+
241
+ if mode == "upwind":
242
+ uLupwind = xp.roll(uR, 1, axis=axis)
243
+ return np.asarray(uLupwind), np.asarray(uR)
244
+
245
+ if mode == "hybrid":
246
+ uLhybrid = xp.where(~shockFlagL, (uL + xp.roll(uR, 1, axis=axis)) / 2, uL)
247
+ uRhybrid = xp.where(~shockFlagR, (uR + xp.roll(uL, -1, axis=axis)) / 2, uR)
248
+ return np.asarray(uLhybrid), np.asarray(uRhybrid)
249
+
250
+ raise ValueError(f"Invalid mode: {mode}. Options: 'inner', 'upwind', and 'hybrid'.")
251
+
252
+
253
+ def TENO7Mx(
254
+ u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
255
+ ) -> tuple[np.ndarray, np.ndarray]:
256
+ """TENO7-M reconstruction in x direction."""
257
+ if hasJAX:
258
+ import jax
259
+ u = jax.device_put(u, sharding["x"])
260
+ return TENO7M(u, axis=0, mode=mode, CT=CT)
261
+
262
+ def TENO7My(
263
+ u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
264
+ ) -> tuple[np.ndarray, np.ndarray]:
265
+ """TENO7-M reconstruction in y direction."""
266
+ if hasJAX:
267
+ import jax
268
+ u = jax.device_put(u, sharding["y"])
269
+ return TENO7M(u, axis=1, mode=mode, CT=CT)
270
+
271
+ def TENO7Mz(
272
+ u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
273
+ ) -> tuple[np.ndarray, np.ndarray]:
274
+ """TENO7-M reconstruction in z direction."""
275
+ if hasJAX:
276
+ import jax
277
+ u = jax.device_put(u, sharding["z"])
278
+ return TENO7M(u, axis=2, mode=mode, CT=CT)
@@ -0,0 +1,323 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/WENO.py
7
+ --------------------------
8
+
9
+ Implements the WENO5-Z and WENO7-Z reconstruction schemes with two backends:
10
+ - NumPy (default): more compatible, but limited to a single CPU core
11
+ - JAX: requires JAX, but leverages multi-core CPU and GPU acceleration
12
+ """
13
+
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+ from numpy.typing import ArrayLike
18
+
19
+ hasJAX = False
20
+
21
+ try:
22
+ import jax
23
+ jax.config.update("jax_enable_x64", True)
24
+
25
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
26
+ jax.config.update("jax_num_cpu_devices", 8)
27
+
28
+ mesh: Mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=("X", "Y"))
29
+
30
+ sharding = {
31
+ "x": NamedSharding(mesh, P(None, "X", "Y")),
32
+ "y": NamedSharding(mesh, P("X", None, "Y")),
33
+ "z": NamedSharding(mesh, P("X", "Y", None)),
34
+ }
35
+
36
+ hasJAX = True
37
+
38
+ except ImportError:
39
+ pass
40
+
41
+ # ===== WENO5 =====
42
+
43
+ def coreWENO5(u: ArrayLike, axis: int) -> tuple[Any, Any]:
44
+ """WENO5-Z reconstruction along a single axis."""
45
+ if hasJAX:
46
+ import jax.numpy as jnp
47
+ xp = jnp
48
+ else:
49
+ xp = np
50
+ u = xp.asarray(u)
51
+
52
+ uR2 = xp.roll(u, -2, axis=axis)
53
+ uR1 = xp.roll(u, -1, axis=axis)
54
+ u0 = u
55
+ uL1 = xp.roll(u, 1, axis=axis)
56
+ uL2 = xp.roll(u, 2, axis=axis)
57
+
58
+ beta1 = 13/12 * (uL2 - 2 * uL1 + u0 )**2 + 1/4 * (uL2 - 4 * uL1 + 3 * u0)**2
59
+ beta2 = 13/12 * (uL1 - 2 * u0 + uR1)**2 + 1/4 * (uL1 - uR1)**2
60
+ beta3 = 13/12 * (u0 - 2 * uR1 + uR2)**2 + 1/4 * (3 * u0 - 4 * uR1 + uR2)**2
61
+
62
+ tau5 = xp.abs(beta1 - beta3)
63
+ epsilon = 1e-6
64
+ p = 4
65
+
66
+ gamma1, gamma2, gamma3 = 1/10, 3/5, 3/10
67
+
68
+ bL1, aL1, c1, aR1, bR1 = 1/3, -7/6, 11/6, 0, 0
69
+ bL2, aL2, c2, aR2, bR2 = 0, -1/6, 5/6, 1/3, 0
70
+ bL3, aL3, c3, aR3, bR3 = 0, 0, 1/3, 5/6, -1/6
71
+
72
+ weights1 = gamma1 * (1 + (tau5/(beta1 + epsilon)) ** p)
73
+ weights2 = gamma2 * (1 + (tau5/(beta2 + epsilon)) ** p)
74
+ weights3 = gamma3 * (1 + (tau5/(beta3 + epsilon)) ** p)
75
+
76
+ w1 = weights1 / (weights1 + weights2 + weights3)
77
+ w2 = weights2 / (weights1 + weights2 + weights3)
78
+ w3 = weights3 / (weights1 + weights2 + weights3)
79
+
80
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3
81
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3
82
+ c = w1 * c1 + w2 * c2 + w3 * c3
83
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3
84
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3
85
+
86
+ uR = bL * uL2 + aL * uL1 + c * u0 + aR * uR1 + bR * uR2
87
+
88
+ gamma1, gamma2, gamma3 = 3/10, 3/5, 1/10
89
+
90
+ bL1, aL1, c1, aR1, bR1 = -1/6, 5/6, 1/3, 0, 0
91
+ bL2, aL2, c2, aR2, bR2 = 0, 1/3, 5/6, -1/6, 0
92
+ bL3, aL3, c3, aR3, bR3 = 0, 0, 11/6, -7/6, 1/3
93
+
94
+ weights1 = gamma1 * (1 + (tau5/(beta1 + epsilon)) ** p)
95
+ weights2 = gamma2 * (1 + (tau5/(beta2 + epsilon)) ** p)
96
+ weights3 = gamma3 * (1 + (tau5/(beta3 + epsilon)) ** p)
97
+
98
+ w1 = weights1 / (weights1 + weights2 + weights3)
99
+ w2 = weights2 / (weights1 + weights2 + weights3)
100
+ w3 = weights3 / (weights1 + weights2 + weights3)
101
+
102
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3
103
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3
104
+ c = w1 * c1 + w2 * c2 + w3 * c3
105
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3
106
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3
107
+
108
+ uL = bL * uL2 + aL * uL1 + c * u0 + aR * uR1 + bR * uR2
109
+
110
+ return uL, uR
111
+
112
+ if hasJAX:
113
+ import jax
114
+ coreWENO5 = jax.jit(coreWENO5, static_argnames=("axis",))
115
+
116
+ def WENO5(u: ArrayLike, axis: int) -> tuple[np.ndarray, np.ndarray]:
117
+ """WENO5-Z reconstruction on a 3D array along the given axis."""
118
+ if hasJAX:
119
+ import jax.numpy as jnp
120
+ xp: Any = jnp
121
+ else:
122
+ xp = np
123
+ u = xp.asarray(u)
124
+
125
+ if axis == 2:
126
+ u = xp.moveaxis(u, 2, 0)
127
+ uL, uR = coreWENO5(u, axis=0)
128
+ return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
129
+
130
+ uL, uR = coreWENO5(u, axis=axis)
131
+ return np.asarray(uL), np.asarray(uR)
132
+
133
+ def WENO5x(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
134
+ """WENO5-Z reconstruction in x direction."""
135
+ if hasJAX:
136
+ import jax
137
+ u = jax.device_put(u, sharding["x"])
138
+ return WENO5(u, axis=0)
139
+
140
+ def WENO5y(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
141
+ """WENO5-Z reconstruction in y direction."""
142
+ if hasJAX:
143
+ import jax
144
+ u = jax.device_put(u, sharding["y"])
145
+ return WENO5(u, axis=1)
146
+
147
+ def WENO5z(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
148
+ """WENO5-Z reconstruction in z direction."""
149
+ if hasJAX:
150
+ import jax
151
+ u = jax.device_put(u, sharding["z"])
152
+ return WENO5(u, axis=2)
153
+
154
+ # ===== WENO7 =====
155
+
156
+ def coreWENO7(u: ArrayLike, axis: int) -> tuple[Any, Any]:
157
+ """WENO7-Z reconstruction along a single axis."""
158
+ if hasJAX:
159
+ import jax.numpy as jnp
160
+ xp = jnp
161
+ else:
162
+ xp = np
163
+ u = xp.asarray(u)
164
+
165
+ uR3 = xp.roll(u, -3, axis=axis)
166
+ uR2 = xp.roll(u, -2, axis=axis)
167
+ uR1 = xp.roll(u, -1, axis=axis)
168
+ u0 = u
169
+ uL1 = xp.roll(u, 1, axis=axis)
170
+ uL2 = xp.roll(u, 2, axis=axis)
171
+ uL3 = xp.roll(u, 3, axis=axis)
172
+
173
+ beta1 = 1 / 240 * (
174
+ uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
175
+ uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
176
+ uL1 * (11003 * uL1 - 9402 * u0) + \
177
+ u0 * ( 2107 * u0)
178
+ )
179
+ beta2 = 1 / 240 * (
180
+ uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
181
+ uL1 * (2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
182
+ u0 * (3443 * u0 - 2522 * uR1) + \
183
+ uR1 * ( 547 * uR1)
184
+ )
185
+ beta3 = 1 / 240 * (
186
+ uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
187
+ u0 * (3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
188
+ uR1 * (2843 * uR1 - 1642 * uR2) + \
189
+ uR2 * ( 267 * uR2)
190
+ )
191
+ beta4 = 1 / 240 * (
192
+ u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
193
+ uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
194
+ uR2 * ( 7043 * uR2 - 3882 * uR3) + \
195
+ uR3 * ( 547 * uR3)
196
+ )
197
+
198
+ tau = xp.abs(beta1 - beta4)
199
+ epsilon = 1e-6
200
+ p = 4
201
+
202
+ gamma1, gamma2, gamma3, gamma4 = 1/35, 12/35, 18/35, 4/35
203
+
204
+ bL1, aL1, c1, aR1, bR1, cR1, dR1 = -1/4, 13/12, -23/12, 25/12, 0, 0, 0
205
+ bL2, aL2, c2, aR2, bR2, cR2, dR2 = 0, 1/12, -5/12, 13/12, 1/4, 0, 0
206
+ bL3, aL3, c3, aR3, bR3, cR3, dR3 = 0, 0, -1/12, 7/12, 7/12, -1/12, 0
207
+ bL4, aL4, c4, aR4, bR4, cR4, dR4 = 0, 0, 0, 1/4, 13/12, -5/12, 1/12
208
+
209
+ weights1 = gamma1 * (1 + (tau/(beta1 + epsilon)) ** p)
210
+ weights2 = gamma2 * (1 + (tau/(beta2 + epsilon)) ** p)
211
+ weights3 = gamma3 * (1 + (tau/(beta3 + epsilon)) ** p)
212
+ weights4 = gamma4 * (1 + (tau/(beta4 + epsilon)) ** p)
213
+
214
+ w1 = weights1 / (weights1 + weights2 + weights3 + weights4)
215
+ w2 = weights2 / (weights1 + weights2 + weights3 + weights4)
216
+ w3 = weights3 / (weights1 + weights2 + weights3 + weights4)
217
+ w4 = weights4 / (weights1 + weights2 + weights3 + weights4)
218
+
219
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3 + w4 * bL4
220
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3 + w4 * aL4
221
+ c = w1 * c1 + w2 * c2 + w3 * c3 + w4 * c4
222
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3 + w4 * aR4
223
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3 + w4 * bR4
224
+ cR = w1 * cR1 + w2 * cR2 + w3 * cR3 + w4 * cR4
225
+ dR = w1 * dR1 + w2 * dR2 + w3 * dR3 + w4 * dR4
226
+ uR = bL * uL3 + aL * uL2 + c * uL1 + aR * u0 + bR * uR1 + cR * uR2 + dR * uR3
227
+
228
+ gamma1, gamma2, gamma3, gamma4 = 4/35, 18/35, 12/35, 1/35
229
+
230
+ bL1, aL1, c1, aR1, bR1, cR1, dR1 = 1/12, -5/12, 13/12, 1/4, 0, 0, 0
231
+ bL2, aL2, c2, aR2, bR2, cR2, dR2 = 0, -1/12, 7/12, 7/12, -1/12, 0, 0
232
+ bL3, aL3, c3, aR3, bR3, cR3, dR3 = 0, 0, 1/4, 13/12, -5/12, 1/12, 0
233
+ bL4, aL4, c4, aR4, bR4, cR4, dR4 = 0, 0, 0, 25/12, -23/12, 13/12, -1/4
234
+
235
+ weights1 = gamma1 * (1 + (tau/(beta1 + epsilon)) ** p)
236
+ weights2 = gamma2 * (1 + (tau/(beta2 + epsilon)) ** p)
237
+ weights3 = gamma3 * (1 + (tau/(beta3 + epsilon)) ** p)
238
+ weights4 = gamma4 * (1 + (tau/(beta4 + epsilon)) ** p)
239
+
240
+ w1 = weights1 / (weights1 + weights2 + weights3 + weights4)
241
+ w2 = weights2 / (weights1 + weights2 + weights3 + weights4)
242
+ w3 = weights3 / (weights1 + weights2 + weights3 + weights4)
243
+ w4 = weights4 / (weights1 + weights2 + weights3 + weights4)
244
+
245
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3 + w4 * bL4
246
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3 + w4 * aL4
247
+ c = w1 * c1 + w2 * c2 + w3 * c3 + w4 * c4
248
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3 + w4 * aR4
249
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3 + w4 * bR4
250
+ cR = w1 * cR1 + w2 * cR2 + w3 * cR3 + w4 * cR4
251
+ dR = w1 * dR1 + w2 * dR2 + w3 * dR3 + w4 * dR4
252
+ uL = bL * uL3 + aL * uL2 + c * uL1 + aR * u0 + bR * uR1 + cR * uR2 + dR * uR3
253
+
254
+ return uL, uR
255
+
256
+ if hasJAX:
257
+ import jax
258
+ coreWENO7 = jax.jit(coreWENO7, static_argnames=("axis",))
259
+
260
+ def WENO7(u: ArrayLike, axis: int) -> tuple[np.ndarray, np.ndarray]:
261
+ """WENO7-Z reconstruction on a 3D array along the given axis."""
262
+ if hasJAX:
263
+ import jax.numpy as jnp
264
+ xp: Any = jnp
265
+ else:
266
+ xp = np
267
+ u = xp.asarray(u)
268
+
269
+ if axis == 2:
270
+ u = xp.moveaxis(u, 2, 0)
271
+ uL, uR = coreWENO7(u, axis=0)
272
+ return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
273
+
274
+ uL, uR = coreWENO7(u, axis=axis)
275
+ return np.asarray(uL), np.asarray(uR)
276
+
277
+ def WENO7x(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
278
+ """WENO7-Z reconstruction in x direction."""
279
+ if hasJAX:
280
+ import jax
281
+ u = jax.device_put(u, sharding["x"])
282
+ return WENO7(u, axis=0)
283
+
284
+ def WENO7y(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
285
+ """WENO7-Z reconstruction in y direction."""
286
+ if hasJAX:
287
+ import jax
288
+ u = jax.device_put(u, sharding["y"])
289
+ return WENO7(u, axis=1)
290
+
291
+ def WENO7z(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
292
+ """WENO7-Z reconstruction in z direction."""
293
+ if hasJAX:
294
+ import jax
295
+ u = jax.device_put(u, sharding["z"])
296
+ return WENO7(u, axis=2)
297
+
298
+ def WENOx(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
299
+ """WENO5-Z or WENO7-Z reconstruction in x direction."""
300
+ if stencils == 5:
301
+ return WENO5x(u)
302
+ if stencils == 7:
303
+ return WENO7x(u)
304
+
305
+ raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
306
+
307
+ def WENOy(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
308
+ """WENO5-Z or WENO7-Z reconstruction in y direction."""
309
+ if stencils == 5:
310
+ return WENO5y(u)
311
+ if stencils == 7:
312
+ return WENO7y(u)
313
+
314
+ raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
315
+
316
+ def WENOz(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
317
+ """WENO5-Z or WENO7-Z reconstruction in z direction."""
318
+ if stencils == 5:
319
+ return WENO5z(u)
320
+ if stencils == 7:
321
+ return WENO7z(u)
322
+
323
+ raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
@@ -0,0 +1,24 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/__init__.py
7
+ """
8
+
9
+ from .derivative import (
10
+ Field, Algorithm,
11
+ Dx, Dy, Dz,
12
+ grad, div, curl, laplacian,
13
+ average2center,
14
+ )
15
+
16
+ from . import derivative
17
+
18
+ __all__ = [
19
+ 'derivative',
20
+ 'Field', 'Algorithm',
21
+ 'Dx', 'Dy', 'Dz',
22
+ 'grad', 'div', 'curl', 'laplacian',
23
+ 'average2center',
24
+ ]