PyMHD 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymhd/__init__.py +31 -0
- pymhd/derivatives/TENO.py +278 -0
- pymhd/derivatives/WENO.py +323 -0
- pymhd/derivatives/__init__.py +24 -0
- pymhd/derivatives/compact.py +365 -0
- pymhd/derivatives/derivative.py +926 -0
- pymhd/numdiss.py +598 -0
- pymhd/plot/__init__.py +48 -0
- pymhd/plot/nd.py +1519 -0
- pymhd/plot/slc.py +648 -0
- pymhd/plot/spc.py +249 -0
- pymhd/preprocess/Athena.py +847 -0
- pymhd/preprocess/__init__.py +69 -0
- pymhd/preprocess/helper/NOTICE +42 -0
- pymhd/preprocess/helper/bin_convert.py +2000 -0
- pymhd/preprocess/helper/make_athdf.py +45 -0
- pymhd/spectra.py +376 -0
- pymhd/turbulence.py +917 -0
- pymhd-0.1.0.dist-info/METADATA +100 -0
- pymhd-0.1.0.dist-info/RECORD +22 -0
- pymhd-0.1.0.dist-info/WHEEL +4 -0
- pymhd-0.1.0.dist-info/licenses/LICENSE +21 -0
pymhd/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# PyMHD: Python for Magnetohydrodynamic Turbulence.
|
|
2
|
+
# Copyright (c) 2026 Yuyang Hua (华宇阳)
|
|
3
|
+
# License: MIT
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
pymhd/__init__.py
|
|
7
|
+
-----------------
|
|
8
|
+
|
|
9
|
+
Unified interface for PyMHD.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .turbulence import Scalar, Vector, ScalarField, VectorField, Turbulence, avg, std
|
|
13
|
+
from .derivatives.derivative import Dx, Dy, Dz, grad, div, curl, laplacian, Algorithm
|
|
14
|
+
|
|
15
|
+
from .spectra import Spectrum, Spectrum1D, EnergySpectra
|
|
16
|
+
from .numdiss import NumericalDissipation
|
|
17
|
+
from .preprocess import output2turbulence
|
|
18
|
+
|
|
19
|
+
from .plot import plot
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"Scalar", "Vector",
|
|
23
|
+
"ScalarField", "VectorField", "avg", "std", "Turbulence",
|
|
24
|
+
"Algorithm",
|
|
25
|
+
"Dx", "Dy", "Dz", "grad", "div", "curl", "laplacian",
|
|
26
|
+
"Spectrum", "Spectrum1D",
|
|
27
|
+
"EnergySpectra",
|
|
28
|
+
"NumericalDissipation",
|
|
29
|
+
"output2turbulence",
|
|
30
|
+
"plot",
|
|
31
|
+
]
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# PyMHD: Python for Magnetohydrodynamic Turbulence.
|
|
2
|
+
# Copyright (c) 2026 Yuyang Hua (华宇阳)
|
|
3
|
+
# License: MIT
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
pymhd/derivatives/TENO.py
|
|
7
|
+
--------------------------
|
|
8
|
+
|
|
9
|
+
Implements the TENO7-M reconstruction scheme with two backends:
|
|
10
|
+
- NumPy (default): more compatible, but limited to a single CPU core
|
|
11
|
+
- JAX: requires JAX, but leverages multi-core CPU and GPU acceleration
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from numpy.typing import ArrayLike
|
|
18
|
+
|
|
19
|
+
hasJAX = False
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import jax
|
|
23
|
+
jax.config.update("jax_enable_x64", True)
|
|
24
|
+
|
|
25
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
|
26
|
+
jax.config.update("jax_num_cpu_devices", 8)
|
|
27
|
+
|
|
28
|
+
mesh: Mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=("X", "Y"))
|
|
29
|
+
|
|
30
|
+
sharding = {
|
|
31
|
+
"x": NamedSharding(mesh, P(None, "X", "Y")),
|
|
32
|
+
"y": NamedSharding(mesh, P("X", None, "Y")),
|
|
33
|
+
"z": NamedSharding(mesh, P("X", "Y", None)),
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
hasJAX = True
|
|
37
|
+
|
|
38
|
+
except ImportError:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def coreTENO7M(
|
|
42
|
+
u: ArrayLike, direction: str, axis: int, CT: float = 0.01
|
|
43
|
+
) -> tuple[Any, Any]:
|
|
44
|
+
"""TENO7-M reconstruction.
|
|
45
|
+
|
|
46
|
+
Reconstruct the cell interface values from the cell averages
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
u : npt.ArrayLike, cell-centered or cell-averaged values
|
|
51
|
+
direction: str, reconstruction direction, 'L' (for i - 1/2) or 'R' (for i + 1/2)
|
|
52
|
+
axis : int, the axis direction (0, 1, 2 corresponds to x, y, z)
|
|
53
|
+
CT : float, smoothness threshold, defaults to 0.01
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
uHalf : np.ndarray or jax.Array, reconstructed value at the cell interfaces (i ± 1/2)
|
|
58
|
+
shockflag: np.ndarray or jax.Array, shock detection flag
|
|
59
|
+
"""
|
|
60
|
+
if hasJAX:
|
|
61
|
+
import jax.numpy as jnp
|
|
62
|
+
xp = jnp
|
|
63
|
+
else:
|
|
64
|
+
xp = np
|
|
65
|
+
u = xp.asarray(u)
|
|
66
|
+
|
|
67
|
+
def roll(arr, offset: int):
|
|
68
|
+
shift = -offset if direction == "R" else offset
|
|
69
|
+
return xp.roll(arr, shift, axis=axis)
|
|
70
|
+
|
|
71
|
+
uR3 = roll(u, 3)
|
|
72
|
+
uR2 = roll(u, 2)
|
|
73
|
+
uR1 = roll(u, 1)
|
|
74
|
+
u0 = u
|
|
75
|
+
uL1 = roll(u, -1)
|
|
76
|
+
uL2 = roll(u, -2)
|
|
77
|
+
uL3 = roll(u, -3)
|
|
78
|
+
|
|
79
|
+
# Calculate the smoothness indicators for (7, 4) stencil
|
|
80
|
+
# Ref: Balsara & Shu, Journal of Computational Physics 160, no. 2 (May 2000): 405–52
|
|
81
|
+
beta0 = 1 / 240 * (
|
|
82
|
+
uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
|
|
83
|
+
uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
|
|
84
|
+
uL1 * (11003 * uL1 - 9402 * u0) + \
|
|
85
|
+
u0 * ( 2107 * u0)
|
|
86
|
+
)
|
|
87
|
+
beta1 = 1 / 240 * (
|
|
88
|
+
uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
|
|
89
|
+
uL1 * (2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
|
|
90
|
+
u0 * (3443 * u0 - 2522 * uR1) + \
|
|
91
|
+
uR1 * ( 547 * uR1)
|
|
92
|
+
)
|
|
93
|
+
beta2 = 1 / 240 * (
|
|
94
|
+
uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
|
|
95
|
+
u0 * (3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
|
|
96
|
+
uR1 * (2843 * uR1 - 1642 * uR2) + \
|
|
97
|
+
uR2 * ( 267 * uR2)
|
|
98
|
+
)
|
|
99
|
+
beta3 = 1 / 240 * (
|
|
100
|
+
u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
|
|
101
|
+
uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
|
|
102
|
+
uR2 * ( 7043 * uR2 - 3882 * uR3) + \
|
|
103
|
+
uR3 * ( 547 * uR3)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
epsilon = 1e-40
|
|
107
|
+
q = 6
|
|
108
|
+
|
|
109
|
+
gamma0 = 1 / (beta0 + epsilon) ** q
|
|
110
|
+
gamma1 = 1 / (beta1 + epsilon) ** q
|
|
111
|
+
gamma2 = 1 / (beta2 + epsilon) ** q
|
|
112
|
+
gamma3 = 1 / (beta3 + epsilon) ** q
|
|
113
|
+
|
|
114
|
+
chi0mask = (gamma0 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
|
|
115
|
+
chi1mask = (gamma1 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
|
|
116
|
+
chi2mask = (gamma2 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
|
|
117
|
+
chi3mask = (gamma3 / (gamma0 + gamma1 + gamma2 + gamma3)) < CT
|
|
118
|
+
|
|
119
|
+
delta0 = ~(chi0mask & roll(chi1mask, -1) & roll(chi2mask, -2) & roll(chi3mask, -3))
|
|
120
|
+
delta1 = ~(chi1mask & roll(chi2mask, -1) & roll(chi3mask, -2) & roll(chi0mask, 1))
|
|
121
|
+
delta2 = ~(chi2mask & roll(chi3mask, -1) & roll(chi1mask, 1) & roll(chi0mask, 2))
|
|
122
|
+
delta3 = ~(chi3mask & roll(chi2mask, 1) & roll(chi1mask, 2) & roll(chi0mask, 3))
|
|
123
|
+
|
|
124
|
+
# ----- Dispersion-relation-preserving (DRP) -----
|
|
125
|
+
# Ref: Tam. Computational aeroacoustics: A wave number approach, 2012.
|
|
126
|
+
a71 = 0.77088238051822552
|
|
127
|
+
a72 = -0.166705904414580469
|
|
128
|
+
a73 = 0.02084314277031176
|
|
129
|
+
|
|
130
|
+
a91 = 0.8301178834769906875382633360472085
|
|
131
|
+
a92 = -0.23175338776901819008451262109655756
|
|
132
|
+
a93 = 0.05287205020483696423592156502901203
|
|
133
|
+
a94 = -6.306814638366300019250697235282424e-3
|
|
134
|
+
|
|
135
|
+
cL3 = 2 * a94
|
|
136
|
+
cL2 = 2 * a93 - a73
|
|
137
|
+
cL1 = 2 * a92 + 2 * a94 - a72 + a73
|
|
138
|
+
c0 = 2 * a91 + 2 * a93 - a71 + a72 - a73
|
|
139
|
+
cR1 = 2 * a92 + 2 * a94 - a72 + a73 + a71
|
|
140
|
+
cR2 = 2 * a93 - a73 + a72
|
|
141
|
+
cR3 = 2 * a94 + a73
|
|
142
|
+
|
|
143
|
+
# Optimal reconstruction polynomial for smooth stencils
|
|
144
|
+
uHalf = cL3 * uL3 + cL2 * uL2 + cL1 * uL1 + c0 * u0 + cR1 * uR1 + cR2 * uR2 + cR3 * uR3
|
|
145
|
+
|
|
146
|
+
# a single discontinuity lies in (uL3, uL2)
|
|
147
|
+
uHalf = xp.where(
|
|
148
|
+
(~delta0) & delta1 & delta2 & delta3,
|
|
149
|
+
1 / 60 * uL2 - 2 / 15 * uL1 + 37 / 60 * u0 + 37 / 60 * uR1 - 2 / 15 * uR2 + 1 / 60 * uR3,
|
|
150
|
+
uHalf,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# a single discontinuity lies in (uL2, uL1)
|
|
154
|
+
uHalf = xp.where(
|
|
155
|
+
(~delta0) & (~delta1) & delta2 & delta3,
|
|
156
|
+
-1 / 20 * uL1 + 9 / 20 * u0 + 47 / 60 * uR1 - 13 / 60 * uR2 + 1 / 30 * uR3,
|
|
157
|
+
uHalf,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# a single discontinuity lies in (uL1, u0)
|
|
161
|
+
uHalf = xp.where(
|
|
162
|
+
(~delta0) & (~delta1) & (~delta2) & delta3,
|
|
163
|
+
1 / 4 * u0 + 13 / 12 * uR1 - 5 / 12 * uR2 + 1 / 12 * uR3,
|
|
164
|
+
uHalf,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# a single discontinuity lies in (u0, uR1)
|
|
168
|
+
uHalf = xp.where(
|
|
169
|
+
delta0 & (~delta1) & (~delta2) & (~delta3),
|
|
170
|
+
-1 / 4 * uL3 + 13 / 12 * uL2 - 23 / 12 * uL1 + 25 / 12 * u0,
|
|
171
|
+
uHalf,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# a single discontinuity lies in (uR1, uR2)
|
|
175
|
+
uHalf = xp.where(
|
|
176
|
+
delta0 & delta1 & (~delta2) & (~delta3),
|
|
177
|
+
-1 / 20 * uL3 + 17 / 60 * uL2 - 43 / 60 * uL1 + 77 / 60 * u0 + 1 / 5 * uR1,
|
|
178
|
+
uHalf,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# a single discontinuity lies in (uR2, uR3)
|
|
182
|
+
uHalf = xp.where(
|
|
183
|
+
delta0 & delta1 & delta2 & (~delta3),
|
|
184
|
+
-1 / 60 * uL3 + 7 / 60 * uL2 - 23 / 60 * uL1 + 19 / 20 * u0 + 11 / 30 * uR1 - 1 / 30 * uR2,
|
|
185
|
+
uHalf,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# two discontinuities lie in (uL3, uL2) and (uR2, uR3)
|
|
189
|
+
uHalf = xp.where(
|
|
190
|
+
(~delta0) & delta1 & delta2 & (~delta3),
|
|
191
|
+
2 / 60 * uL2 - 13 / 60 * uL1 + 47 / 60 * u0 + 27 / 60 * uR1 - 3 / 60 * uR2,
|
|
192
|
+
uHalf,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Any stencil other than fully smooth is considered discontinuous.
|
|
196
|
+
shockflag = ~(delta0 & delta1 & delta2 & delta3)
|
|
197
|
+
|
|
198
|
+
return uHalf, shockflag
|
|
199
|
+
|
|
200
|
+
if hasJAX:
|
|
201
|
+
import jax
|
|
202
|
+
coreTENO7M = jax.jit(coreTENO7M, static_argnames=("direction", "axis", "CT"))
|
|
203
|
+
|
|
204
|
+
def TENO7M(
|
|
205
|
+
u: ArrayLike, axis: int, mode: str = "hybrid", CT: float = 0.01
|
|
206
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
207
|
+
"""TENO7-M reconstruction."""
|
|
208
|
+
if hasJAX:
|
|
209
|
+
import jax.numpy as jnp
|
|
210
|
+
xp: Any = jnp
|
|
211
|
+
else:
|
|
212
|
+
xp = np
|
|
213
|
+
u = xp.asarray(u)
|
|
214
|
+
|
|
215
|
+
if axis == 2:
|
|
216
|
+
u = xp.moveaxis(u, 2, 0)
|
|
217
|
+
|
|
218
|
+
uL, shockFlagL = coreTENO7M(u, direction="L", axis=0, CT=CT)
|
|
219
|
+
uR, shockFlagR = coreTENO7M(u, direction="R", axis=0, CT=CT)
|
|
220
|
+
|
|
221
|
+
if mode == "inner":
|
|
222
|
+
return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
|
|
223
|
+
|
|
224
|
+
if mode == "upwind":
|
|
225
|
+
uLupwind = xp.roll(uR, 1, axis=0)
|
|
226
|
+
return np.asarray(xp.moveaxis(uLupwind, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
|
|
227
|
+
|
|
228
|
+
if mode == "hybrid":
|
|
229
|
+
uLhybrid = xp.where(~shockFlagL, (uL + xp.roll(uR, 1, axis=0)) / 2, uL)
|
|
230
|
+
uRhybrid = xp.where(~shockFlagR, (uR + xp.roll(uL, -1, axis=0)) / 2, uR)
|
|
231
|
+
return np.asarray(xp.moveaxis(uLhybrid, 0, 2)), np.asarray(xp.moveaxis(uRhybrid, 0, 2))
|
|
232
|
+
|
|
233
|
+
raise ValueError(f"Invalid mode: {mode}. Options: 'inner', 'upwind', and 'hybrid'.")
|
|
234
|
+
|
|
235
|
+
uL, shockFlagL = coreTENO7M(u, direction="L", axis=axis, CT=CT)
|
|
236
|
+
uR, shockFlagR = coreTENO7M(u, direction="R", axis=axis, CT=CT)
|
|
237
|
+
|
|
238
|
+
if mode == "inner":
|
|
239
|
+
return np.asarray(uL), np.asarray(uR)
|
|
240
|
+
|
|
241
|
+
if mode == "upwind":
|
|
242
|
+
uLupwind = xp.roll(uR, 1, axis=axis)
|
|
243
|
+
return np.asarray(uLupwind), np.asarray(uR)
|
|
244
|
+
|
|
245
|
+
if mode == "hybrid":
|
|
246
|
+
uLhybrid = xp.where(~shockFlagL, (uL + xp.roll(uR, 1, axis=axis)) / 2, uL)
|
|
247
|
+
uRhybrid = xp.where(~shockFlagR, (uR + xp.roll(uL, -1, axis=axis)) / 2, uR)
|
|
248
|
+
return np.asarray(uLhybrid), np.asarray(uRhybrid)
|
|
249
|
+
|
|
250
|
+
raise ValueError(f"Invalid mode: {mode}. Options: 'inner', 'upwind', and 'hybrid'.")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def TENO7Mx(
|
|
254
|
+
u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
|
|
255
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
256
|
+
"""TENO7-M reconstruction in x direction."""
|
|
257
|
+
if hasJAX:
|
|
258
|
+
import jax
|
|
259
|
+
u = jax.device_put(u, sharding["x"])
|
|
260
|
+
return TENO7M(u, axis=0, mode=mode, CT=CT)
|
|
261
|
+
|
|
262
|
+
def TENO7My(
|
|
263
|
+
u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
|
|
264
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
265
|
+
"""TENO7-M reconstruction in y direction."""
|
|
266
|
+
if hasJAX:
|
|
267
|
+
import jax
|
|
268
|
+
u = jax.device_put(u, sharding["y"])
|
|
269
|
+
return TENO7M(u, axis=1, mode=mode, CT=CT)
|
|
270
|
+
|
|
271
|
+
def TENO7Mz(
|
|
272
|
+
u: ArrayLike, mode: str = "hybrid", CT: float = 0.01
|
|
273
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
274
|
+
"""TENO7-M reconstruction in z direction."""
|
|
275
|
+
if hasJAX:
|
|
276
|
+
import jax
|
|
277
|
+
u = jax.device_put(u, sharding["z"])
|
|
278
|
+
return TENO7M(u, axis=2, mode=mode, CT=CT)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# PyMHD: Python for Magnetohydrodynamic Turbulence.
|
|
2
|
+
# Copyright (c) 2026 Yuyang Hua (华宇阳)
|
|
3
|
+
# License: MIT
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
pymhd/derivatives/WENO.py
|
|
7
|
+
--------------------------
|
|
8
|
+
|
|
9
|
+
Implements the WENO5-Z and WENO7-Z reconstruction schemes with two backends:
|
|
10
|
+
- NumPy (default): more compatible, but limited to a single CPU core
|
|
11
|
+
- JAX: requires JAX, but leverages multi-core CPU and GPU acceleration
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from numpy.typing import ArrayLike
|
|
18
|
+
|
|
19
|
+
hasJAX = False
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import jax
|
|
23
|
+
jax.config.update("jax_enable_x64", True)
|
|
24
|
+
|
|
25
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
|
26
|
+
jax.config.update("jax_num_cpu_devices", 8)
|
|
27
|
+
|
|
28
|
+
mesh: Mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=("X", "Y"))
|
|
29
|
+
|
|
30
|
+
sharding = {
|
|
31
|
+
"x": NamedSharding(mesh, P(None, "X", "Y")),
|
|
32
|
+
"y": NamedSharding(mesh, P("X", None, "Y")),
|
|
33
|
+
"z": NamedSharding(mesh, P("X", "Y", None)),
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
hasJAX = True
|
|
37
|
+
|
|
38
|
+
except ImportError:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
# ===== WENO5 =====
|
|
42
|
+
|
|
43
|
+
def coreWENO5(u: ArrayLike, axis: int) -> tuple[Any, Any]:
|
|
44
|
+
"""WENO5-Z reconstruction along a single axis."""
|
|
45
|
+
if hasJAX:
|
|
46
|
+
import jax.numpy as jnp
|
|
47
|
+
xp = jnp
|
|
48
|
+
else:
|
|
49
|
+
xp = np
|
|
50
|
+
u = xp.asarray(u)
|
|
51
|
+
|
|
52
|
+
uR2 = xp.roll(u, -2, axis=axis)
|
|
53
|
+
uR1 = xp.roll(u, -1, axis=axis)
|
|
54
|
+
u0 = u
|
|
55
|
+
uL1 = xp.roll(u, 1, axis=axis)
|
|
56
|
+
uL2 = xp.roll(u, 2, axis=axis)
|
|
57
|
+
|
|
58
|
+
beta1 = 13/12 * (uL2 - 2 * uL1 + u0 )**2 + 1/4 * (uL2 - 4 * uL1 + 3 * u0)**2
|
|
59
|
+
beta2 = 13/12 * (uL1 - 2 * u0 + uR1)**2 + 1/4 * (uL1 - uR1)**2
|
|
60
|
+
beta3 = 13/12 * (u0 - 2 * uR1 + uR2)**2 + 1/4 * (3 * u0 - 4 * uR1 + uR2)**2
|
|
61
|
+
|
|
62
|
+
tau5 = xp.abs(beta1 - beta3)
|
|
63
|
+
epsilon = 1e-6
|
|
64
|
+
p = 4
|
|
65
|
+
|
|
66
|
+
gamma1, gamma2, gamma3 = 1/10, 3/5, 3/10
|
|
67
|
+
|
|
68
|
+
bL1, aL1, c1, aR1, bR1 = 1/3, -7/6, 11/6, 0, 0
|
|
69
|
+
bL2, aL2, c2, aR2, bR2 = 0, -1/6, 5/6, 1/3, 0
|
|
70
|
+
bL3, aL3, c3, aR3, bR3 = 0, 0, 1/3, 5/6, -1/6
|
|
71
|
+
|
|
72
|
+
weights1 = gamma1 * (1 + (tau5/(beta1 + epsilon)) ** p)
|
|
73
|
+
weights2 = gamma2 * (1 + (tau5/(beta2 + epsilon)) ** p)
|
|
74
|
+
weights3 = gamma3 * (1 + (tau5/(beta3 + epsilon)) ** p)
|
|
75
|
+
|
|
76
|
+
w1 = weights1 / (weights1 + weights2 + weights3)
|
|
77
|
+
w2 = weights2 / (weights1 + weights2 + weights3)
|
|
78
|
+
w3 = weights3 / (weights1 + weights2 + weights3)
|
|
79
|
+
|
|
80
|
+
bL = w1 * bL1 + w2 * bL2 + w3 * bL3
|
|
81
|
+
aL = w1 * aL1 + w2 * aL2 + w3 * aL3
|
|
82
|
+
c = w1 * c1 + w2 * c2 + w3 * c3
|
|
83
|
+
aR = w1 * aR1 + w2 * aR2 + w3 * aR3
|
|
84
|
+
bR = w1 * bR1 + w2 * bR2 + w3 * bR3
|
|
85
|
+
|
|
86
|
+
uR = bL * uL2 + aL * uL1 + c * u0 + aR * uR1 + bR * uR2
|
|
87
|
+
|
|
88
|
+
gamma1, gamma2, gamma3 = 3/10, 3/5, 1/10
|
|
89
|
+
|
|
90
|
+
bL1, aL1, c1, aR1, bR1 = -1/6, 5/6, 1/3, 0, 0
|
|
91
|
+
bL2, aL2, c2, aR2, bR2 = 0, 1/3, 5/6, -1/6, 0
|
|
92
|
+
bL3, aL3, c3, aR3, bR3 = 0, 0, 11/6, -7/6, 1/3
|
|
93
|
+
|
|
94
|
+
weights1 = gamma1 * (1 + (tau5/(beta1 + epsilon)) ** p)
|
|
95
|
+
weights2 = gamma2 * (1 + (tau5/(beta2 + epsilon)) ** p)
|
|
96
|
+
weights3 = gamma3 * (1 + (tau5/(beta3 + epsilon)) ** p)
|
|
97
|
+
|
|
98
|
+
w1 = weights1 / (weights1 + weights2 + weights3)
|
|
99
|
+
w2 = weights2 / (weights1 + weights2 + weights3)
|
|
100
|
+
w3 = weights3 / (weights1 + weights2 + weights3)
|
|
101
|
+
|
|
102
|
+
bL = w1 * bL1 + w2 * bL2 + w3 * bL3
|
|
103
|
+
aL = w1 * aL1 + w2 * aL2 + w3 * aL3
|
|
104
|
+
c = w1 * c1 + w2 * c2 + w3 * c3
|
|
105
|
+
aR = w1 * aR1 + w2 * aR2 + w3 * aR3
|
|
106
|
+
bR = w1 * bR1 + w2 * bR2 + w3 * bR3
|
|
107
|
+
|
|
108
|
+
uL = bL * uL2 + aL * uL1 + c * u0 + aR * uR1 + bR * uR2
|
|
109
|
+
|
|
110
|
+
return uL, uR
|
|
111
|
+
|
|
112
|
+
if hasJAX:
|
|
113
|
+
import jax
|
|
114
|
+
coreWENO5 = jax.jit(coreWENO5, static_argnames=("axis",))
|
|
115
|
+
|
|
116
|
+
def WENO5(u: ArrayLike, axis: int) -> tuple[np.ndarray, np.ndarray]:
|
|
117
|
+
"""WENO5-Z reconstruction on a 3D array along the given axis."""
|
|
118
|
+
if hasJAX:
|
|
119
|
+
import jax.numpy as jnp
|
|
120
|
+
xp: Any = jnp
|
|
121
|
+
else:
|
|
122
|
+
xp = np
|
|
123
|
+
u = xp.asarray(u)
|
|
124
|
+
|
|
125
|
+
if axis == 2:
|
|
126
|
+
u = xp.moveaxis(u, 2, 0)
|
|
127
|
+
uL, uR = coreWENO5(u, axis=0)
|
|
128
|
+
return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
|
|
129
|
+
|
|
130
|
+
uL, uR = coreWENO5(u, axis=axis)
|
|
131
|
+
return np.asarray(uL), np.asarray(uR)
|
|
132
|
+
|
|
133
|
+
def WENO5x(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
|
|
134
|
+
"""WENO5-Z reconstruction in x direction."""
|
|
135
|
+
if hasJAX:
|
|
136
|
+
import jax
|
|
137
|
+
u = jax.device_put(u, sharding["x"])
|
|
138
|
+
return WENO5(u, axis=0)
|
|
139
|
+
|
|
140
|
+
def WENO5y(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
|
|
141
|
+
"""WENO5-Z reconstruction in y direction."""
|
|
142
|
+
if hasJAX:
|
|
143
|
+
import jax
|
|
144
|
+
u = jax.device_put(u, sharding["y"])
|
|
145
|
+
return WENO5(u, axis=1)
|
|
146
|
+
|
|
147
|
+
def WENO5z(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
|
|
148
|
+
"""WENO5-Z reconstruction in z direction."""
|
|
149
|
+
if hasJAX:
|
|
150
|
+
import jax
|
|
151
|
+
u = jax.device_put(u, sharding["z"])
|
|
152
|
+
return WENO5(u, axis=2)
|
|
153
|
+
|
|
154
|
+
# ===== WENO7 =====
|
|
155
|
+
|
|
156
|
+
def coreWENO7(u: ArrayLike, axis: int) -> tuple[Any, Any]:
|
|
157
|
+
"""WENO7-Z reconstruction along a single axis."""
|
|
158
|
+
if hasJAX:
|
|
159
|
+
import jax.numpy as jnp
|
|
160
|
+
xp = jnp
|
|
161
|
+
else:
|
|
162
|
+
xp = np
|
|
163
|
+
u = xp.asarray(u)
|
|
164
|
+
|
|
165
|
+
uR3 = xp.roll(u, -3, axis=axis)
|
|
166
|
+
uR2 = xp.roll(u, -2, axis=axis)
|
|
167
|
+
uR1 = xp.roll(u, -1, axis=axis)
|
|
168
|
+
u0 = u
|
|
169
|
+
uL1 = xp.roll(u, 1, axis=axis)
|
|
170
|
+
uL2 = xp.roll(u, 2, axis=axis)
|
|
171
|
+
uL3 = xp.roll(u, 3, axis=axis)
|
|
172
|
+
|
|
173
|
+
beta1 = 1 / 240 * (
|
|
174
|
+
uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
|
|
175
|
+
uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
|
|
176
|
+
uL1 * (11003 * uL1 - 9402 * u0) + \
|
|
177
|
+
u0 * ( 2107 * u0)
|
|
178
|
+
)
|
|
179
|
+
beta2 = 1 / 240 * (
|
|
180
|
+
uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
|
|
181
|
+
uL1 * (2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
|
|
182
|
+
u0 * (3443 * u0 - 2522 * uR1) + \
|
|
183
|
+
uR1 * ( 547 * uR1)
|
|
184
|
+
)
|
|
185
|
+
beta3 = 1 / 240 * (
|
|
186
|
+
uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
|
|
187
|
+
u0 * (3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
|
|
188
|
+
uR1 * (2843 * uR1 - 1642 * uR2) + \
|
|
189
|
+
uR2 * ( 267 * uR2)
|
|
190
|
+
)
|
|
191
|
+
beta4 = 1 / 240 * (
|
|
192
|
+
u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
|
|
193
|
+
uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
|
|
194
|
+
uR2 * ( 7043 * uR2 - 3882 * uR3) + \
|
|
195
|
+
uR3 * ( 547 * uR3)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
tau = xp.abs(beta1 - beta4)
|
|
199
|
+
epsilon = 1e-6
|
|
200
|
+
p = 4
|
|
201
|
+
|
|
202
|
+
gamma1, gamma2, gamma3, gamma4 = 1/35, 12/35, 18/35, 4/35
|
|
203
|
+
|
|
204
|
+
bL1, aL1, c1, aR1, bR1, cR1, dR1 = -1/4, 13/12, -23/12, 25/12, 0, 0, 0
|
|
205
|
+
bL2, aL2, c2, aR2, bR2, cR2, dR2 = 0, 1/12, -5/12, 13/12, 1/4, 0, 0
|
|
206
|
+
bL3, aL3, c3, aR3, bR3, cR3, dR3 = 0, 0, -1/12, 7/12, 7/12, -1/12, 0
|
|
207
|
+
bL4, aL4, c4, aR4, bR4, cR4, dR4 = 0, 0, 0, 1/4, 13/12, -5/12, 1/12
|
|
208
|
+
|
|
209
|
+
weights1 = gamma1 * (1 + (tau/(beta1 + epsilon)) ** p)
|
|
210
|
+
weights2 = gamma2 * (1 + (tau/(beta2 + epsilon)) ** p)
|
|
211
|
+
weights3 = gamma3 * (1 + (tau/(beta3 + epsilon)) ** p)
|
|
212
|
+
weights4 = gamma4 * (1 + (tau/(beta4 + epsilon)) ** p)
|
|
213
|
+
|
|
214
|
+
w1 = weights1 / (weights1 + weights2 + weights3 + weights4)
|
|
215
|
+
w2 = weights2 / (weights1 + weights2 + weights3 + weights4)
|
|
216
|
+
w3 = weights3 / (weights1 + weights2 + weights3 + weights4)
|
|
217
|
+
w4 = weights4 / (weights1 + weights2 + weights3 + weights4)
|
|
218
|
+
|
|
219
|
+
bL = w1 * bL1 + w2 * bL2 + w3 * bL3 + w4 * bL4
|
|
220
|
+
aL = w1 * aL1 + w2 * aL2 + w3 * aL3 + w4 * aL4
|
|
221
|
+
c = w1 * c1 + w2 * c2 + w3 * c3 + w4 * c4
|
|
222
|
+
aR = w1 * aR1 + w2 * aR2 + w3 * aR3 + w4 * aR4
|
|
223
|
+
bR = w1 * bR1 + w2 * bR2 + w3 * bR3 + w4 * bR4
|
|
224
|
+
cR = w1 * cR1 + w2 * cR2 + w3 * cR3 + w4 * cR4
|
|
225
|
+
dR = w1 * dR1 + w2 * dR2 + w3 * dR3 + w4 * dR4
|
|
226
|
+
uR = bL * uL3 + aL * uL2 + c * uL1 + aR * u0 + bR * uR1 + cR * uR2 + dR * uR3
|
|
227
|
+
|
|
228
|
+
gamma1, gamma2, gamma3, gamma4 = 4/35, 18/35, 12/35, 1/35
|
|
229
|
+
|
|
230
|
+
bL1, aL1, c1, aR1, bR1, cR1, dR1 = 1/12, -5/12, 13/12, 1/4, 0, 0, 0
|
|
231
|
+
bL2, aL2, c2, aR2, bR2, cR2, dR2 = 0, -1/12, 7/12, 7/12, -1/12, 0, 0
|
|
232
|
+
bL3, aL3, c3, aR3, bR3, cR3, dR3 = 0, 0, 1/4, 13/12, -5/12, 1/12, 0
|
|
233
|
+
bL4, aL4, c4, aR4, bR4, cR4, dR4 = 0, 0, 0, 25/12, -23/12, 13/12, -1/4
|
|
234
|
+
|
|
235
|
+
weights1 = gamma1 * (1 + (tau/(beta1 + epsilon)) ** p)
|
|
236
|
+
weights2 = gamma2 * (1 + (tau/(beta2 + epsilon)) ** p)
|
|
237
|
+
weights3 = gamma3 * (1 + (tau/(beta3 + epsilon)) ** p)
|
|
238
|
+
weights4 = gamma4 * (1 + (tau/(beta4 + epsilon)) ** p)
|
|
239
|
+
|
|
240
|
+
w1 = weights1 / (weights1 + weights2 + weights3 + weights4)
|
|
241
|
+
w2 = weights2 / (weights1 + weights2 + weights3 + weights4)
|
|
242
|
+
w3 = weights3 / (weights1 + weights2 + weights3 + weights4)
|
|
243
|
+
w4 = weights4 / (weights1 + weights2 + weights3 + weights4)
|
|
244
|
+
|
|
245
|
+
bL = w1 * bL1 + w2 * bL2 + w3 * bL3 + w4 * bL4
|
|
246
|
+
aL = w1 * aL1 + w2 * aL2 + w3 * aL3 + w4 * aL4
|
|
247
|
+
c = w1 * c1 + w2 * c2 + w3 * c3 + w4 * c4
|
|
248
|
+
aR = w1 * aR1 + w2 * aR2 + w3 * aR3 + w4 * aR4
|
|
249
|
+
bR = w1 * bR1 + w2 * bR2 + w3 * bR3 + w4 * bR4
|
|
250
|
+
cR = w1 * cR1 + w2 * cR2 + w3 * cR3 + w4 * cR4
|
|
251
|
+
dR = w1 * dR1 + w2 * dR2 + w3 * dR3 + w4 * dR4
|
|
252
|
+
uL = bL * uL3 + aL * uL2 + c * uL1 + aR * u0 + bR * uR1 + cR * uR2 + dR * uR3
|
|
253
|
+
|
|
254
|
+
return uL, uR
|
|
255
|
+
|
|
256
|
+
if hasJAX:
|
|
257
|
+
import jax
|
|
258
|
+
coreWENO7 = jax.jit(coreWENO7, static_argnames=("axis",))
|
|
259
|
+
|
|
260
|
+
def WENO7(u: ArrayLike, axis: int) -> tuple[np.ndarray, np.ndarray]:
|
|
261
|
+
"""WENO7-Z reconstruction on a 3D array along the given axis."""
|
|
262
|
+
if hasJAX:
|
|
263
|
+
import jax.numpy as jnp
|
|
264
|
+
xp: Any = jnp
|
|
265
|
+
else:
|
|
266
|
+
xp = np
|
|
267
|
+
u = xp.asarray(u)
|
|
268
|
+
|
|
269
|
+
if axis == 2:
|
|
270
|
+
u = xp.moveaxis(u, 2, 0)
|
|
271
|
+
uL, uR = coreWENO7(u, axis=0)
|
|
272
|
+
return np.asarray(xp.moveaxis(uL, 0, 2)), np.asarray(xp.moveaxis(uR, 0, 2))
|
|
273
|
+
|
|
274
|
+
uL, uR = coreWENO7(u, axis=axis)
|
|
275
|
+
return np.asarray(uL), np.asarray(uR)
|
|
276
|
+
|
|
277
|
+
def WENO7x(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
|
|
278
|
+
"""WENO7-Z reconstruction in x direction."""
|
|
279
|
+
if hasJAX:
|
|
280
|
+
import jax
|
|
281
|
+
u = jax.device_put(u, sharding["x"])
|
|
282
|
+
return WENO7(u, axis=0)
|
|
283
|
+
|
|
284
|
+
def WENO7y(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
|
|
285
|
+
"""WENO7-Z reconstruction in y direction."""
|
|
286
|
+
if hasJAX:
|
|
287
|
+
import jax
|
|
288
|
+
u = jax.device_put(u, sharding["y"])
|
|
289
|
+
return WENO7(u, axis=1)
|
|
290
|
+
|
|
291
|
+
def WENO7z(u: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
|
|
292
|
+
"""WENO7-Z reconstruction in z direction."""
|
|
293
|
+
if hasJAX:
|
|
294
|
+
import jax
|
|
295
|
+
u = jax.device_put(u, sharding["z"])
|
|
296
|
+
return WENO7(u, axis=2)
|
|
297
|
+
|
|
298
|
+
def WENOx(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
|
|
299
|
+
"""WENO5-Z or WENO7-Z reconstruction in x direction."""
|
|
300
|
+
if stencils == 5:
|
|
301
|
+
return WENO5x(u)
|
|
302
|
+
if stencils == 7:
|
|
303
|
+
return WENO7x(u)
|
|
304
|
+
|
|
305
|
+
raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
|
|
306
|
+
|
|
307
|
+
def WENOy(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
|
|
308
|
+
"""WENO5-Z or WENO7-Z reconstruction in y direction."""
|
|
309
|
+
if stencils == 5:
|
|
310
|
+
return WENO5y(u)
|
|
311
|
+
if stencils == 7:
|
|
312
|
+
return WENO7y(u)
|
|
313
|
+
|
|
314
|
+
raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
|
|
315
|
+
|
|
316
|
+
def WENOz(u: ArrayLike, stencils: int = 7) -> tuple[np.ndarray, np.ndarray]:
|
|
317
|
+
"""WENO5-Z or WENO7-Z reconstruction in z direction."""
|
|
318
|
+
if stencils == 5:
|
|
319
|
+
return WENO5z(u)
|
|
320
|
+
if stencils == 7:
|
|
321
|
+
return WENO7z(u)
|
|
322
|
+
|
|
323
|
+
raise ValueError(f"Invalid stencils: {stencils}. Supported stencils: 5, 7.")
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# PyMHD: Python for Magnetohydrodynamic Turbulence.
|
|
2
|
+
# Copyright (c) 2026 Yuyang Hua (华宇阳)
|
|
3
|
+
# License: MIT
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
pymhd/derivatives/__init__.py
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .derivative import (
|
|
10
|
+
Field, Algorithm,
|
|
11
|
+
Dx, Dy, Dz,
|
|
12
|
+
grad, div, curl, laplacian,
|
|
13
|
+
average2center,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from . import derivative
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
'derivative',
|
|
20
|
+
'Field', 'Algorithm',
|
|
21
|
+
'Dx', 'Dy', 'Dz',
|
|
22
|
+
'grad', 'div', 'curl', 'laplacian',
|
|
23
|
+
'average2center',
|
|
24
|
+
]
|