pymhd 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymhd/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/__init__.py
7
+ -----------------
8
+
9
+ Unified interface for PyMHD.
10
+ """
11
+
12
+ from .turbulence import Scalar, Vector, ScalarField, VectorField, Turbulence, avg, std
13
+ from .derivatives.derivative import Dx, Dy, Dz, grad, div, curl, laplacian, Algorithm
14
+
15
+ from .spectra import Spectrum, Spectrum1D, EnergySpectra
16
+ from .numdiss import NumericalDissipation
17
+ from .preprocess import output2turbulence
18
+
19
+ from .plot import plot
20
+
21
+ __all__ = [
22
+ "Scalar", "Vector",
23
+ "ScalarField", "VectorField", "avg", "std", "Turbulence",
24
+ "Algorithm",
25
+ "Dx", "Dy", "Dz", "grad", "div", "curl", "laplacian",
26
+ "Spectrum", "Spectrum1D",
27
+ "EnergySpectra",
28
+ "NumericalDissipation",
29
+ "output2turbulence",
30
+ "plot",
31
+ ]
@@ -0,0 +1,282 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/TENO.py
7
+ --------------------------
8
+
9
+ Implements the TENO7-M reconstruction scheme with two backends:
10
+ - NumPy (default): more compatible, but limited to a single CPU core
11
+ - JAX: requires JAX, but leverages multi-core CPU and (future) GPU acceleration
12
+ """
13
+
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+ from numpy.typing import ArrayLike
18
+
19
+ hasJAX = False
20
+
21
+ try:
22
+ import jax
23
+ jax.config.update("jax_enable_x64", True)
24
+
25
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
26
+ try:
27
+ jax.config.update("jax_num_cpu_devices", 8)
28
+ except RuntimeError:
29
+ # In some JAX versions, re-setting this after backend init raises.
30
+ pass
31
+
32
+ mesh: Mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=("X", "Y"))
33
+
34
+ sharding = {
35
+ "x": NamedSharding(mesh, P(None, "X", "Y")),
36
+ "y": NamedSharding(mesh, P("X", None, "Y")),
37
+ "z": NamedSharding(mesh, P("X", "Y", None)),
38
+ }
39
+
40
+ hasJAX = True
41
+
42
+ except ImportError:
43
+ pass
44
+
45
+ def coreTENO7M(
46
+ u: ArrayLike, direction: str, axis: int, CT: float = 0.01
47
+ ) -> tuple[Any, Any]:
48
+ """TENO7-M reconstruction.
49
+
50
+ Reconstruct the cell interface values from the cell averages
51
+
52
+ Parameters
53
+ ----------
54
+ u : npt.ArrayLike, cell-centered or cell-averaged values
55
+ direction: str, reconstruction direction, 'L' (for i - 1/2) or 'R' (for i + 1/2)
56
+ axis : int, the axis direction (0, 1, 2 corresponds to x, y, z)
57
+ CT : float, smoothness threshold, defaults to 0.01
58
+
59
+ Returns
60
+ -------
61
+ uHalf : np.ndarray or jax.Array, reconstructed value at the cell interfaces (i ± 1/2)
62
+ shockflag: np.ndarray or jax.Array, shock detection flag
63
+ """
64
+ if hasJAX:
65
+ import jax.numpy as jnp
66
+ xp = jnp
67
+ else:
68
+ xp = np
69
+ u = xp.asarray(u)
70
+
71
+ def roll(arr, offset: int):
72
+ shift = -offset if direction == "R" else offset
73
+ return xp.roll(arr, shift, axis=axis)
74
+
75
+ uR3 = roll(u, 3)
76
+ uR2 = roll(u, 2)
77
+ uR1 = roll(u, 1)
78
+ u0 = u
79
+ uL1 = roll(u, -1)
80
+ uL2 = roll(u, -2)
81
+ uL3 = roll(u, -3)
82
+
83
+ # Calculate the smoothness indicators for (7, 4) stencil
84
+ # Ref: Balsara & Shu, Journal of Computational Physics 160, no. 2 (May 2000): 405–52
85
+ beta0 = 1 / 240 * (
86
+ uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
87
+ uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
88
+ uL1 * (11003 * uL1 - 9402 * u0) + \
89
+ u0 * ( 2107 * u0)
90
+ )
91
+ beta1 = 1 / 240 * (
92
+ uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
93
+ uL1 * (2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
94
+ u0 * (3443 * u0 - 2522 * uR1) + \
95
+ uR1 * ( 547 * uR1)
96
+ )
97
+ beta2 = 1 / 240 * (
98
+ uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
99
+ u0 * (3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
100
+ uR1 * (2843 * uR1 - 1642 * uR2) + \
101
+ uR2 * ( 267 * uR2)
102
+ )
103
+ beta3 = 1 / 240 * (
104
+ u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
105
+ uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
106
+ uR2 * ( 7043 * uR2 - 3882 * uR3) + \
107
+ uR3 * ( 547 * uR3)
108
+ )
109
+
110
+ epsilon = 1e-40
111
+ q = 6
112
+
113
+ gamma0 = 1 / (beta0 + epsilon) ** q
114
+ gamma1 = 1 / (beta1 + epsilon) ** q
115
+ gamma2 = 1 / (beta2 + epsilon) ** q
116
+ gamma3 = 1 / (beta3 + epsilon) ** q
117
+
118
+ chi0mask = (gamma0 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
119
+ chi1mask = (gamma1 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
120
+ chi2mask = (gamma2 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
121
+ chi3mask = (gamma3 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
122
+
123
+ delta0 = ~(chi0mask & roll(chi1mask, -1) & roll(chi2mask, -2) & roll(chi3mask, -3))
124
+ delta1 = ~(chi1mask & roll(chi2mask, -1) & roll(chi3mask, -2) & roll(chi0mask, 1))
125
+ delta2 = ~(chi2mask & roll(chi3mask, -1) & roll(chi1mask, 1) & roll(chi0mask, 2))
126
+ delta3 = ~(chi3mask & roll(chi2mask, 1) & roll(chi1mask, 2) & roll(chi0mask, 3))
127
+
128
+ # ----- Dispersion-relation-preserving (DRP) -----
129
+ # Ref: Tam. Computational aeroacoustics: A wave number approach, 2012.
130
+ a71 = 0.77088238051822552
131
+ a72 = -0.166705904414580469
132
+ a73 = 0.02084314277031176
133
+
134
+ a91 = 0.8301178834769906875382633360472085
135
+ a92 = -0.23175338776901819008451262109655756
136
+ a93 = 0.05287205020483696423592156502901203
137
+ a94 = -6.306814638366300019250697235282424e-3
138
+
139
+ cL3 = 2 * a94
140
+ cL2 = 2 * a93 - a73
141
+ cL1 = 2 * a92 + 2 * a94 - a72 + a73
142
+ c0 = 2 * a91 + 2 * a93 - a71 + a72 - a73
143
+ cR1 = 2 * a92 + 2 * a94 - a72 + a73 + a71
144
+ cR2 = 2 * a93 - a73 + a72
145
+ cR3 = 2 * a94 + a73
146
+
147
+ # Optimal reconstruction polynomial for smooth stencils
148
+ uHalf = cL3 * uL3 + cL2 * uL2 + cL1 * uL1 + c0 * u0 + cR1 * uR1 + cR2 * uR2 + cR3 * uR3
149
+
150
+ # a single discontinuity lies in (uL3, uL2)
151
+ uHalf = xp.where(
152
+ (~delta0) & delta1 & delta2 & delta3,
153
+ 1 / 60 * uL2 - 2 / 15 * uL1 + 37 / 60 * u0 + 37 / 60 * uR1 - 2 / 15 * uR2 + 1 / 60 * uR3,
154
+ uHalf,
155
+ )
156
+
157
+ # a single discontinuity lies in (uL2, uL1)
158
+ uHalf = xp.where(
159
+ (~delta0) & (~delta1) & delta2 & delta3,
160
+ -1 / 20 * uL1 + 9 / 20 * u0 + 47 / 60 * uR1 - 13 / 60 * uR2 + 1 / 30 * uR3,
161
+ uHalf,
162
+ )
163
+
164
+ # a single discontinuity lies in (uL1, u0)
165
+ uHalf = xp.where(
166
+ (~delta0) & (~delta1) & (~delta2) & delta3,
167
+ 1 / 4 * u0 + 13 / 12 * uR1 - 5 / 12 * uR2 + 1 / 12 * uR3,
168
+ uHalf,
169
+ )
170
+
171
+ # a single discontinuity lies in (u0, uR1)
172
+ uHalf = xp.where(
173
+ delta0 & (~delta1) & (~delta2) & (~delta3),
174
+ -1 / 4 * uL3 + 13 / 12 * uL2 - 23 / 12 * uL1 + 25 / 12 * u0,
175
+ uHalf,
176
+ )
177
+
178
+ # a single discontinuity lies in (uR1, uR2)
179
+ uHalf = xp.where(
180
+ delta0 & delta1 & (~delta2) & (~delta3),
181
+ -1 / 20 * uL3 + 17 / 60 * uL2 - 43 / 60 * uL1 + 77 / 60 * u0 + 1 / 5 * uR1,
182
+ uHalf,
183
+ )
184
+
185
+ # a single discontinuity lies in (uR2, uR3)
186
+ uHalf = xp.where(
187
+ delta0 & delta1 & delta2 & (~delta3),
188
+ -1 / 60 * uL3 + 7 / 60 * uL2 - 23 / 60 * uL1 + 19 / 20 * u0 + 11 / 30 * uR1 - 1 / 30 * uR2,
189
+ uHalf,
190
+ )
191
+
192
+ # two discontinuities lie in (uL3, uL2) and (uR2, uR3)
193
+ uHalf = xp.where(
194
+ (~delta0) & delta1 & delta2 & (~delta3),
195
+ 2 / 60 * uL2 - 13 / 60 * uL1 + 47 / 60 * u0 + 27 / 60 * uR1 - 3 / 60 * uR2,
196
+ uHalf,
197
+ )
198
+
199
+ # Any stencil other than fully smooth is considered discontinuous.
200
+ shockflag = ~(delta0 & delta1 & delta2 & delta3)
201
+
202
+ return uHalf, shockflag
203
+
204
+ if hasJAX:
205
+ import jax
206
+ coreTENO7M = jax.jit(coreTENO7M, static_argnames=("direction", "axis", "CT"))
207
+
208
+ def TENO7M(
209
+ u: ArrayLike, axis: int, mode: str = "hybrid", CT: float = 0.01
210
+ ) -> tuple[np.ndarray, np.ndarray]:
211
+ """TENO7-M reconstruction."""
212
+ if hasJAX:
213
+ import jax.numpy as jnp
214
+ xp: Any = jnp
215
+ else:
216
+ xp = np
217
+ u = xp.asarray(u)
218
+
219
+ if axis == 2:
220
+ u = xp.moveaxis(u, 2, 0)
221
+
222
+ uL, shockFlagL = coreTENO7M(u, direction="L", axis=0, CT=CT)
223
+ uR, shockFlagR = coreTENO7M(u, direction="R", axis=0, CT=CT)
224
+
225
+ if mode == "inner":
226
+ return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
227
+
228
+ if mode == "upwind":
229
+ uLupwind = xp.roll(uR, 1, axis=0)
230
+ return np.asarray(xp.moveaxis(uLupwind, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
231
+
232
+ if mode == "hybrid":
233
+ uLhybrid = xp.where(~shockFlagL, (uL + xp.roll(uR, 1, axis=0)) / 2, uL)
234
+ uRhybrid = xp.where(~shockFlagR, (uR + xp.roll(uL, -1, axis=0)) / 2, uR)
235
+ return np.asarray(xp.moveaxis(uLhybrid, 0, 2)), np.asarray(xp.moveaxis(uRhybrid, 0, 2))
236
+
237
+ raise ValueError(f"Invalid mode: {mode}. Options: 'inner', 'upwind', and 'hybrid'.")
238
+
239
+ uL, shockFlagL = coreTENO7M(u, direction="L", axis=axis, CT=CT)
240
+ uR, shockFlagR = coreTENO7M(u, direction="R", axis=axis, CT=CT)
241
+
242
+ if mode == "inner":
243
+ return np.asarray(uL), np.asarray(uR)
244
+
245
+ if mode == "upwind":
246
+ uLupwind = xp.roll(uR, 1, axis=axis)
247
+ return np.asarray(uLupwind), np.asarray(uR)
248
+
249
+ if mode == "hybrid":
250
+ uLhybrid = xp.where(~shockFlagL, (uL + xp.roll(uR, 1, axis=axis)) / 2, uL)
251
+ uRhybrid = xp.where(~shockFlagR, (uR + xp.roll(uL, -1, axis=axis)) / 2, uR)
252
+ return np.asarray(uLhybrid), np.asarray(uRhybrid)
253
+
254
+ raise ValueError(f"Invalid mode: {mode}. Options: 'inner', 'upwind', and 'hybrid'.")
255
+
256
+
257
+ def TENO7Mx(
258
+ u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
259
+ ) -> tuple[np.ndarray, np.ndarray]:
260
+ """TENO7-M reconstruction in x direction."""
261
+ if hasJAX:
262
+ import jax
263
+ u = jax.device_put(u, sharding["x"])
264
+ return TENO7M(u, axis=0, mode=mode, CT=CT)
265
+
266
+ def TENO7My(
267
+ u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
268
+ ) -> tuple[np.ndarray, np.ndarray]:
269
+ """TENO7-M reconstruction in y direction."""
270
+ if hasJAX:
271
+ import jax
272
+ u = jax.device_put(u, sharding["y"])
273
+ return TENO7M(u, axis=1, mode=mode, CT=CT)
274
+
275
+ def TENO7Mz(
276
+ u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
277
+ ) -> tuple[np.ndarray, np.ndarray]:
278
+ """TENO7-M reconstruction in z direction."""
279
+ if hasJAX:
280
+ import jax
281
+ u = jax.device_put(u, sharding["z"])
282
+ return TENO7M(u, axis=2, mode=mode, CT=CT)
@@ -0,0 +1,327 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/WENO.py
7
+ --------------------------
8
+
9
+ Implements the WENO5-Z and WENO7-Z reconstruction schemes with two backends:
10
+ - NumPy (default): more compatible, but limited to a single CPU core
11
+ - JAX: requires JAX, but leverages multi-core CPU and (future) GPU acceleration
12
+ """
13
+
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+ from numpy.typing import ArrayLike
18
+
19
+ hasJAX = False
20
+
21
+ try:
22
+ import jax
23
+ jax.config.update("jax_enable_x64", True)
24
+
25
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
26
+ try:
27
+ jax.config.update("jax_num_cpu_devices", 8)
28
+ except RuntimeError:
29
+ # In some JAX versions, re-setting this after backend init raises.
30
+ pass
31
+
32
+ mesh: Mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=("X", "Y"))
33
+
34
+ sharding = {
35
+ "x": NamedSharding(mesh, P(None, "X", "Y")),
36
+ "y": NamedSharding(mesh, P("X", None, "Y")),
37
+ "z": NamedSharding(mesh, P("X", "Y", None)),
38
+ }
39
+
40
+ hasJAX = True
41
+
42
+ except ImportError:
43
+ pass
44
+
45
+ # ===== WENO5 =====
46
+
47
+ def coreWENO5(u: ArrayLike, axis: int) -> tuple[Any, Any]:
48
+ """WENO5-Z reconstruction along a single axis."""
49
+ if hasJAX:
50
+ import jax.numpy as jnp
51
+ xp = jnp
52
+ else:
53
+ xp = np
54
+ u = xp.asarray(u)
55
+
56
+ uR2 = xp.roll(u, -2, axis=axis)
57
+ uR1 = xp.roll(u, -1, axis=axis)
58
+ u0 = u
59
+ uL1 = xp.roll(u, 1, axis=axis)
60
+ uL2 = xp.roll(u, 2, axis=axis)
61
+
62
+ beta1 = 13/12 * (uL2 - 2 * uL1 + u0 )**2 + 1/4 * (uL2 - 4 * uL1 + 3 * u0)**2
63
+ beta2 = 13/12 * (uL1 - 2 * u0 + uR1)**2 + 1/4 * (uL1 - uR1)**2
64
+ beta3 = 13/12 * (u0 - 2 * uR1 + uR2)**2 + 1/4 * (3 * u0 - 4 * uR1 + uR2)**2
65
+
66
+ tau5 = xp.abs(beta1 - beta3)
67
+ epsilon = 1e-6
68
+ p = 4
69
+
70
+ gamma1, gamma2, gamma3 = 1/10, 3/5, 3/10
71
+
72
+ bL1, aL1, c1, aR1, bR1 = 1/3, -7/6, 11/6, 0, 0
73
+ bL2, aL2, c2, aR2, bR2 = 0, -1/6, 5/6, 1/3, 0
74
+ bL3, aL3, c3, aR3, bR3 = 0, 0, 1/3, 5/6, -1/6
75
+
76
+ weights1 = gamma1 * (1 + (tau5/(beta1 + epsilon)) ** p)
77
+ weights2 = gamma2 * (1 + (tau5/(beta2 + epsilon)) ** p)
78
+ weights3 = gamma3 * (1 + (tau5/(beta3 + epsilon)) ** p)
79
+
80
+ w1 = weights1 / (weights1 + weights2 + weights3)
81
+ w2 = weights2 / (weights1 + weights2 + weights3)
82
+ w3 = weights3 / (weights1 + weights2 + weights3)
83
+
84
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3
85
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3
86
+ c = w1 * c1 + w2 * c2 + w3 * c3
87
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3
88
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3
89
+
90
+ uR = bL * uL2 + aL * uL1 + c * u0 + aR * uR1 + bR * uR2
91
+
92
+ gamma1, gamma2, gamma3 = 3/10, 3/5, 1/10
93
+
94
+ bL1, aL1, c1, aR1, bR1 = -1/6, 5/6, 1/3, 0, 0
95
+ bL2, aL2, c2, aR2, bR2 = 0, 1/3, 5/6, -1/6, 0
96
+ bL3, aL3, c3, aR3, bR3 = 0, 0, 11/6, -7/6, 1/3
97
+
98
+ weights1 = gamma1 * (1 + (tau5/(beta1 + epsilon)) ** p)
99
+ weights2 = gamma2 * (1 + (tau5/(beta2 + epsilon)) ** p)
100
+ weights3 = gamma3 * (1 + (tau5/(beta3 + epsilon)) ** p)
101
+
102
+ w1 = weights1 / (weights1 + weights2 + weights3)
103
+ w2 = weights2 / (weights1 + weights2 + weights3)
104
+ w3 = weights3 / (weights1 + weights2 + weights3)
105
+
106
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3
107
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3
108
+ c = w1 * c1 + w2 * c2 + w3 * c3
109
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3
110
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3
111
+
112
+ uL = bL * uL2 + aL * uL1 + c * u0 + aR * uR1 + bR * uR2
113
+
114
+ return uL, uR
115
+
116
+ if hasJAX:
117
+ import jax
118
+ coreWENO5 = jax.jit(coreWENO5, static_argnames=("axis",))
119
+
120
+ def WENO5(u: ArrayLike, axis: int) -> tuple[np.ndarray, np.ndarray]:
121
+ """WENO5-Z reconstruction on a 3D array along the given axis."""
122
+ if hasJAX:
123
+ import jax.numpy as jnp
124
+ xp: Any = jnp
125
+ else:
126
+ xp = np
127
+ u = xp.asarray(u)
128
+
129
+ if axis == 2:
130
+ u = xp.moveaxis(u, 2, 0)
131
+ uL, uR = coreWENO5(u, axis=0)
132
+ return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
133
+
134
+ uL, uR = coreWENO5(u, axis=axis)
135
+ return np.asarray(uL), np.asarray(uR)
136
+
137
+ def WENO5x(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
138
+ """WENO5-Z reconstruction in x direction."""
139
+ if hasJAX:
140
+ import jax
141
+ u = jax.device_put(u, sharding["x"])
142
+ return WENO5(u, axis=0)
143
+
144
+ def WENO5y(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
145
+ """WENO5-Z reconstruction in y direction."""
146
+ if hasJAX:
147
+ import jax
148
+ u = jax.device_put(u, sharding["y"])
149
+ return WENO5(u, axis=1)
150
+
151
+ def WENO5z(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
152
+ """WENO5-Z reconstruction in z direction."""
153
+ if hasJAX:
154
+ import jax
155
+ u = jax.device_put(u, sharding["z"])
156
+ return WENO5(u, axis=2)
157
+
158
+ # ===== WENO7 =====
159
+
160
+ def coreWENO7(u: ArrayLike, axis: int) -> tuple[Any, Any]:
161
+ """WENO7-Z reconstruction along a single axis."""
162
+ if hasJAX:
163
+ import jax.numpy as jnp
164
+ xp = jnp
165
+ else:
166
+ xp = np
167
+ u = xp.asarray(u)
168
+
169
+ uR3 = xp.roll(u, -3, axis=axis)
170
+ uR2 = xp.roll(u, -2, axis=axis)
171
+ uR1 = xp.roll(u, -1, axis=axis)
172
+ u0 = u
173
+ uL1 = xp.roll(u, 1, axis=axis)
174
+ uL2 = xp.roll(u, 2, axis=axis)
175
+ uL3 = xp.roll(u, 3, axis=axis)
176
+
177
+ beta1 = 1 / 240 * (
178
+ uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
179
+ uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
180
+ uL1 * (11003 * uL1 - 9402 * u0) + \
181
+ u0 * ( 2107 * u0)
182
+ )
183
+ beta2 = 1 / 240 * (
184
+ uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
185
+ uL1 * (2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
186
+ u0 * (3443 * u0 - 2522 * uR1) + \
187
+ uR1 * ( 547 * uR1)
188
+ )
189
+ beta3 = 1 / 240 * (
190
+ uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
191
+ u0 * (3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
192
+ uR1 * (2843 * uR1 - 1642 * uR2) + \
193
+ uR2 * ( 267 * uR2)
194
+ )
195
+ beta4 = 1 / 240 * (
196
+ u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
197
+ uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
198
+ uR2 * ( 7043 * uR2 - 3882 * uR3) + \
199
+ uR3 * ( 547 * uR3)
200
+ )
201
+
202
+ tau = xp.abs(beta1 - beta4)
203
+ epsilon = 1e-6
204
+ p = 4
205
+
206
+ gamma1, gamma2, gamma3, gamma4 = 1/35, 12/35, 18/35, 4/35
207
+
208
+ bL1, aL1, c1, aR1, bR1, cR1, dR1 = -1/4, 13/12, -23/12, 25/12, 0, 0, 0
209
+ bL2, aL2, c2, aR2, bR2, cR2, dR2 = 0, 1/12, -5/12, 13/12, 1/4, 0, 0
210
+ bL3, aL3, c3, aR3, bR3, cR3, dR3 = 0, 0, -1/12, 7/12, 7/12, -1/12, 0
211
+ bL4, aL4, c4, aR4, bR4, cR4, dR4 = 0, 0, 0, 1/4, 13/12, -5/12, 1/12
212
+
213
+ weights1 = gamma1 * (1 + (tau/(beta1 + epsilon)) ** p)
214
+ weights2 = gamma2 * (1 + (tau/(beta2 + epsilon)) ** p)
215
+ weights3 = gamma3 * (1 + (tau/(beta3 + epsilon)) ** p)
216
+ weights4 = gamma4 * (1 + (tau/(beta4 + epsilon)) ** p)
217
+
218
+ w1 = weights1 / (weights1 + weights2 + weights3 + weights4)
219
+ w2 = weights2 / (weights1 + weights2 + weights3 + weights4)
220
+ w3 = weights3 / (weights1 + weights2 + weights3 + weights4)
221
+ w4 = weights4 / (weights1 + weights2 + weights3 + weights4)
222
+
223
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3 + w4 * bL4
224
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3 + w4 * aL4
225
+ c = w1 * c1 + w2 * c2 + w3 * c3 + w4 * c4
226
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3 + w4 * aR4
227
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3 + w4 * bR4
228
+ cR = w1 * cR1 + w2 * cR2 + w3 * cR3 + w4 * cR4
229
+ dR = w1 * dR1 + w2 * dR2 + w3 * dR3 + w4 * dR4
230
+ uR = bL * uL3 + aL * uL2 + c * uL1 + aR * u0 + bR * uR1 + cR * uR2 + dR * uR3
231
+
232
+ gamma1, gamma2, gamma3, gamma4 = 4/35, 18/35, 12/35, 1/35
233
+
234
+ bL1, aL1, c1, aR1, bR1, cR1, dR1 = 1/12, -5/12, 13/12, 1/4, 0, 0, 0
235
+ bL2, aL2, c2, aR2, bR2, cR2, dR2 = 0, -1/12, 7/12, 7/12, -1/12, 0, 0
236
+ bL3, aL3, c3, aR3, bR3, cR3, dR3 = 0, 0, 1/4, 13/12, -5/12, 1/12, 0
237
+ bL4, aL4, c4, aR4, bR4, cR4, dR4 = 0, 0, 0, 25/12, -23/12, 13/12, -1/4
238
+
239
+ weights1 = gamma1 * (1 + (tau/(beta1 + epsilon)) ** p)
240
+ weights2 = gamma2 * (1 + (tau/(beta2 + epsilon)) ** p)
241
+ weights3 = gamma3 * (1 + (tau/(beta3 + epsilon)) ** p)
242
+ weights4 = gamma4 * (1 + (tau/(beta4 + epsilon)) ** p)
243
+
244
+ w1 = weights1 / (weights1 + weights2 + weights3 + weights4)
245
+ w2 = weights2 / (weights1 + weights2 + weights3 + weights4)
246
+ w3 = weights3 / (weights1 + weights2 + weights3 + weights4)
247
+ w4 = weights4 / (weights1 + weights2 + weights3 + weights4)
248
+
249
+ bL = w1 * bL1 + w2 * bL2 + w3 * bL3 + w4 * bL4
250
+ aL = w1 * aL1 + w2 * aL2 + w3 * aL3 + w4 * aL4
251
+ c = w1 * c1 + w2 * c2 + w3 * c3 + w4 * c4
252
+ aR = w1 * aR1 + w2 * aR2 + w3 * aR3 + w4 * aR4
253
+ bR = w1 * bR1 + w2 * bR2 + w3 * bR3 + w4 * bR4
254
+ cR = w1 * cR1 + w2 * cR2 + w3 * cR3 + w4 * cR4
255
+ dR = w1 * dR1 + w2 * dR2 + w3 * dR3 + w4 * dR4
256
+ uL = bL * uL3 + aL * uL2 + c * uL1 + aR * u0 + bR * uR1 + cR * uR2 + dR * uR3
257
+
258
+ return uL, uR
259
+
260
+ if hasJAX:
261
+ import jax
262
+ coreWENO7 = jax.jit(coreWENO7, static_argnames=("axis",))
263
+
264
+ def WENO7(u: ArrayLike, axis: int) -> tuple[np.ndarray, np.ndarray]:
265
+ """WENO7-Z reconstruction on a 3D array along the given axis."""
266
+ if hasJAX:
267
+ import jax.numpy as jnp
268
+ xp: Any = jnp
269
+ else:
270
+ xp = np
271
+ u = xp.asarray(u)
272
+
273
+ if axis == 2:
274
+ u = xp.moveaxis(u, 2, 0)
275
+ uL, uR = coreWENO7(u, axis=0)
276
+ return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
277
+
278
+ uL, uR = coreWENO7(u, axis=axis)
279
+ return np.asarray(uL), np.asarray(uR)
280
+
281
+ def WENO7x(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
282
+ """WENO7-Z reconstruction in x direction."""
283
+ if hasJAX:
284
+ import jax
285
+ u = jax.device_put(u, sharding["x"])
286
+ return WENO7(u, axis=0)
287
+
288
+ def WENO7y(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
289
+ """WENO7-Z reconstruction in y direction."""
290
+ if hasJAX:
291
+ import jax
292
+ u = jax.device_put(u, sharding["y"])
293
+ return WENO7(u, axis=1)
294
+
295
+ def WENO7z(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
296
+ """WENO7-Z reconstruction in z direction."""
297
+ if hasJAX:
298
+ import jax
299
+ u = jax.device_put(u, sharding["z"])
300
+ return WENO7(u, axis=2)
301
+
302
+ def WENOx(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
303
+ """WENO5-Z or WENO7-Z reconstruction in x direction."""
304
+ if stencils == 5:
305
+ return WENO5x(u)
306
+ if stencils == 7:
307
+ return WENO7x(u)
308
+
309
+ raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
310
+
311
+ def WENOy(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
312
+ """WENO5-Z or WENO7-Z reconstruction in y direction."""
313
+ if stencils == 5:
314
+ return WENO5y(u)
315
+ if stencils == 7:
316
+ return WENO7y(u)
317
+
318
+ raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
319
+
320
+ def WENOz(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
321
+ """WENO5-Z or WENO7-Z reconstruction in z direction."""
322
+ if stencils == 5:
323
+ return WENO5z(u)
324
+ if stencils == 7:
325
+ return WENO7z(u)
326
+
327
+ raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
@@ -0,0 +1,24 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/__init__.py
7
+ """
8
+
9
+ from .derivative import (
10
+ Field, Algorithm,
11
+ Dx, Dy, Dz,
12
+ grad, div, curl, laplacian,
13
+ average2center,
14
+ )
15
+
16
+ from . import derivative
17
+
18
+ __all__ = [
19
+ 'derivative',
20
+ 'Field', 'Algorithm',
21
+ 'Dx', 'Dy', 'Dz',
22
+ 'grad', 'div', 'curl', 'laplacian',
23
+ 'average2center',
24
+ ]