PyMHD 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/compact.py
7
+ ----------------------------
8
+
9
+ Implements the TCS-M scheme:
10
+ - TCS-M: Targeted Compact Scheme with Multi-stencil Discontinuity Detectors (MSDD)
11
+ - Ref: Lele 1992 JCP, Jiang 2000 IJCFD
12
+ - Currently only supports JAX implementation
13
+ """
14
+
15
+ import jax
16
+ jax.config.update("jax_enable_x64", True)
17
+
18
+ import jax.numpy as jnp
19
+ from jax.scipy.sparse.linalg import bicgstab
20
+
21
+ from jax import Array
22
+ from jax.typing import ArrayLike
23
+
24
+ from functools import partial
25
+
26
+ def MSDD7(
27
+ u: ArrayLike, axis: int, CT: float = 0.01
28
+ ) -> list[Array]:
29
+ """Multi-stencil Discontinuity Detector (7-point stencils) for 3D arrays
30
+
31
+ Parameters
32
+ ----------
33
+ u : ArrayLike, 3D array
34
+ axis : int, axis direction (0, 1, 2 corresponds to x, y, z)
35
+ CT : float, smoothness threshold, default: 0.01
36
+
37
+ Returns
38
+ -------
39
+ conditions: list[Array], conditions for different discontinuity configurations
40
+
41
+ References
42
+ ----------
43
+ [1] Balsara & Shu, JCP 160, no. 2 (May 2000): 405-52
44
+
45
+ """
46
+ u = jnp.asarray(u)
47
+
48
+ uL3, uL2, uL1, u0, uR1, uR2, uR3 = [jnp.roll(u, shift, axis=axis) for shift in (3, 2, 1, 0, -1, -2, -3)]
49
+
50
+ # Smoothness indicators for 7-point stencils
51
+ # Ref: Balsara & Shu, Journal of Computational Physics 160, no. 2 (May 2000): 405–52
52
+
53
+ # β0
54
+ beta0 = 1/240 * (
55
+ uL3 * ( 547 * uL3 - 3882 * uL2 + 4642 * uL1 - 1854 * u0) + \
56
+ uL2 * ( 7043 * uL2 - 17246 * uL1 + 7042 * u0) + \
57
+ uL1 * (11003 * uL1 - 9402 * u0) + \
58
+ u0 * ( 2107 * u0)
59
+ )
60
+
61
+ # β1
62
+ beta1 = 1/240 * (
63
+ uL2 * ( 267 * uL2 - 1642 * uL1 + 1602 * u0 - 494 * uR1) + \
64
+ uL1 * ( 2843 * uL1 - 5966 * u0 + 1922 * uR1) + \
65
+ u0 * ( 3443 * u0 - 2522 * uR1) + \
66
+ uR1 * ( 547 * uR1) \
67
+ )
68
+
69
+ # β2
70
+ beta2 = 1/240 * (
71
+ uL1 * ( 547 * uL1 - 2522 * u0 + 1922 * uR1 - 494 * uR2) + \
72
+ u0 * ( 3443 * u0 - 5966 * uR1 + 1602 * uR2) + \
73
+ uR1 * ( 2843 * uR1 - 1642 * uR2) + \
74
+ uR2 * ( 267 * uR2) \
75
+ )
76
+
77
+ # β3
78
+ beta3 = 1/240 * (
79
+ u0 * ( 2107 * u0 - 9402 * uR1 + 7042 * uR2 - 1854 * uR3) + \
80
+ uR1 * (11003 * uR1 - 17246 * uR2 + 4642 * uR3) + \
81
+ uR2 * ( 7043 * uR2 - 3882 * uR3) + \
82
+ uR3 * ( 547 * uR3 ) \
83
+ )
84
+
85
+ epsilon = 1e-6
86
+ q = 6
87
+
88
+ gamma0 = 1 / (beta0 + epsilon)**q
89
+ gamma1 = 1 / (beta1 + epsilon)**q
90
+ gamma2 = 1 / (beta2 + epsilon)**q
91
+ gamma3 = 1 / (beta3 + epsilon)**q
92
+
93
+ chi0 = gamma0 / (gamma0 + gamma1 + gamma2 + gamma3)
94
+ chi1 = gamma1 / (gamma0 + gamma1 + gamma2 + gamma3)
95
+ chi2 = gamma2 / (gamma0 + gamma1 + gamma2 + gamma3)
96
+ chi3 = gamma3 / (gamma0 + gamma1 + gamma2 + gamma3)
97
+
98
+ # ===== Computing delta values =====
99
+
100
+ # chi values from neighboring stencils
101
+ chi0R1 = jnp.roll(chi0, -1, axis=axis)
102
+ chi0R2 = jnp.roll(chi0, -2, axis=axis)
103
+ chi0R3 = jnp.roll(chi0, -3, axis=axis)
104
+
105
+ chi1L1 = jnp.roll(chi1, 1, axis=axis)
106
+
107
+ chi1R1 = jnp.roll(chi1, -1, axis=axis)
108
+ chi1R2 = jnp.roll(chi1, -2, axis=axis)
109
+
110
+ chi2L1 = jnp.roll(chi2, 1, axis=axis)
111
+ chi2L2 = jnp.roll(chi2, 2, axis=axis)
112
+
113
+ chi2R1 = jnp.roll(chi2, -1, axis=axis)
114
+
115
+ chi3L1 = jnp.roll(chi3, 1, axis=axis)
116
+ chi3L2 = jnp.roll(chi3, 2, axis=axis)
117
+ chi3L3 = jnp.roll(chi3, 3, axis=axis)
118
+
119
+ # deltas (sharp cutoff function)
120
+ # Multi-stencil Discontinuity Detector
121
+ delta0 = jnp.where((chi0 < CT) & (chi1L1 < CT) & (chi2L2 < CT) & (chi3L3 < CT), 0, 1)
122
+ delta1 = jnp.where((chi1 < CT) & (chi2L1 < CT) & (chi3L2 < CT) & (chi0R1 < CT), 0, 1)
123
+ delta2 = jnp.where((chi2 < CT) & (chi3L1 < CT) & (chi1R1 < CT) & (chi0R2 < CT), 0, 1)
124
+ delta3 = jnp.where((chi3 < CT) & (chi2R1 < CT) & (chi1R2 < CT) & (chi0R3 < CT), 0, 1)
125
+
126
+ # ===== Possible discontinuity Configurations =====
127
+
128
+ # uHalfOptimal: smooth stencils
129
+ condition0 = (delta0 == 1) & (delta1 == 1) & (delta2 == 1) & (delta3 == 1)
130
+
131
+ # uHalf1: a single discontinuity between uL3 and uL2
132
+ condition1 = (delta0 == 0) & (delta1 == 1) & (delta2 == 1) & (delta3 == 1)
133
+
134
+ # uHalf2: a single discontinuity between uL2 and uL1
135
+ condition2 = (delta0 == 0) & (delta1 == 0) & (delta2 == 1) & (delta3 == 1)
136
+
137
+ # uHalf3: a single discontinuity between uL1 and u0
138
+ condition3 = (delta0 == 0) & (delta1 == 0) & (delta2 == 0) & (delta3 == 1)
139
+
140
+ # uHalf4: a single discontinuity between u0 and uR1
141
+ condition4 = (delta0 == 1) & (delta1 == 0) & (delta2 == 0) & (delta3 == 0)
142
+
143
+ # uHalf5: a single discontinuity between uR1 and uR2
144
+ condition5 = (delta0 == 1) & (delta1 == 1) & (delta2 == 0) & (delta3 == 0)
145
+
146
+ # uHalf6: a single discontinuity between uR2 and uR3
147
+ condition6 = (delta0 == 1) & (delta1 == 1) & (delta2 == 1) & (delta3 == 0)
148
+
149
+ conditions = [
150
+ condition0,
151
+ condition1,
152
+ condition2,
153
+ condition3,
154
+ condition4,
155
+ condition5,
156
+ condition6
157
+ ]
158
+
159
+ return conditions
160
+
161
+
162
+ def TENO7M(
163
+ u: ArrayLike, axis: int, dx: float, CT: float = 0.01
164
+ ) -> Array:
165
+ """TENO7-M finite difference scheme
166
+
167
+ Estimation of the derivative for initial guess of the BiCGSTAB algorithm.
168
+
169
+ Parameters
170
+ ----------
171
+ u : ArrayLike, 3D array
172
+ axis : int, axis direction (0, 1, 2 corresponds to x, y, z)
173
+ dx : float, grid spacing along the axis
174
+ CT : float, smoothness threshold, default: 0.01
175
+
176
+ Returns
177
+ -------
178
+ dudx : Array, estimated derivative (initial guess)
179
+
180
+ """
181
+ u = jnp.asarray(u)
182
+
183
+ uL3, uL2, uL1, u0, uR1, uR2, uR3 = [jnp.roll(u, shift, axis=axis) for shift in (3, 2, 1, 0, -1, -2, -3)]
184
+
185
+ conditions = MSDD7(u, axis, CT=CT)
186
+
187
+ # smooth stencils
188
+ a1 = 0.77088238051822552 # Original: 3/4 = 0.75
189
+ a2 = -0.166705904414580469 # Original: -3/20 = -0.15
190
+ a3 = 0.02084314277031176 # Original: 1/60 = 0.01666...
191
+ dudx0 = (a1 * (uR1 - uL1) + a2 * (uR2 - uL2) + a3 * (uR3 - uL3)) / dx
192
+
193
+ # a single discontinuity between uL3 and uL2
194
+ # 1/12 −2/3 0 2/3 −1/12
195
+ dudx1 = (2/3 * (uR1 - uL1) - 1/12 * (uR2 - uL2)) / dx
196
+
197
+ # a single discontinuity between uL2 and uL1
198
+ dudx2 = (uR1 - uL1) / (2 * dx)
199
+
200
+ # a single discontinuity between uL1 and u0
201
+ dudx3 = (uR1 - u0 ) / dx
202
+
203
+ # a single discontinuity between u0 and uR1
204
+ dudx4 = (u0 - uL1) / dx
205
+
206
+ # a single discontinuity between uR1 and uR2
207
+ dudx5 = (uR1 - uL1) / (2 * dx)
208
+
209
+ # a single discontinuity between uR2 and uR3
210
+ # 1/12 −2/3 0 2/3 −1/12
211
+ dudx6 = (2/3 * (uR1 - uL1) - 1/12 * (uR2 - uL2)) / dx
212
+
213
+ dudxs = [dudx0, dudx1, dudx2, dudx3, dudx4, dudx5, dudx6]
214
+
215
+ dudx = jnp.select(conditions, dudxs, default=dudx0)
216
+
217
+ return dudx
218
+
219
+
220
+ @partial(jax.jit, static_argnums=(1, 2, 3))
221
+ def TCS7M(
222
+ u: ArrayLike, axis: int, CT: float = 0.01, L: float = 1.0
223
+ ) -> Array:
224
+ """Targeted Compact Scheme with MSDD (TCS-M)
225
+
226
+ Parameters
227
+ ----------
228
+ u : ArrayLike, 3D array
229
+ axis : int, axis direction (0, 1, 2 corresponds to x, y, z)
230
+ CT : float, smoothness threshold, default: 0.01
231
+ L : float, domain length along the axis, default: 1.0
232
+
233
+ Returns
234
+ -------
235
+ dudx : Array, derivative along the given axis
236
+
237
+ References
238
+ ----------
239
+ [1] Lele, 1992, JCP, 10.1016/0021-9991(92)90324-R
240
+ [2] Jiang, 2001, IJCFD, 10.1080/10618560108970024
241
+
242
+ """
243
+ u = jnp.asarray(u)
244
+ dx = L / u.shape[axis]
245
+
246
+ uL3, uL2, uL1, u0, uR1, uR2, uR3 = [jnp.roll(u, shift, axis=axis) for shift in (3, 2, 1, 0, -1, -2, -3)]
247
+
248
+ conditions = MSDD7(u, axis, CT=CT)
249
+ x0 = TENO7M(u, axis, dx, CT=CT)
250
+
251
+ # construct linear system: AX = B
252
+
253
+ # ===== right-hand side B =====
254
+ # B for smooth stencils:
255
+ # cL, bL, aL, c, aR, bR, cR represent c^- b^-, a^-, c^0, a^+, b^+, c^+, respectively
256
+
257
+ # B0: smooth stencils
258
+ # Lele 1992 (3.1.6)
259
+ a0, b0, c0 = 1.3025166, 0.9935500, 0.03750245
260
+ B0: Array = c0 / (6 * dx) * (uR3 - uL3) + b0 / (4 * dx) * (uR2 - uL2) + a0 / (2 * dx) * (uR1 - uL1)
261
+
262
+ # B1: a single discontinuity between uL3 and uL2
263
+ # Lele 1992 (2.1.12)
264
+ a1, b1 = 40/27, 25/54
265
+ B1: Array = b1 / (4 * dx) * (uR2 - uL2) + a1 / (2 * dx) * (uR1 - uL1)
266
+
267
+ # B2: a single discontinuity between uL2 and uL1
268
+ # Jiang 2001, Lele 1992 (theta = 3/10)
269
+ bL2, aL2, c2, aR2, bR2 = 0, -0.75 * 44/49, -2.65 * 5/49, 0.75 * 44/49 + 1.4 * 5/49, 1.25 * 5/49
270
+ B2: Array = (bL2 * uL2 + aL2 * uL1 + c2 * u0 + aR2 * uR1 + bR2 * uR2) / dx
271
+
272
+ # B3: a single discontinuity between uL1 and u0
273
+ # Jiang 2001, Lele 1992 (theta = 3/10)
274
+ bL3, aL3, c3, aR3, bR3 = 0, 0, -2.65, 1.4, 1.25
275
+ B3: Array = (bL3 * uL2 + aL3 * uL1 + c3 * u0 + aR3 * uR1 + bR3 * uR2) / dx
276
+
277
+ # B4: a single discontinuity between u0 and uR1
278
+ # Jiang 2001, Lele 1992 (theta = 3/10)
279
+ bL4, aL4, c4, aR4, bR4 = -1.25, -1.4, 2.65, 0, 0
280
+ B4: Array = (bL4 * uL2 + aL4 * uL1 + c4 * u0 + aR4 * uR1 + bR4 * uR2) / dx
281
+
282
+ # B5: a single discontinuity between uR1 and uR2
283
+ # Jiang 2001, Lele 1992 (theta = 3/10)
284
+ bL5, aL5, c5, aR5, bR5 = -1.25 * 5/49, -1.4 * 5/49 + -0.75 * 44/49, 2.65 * 5/49, 0.75 * 44/49, 0
285
+ B5: Array = (bL5 * uL2 + aL5 * uL1 + c5 * u0 + aR5 * uR1 + bR5 * uR2) / dx
286
+
287
+ # B6: a single discontinuity between uR2 and uR3
288
+ # Lele 1992 (2.1.12)
289
+ a6, b6 = 40/27, 25/54
290
+ B6: Array = b6 / (4 * dx) * (uR2 - uL2) + a6 / (2 * dx) * (uR1 - uL1)
291
+
292
+ Bs = [B0, B1, B2, B3, B4, B5, B6]
293
+ B = jnp.select(conditions, Bs, default=B0)
294
+
295
+ # ===== matrix-vector product mapping =====
296
+ # here, x represents the array of unknown derivatives, not the coordinates
297
+ # betaL, alphaL, alphaR, betaR represent \beta^-, \alpha^-, \alpha^+, \beta^+, respectively
298
+ def A(x: ArrayLike) -> Array:
299
+ x = jnp.asarray(x)
300
+ xL2, xL1, x0, xR1, xR2 = [jnp.roll(x, shift, axis=axis) for shift in (2, 1, 0, -1, -2)]
301
+
302
+ # A0: smooth stencils
303
+ # Lele 1992 (3.1.6)
304
+ alpha0, beta0 = 0.5771439, 0.0896406
305
+ A0 = beta0 * xL2 + alpha0 * xL1 + x0 + alpha0 * xR1 + beta0 * xR2
306
+
307
+ # A1: a single discontinuity between uL3 and uL2
308
+ # Lele 1992 (2.1.12)
309
+ alpha1, beta1 = 4/9, 1/36
310
+ A1 = beta1 * xL2 + alpha1 * xL1 + x0 + alpha1 * xR1 + beta1 * xR2
311
+
312
+ # A2: a single discontinuity between uL2 and uL1
313
+ # Jiang 2001, Lele 1992 (theta = 3/10)
314
+ betaL2, alphaL2, alphaR2, betaR2 = 0, 0.25 * 44/49, 0.25 * 44/49 + 2.6 * 5/49, 0.3 * 5/49
315
+ A2 = betaL2 * xL2 + alphaL2 * xL1 + x0 + alphaR2 * xR1 + betaR2 * xR2
316
+
317
+ # A3: a single discontinuity between uL1 and u0
318
+ # Jiang 2001, Lele 1992 (theta = 3/10)
319
+ betaL3, alphaL3, alphaR3, betaR3 = 0, 0, 2.6, 0.3
320
+ A3 = betaL3 * xL2 + alphaL3 * xL1 + x0 + alphaR3 * xR1 + betaR3 * xR2
321
+
322
+ # A4: a single discontinuity between u0 and uR1
323
+ # Jiang 2001, Lele 1992 (theta = 3/10)
324
+ betaL4, alphaL4, alphaR4, betaR4 = 0.3, 2.6, 0, 0
325
+ A4 = betaL4 * xL2 + alphaL4 * xL1 + x0 + alphaR4 * xR1 + betaR4 * xR2
326
+
327
+ # A5: a single discontinuity between uR1 and uR2
328
+ # Jiang 2001, Lele 1992 (theta = 3/10)
329
+ betaL5, alphaL5, alphaR5, betaR5 = 0.3 * 5/49, 2.6 * 5/49 + 0.25 * 44/49, 0.25 * 44/49, 0
330
+ A5 = betaL5 * xL2 + alphaL5 * xL1 + x0 + alphaR5 * xR1 + betaR5 * xR2
331
+
332
+ # A6: a single discontinuity between uR2 and uR3
333
+ # Lele 1992 (2.1.12)
334
+ alpha6, beta6 = 4/9, 1/36
335
+ A6 = beta6 * xL2 + alpha6 * xL1 + x0 + alpha6 * xR1 + beta6 * xR2
336
+
337
+ As = [A0, A1, A2, A3, A4, A5, A6]
338
+ A = jnp.select(conditions, As, default=A0)
339
+
340
+ return A
341
+
342
+ dudx, _ = bicgstab(A, B, x0=x0, tol=1e-8, atol=0, maxiter=200)
343
+
344
+ return dudx
345
+
346
+
347
+ # Direction wrappers
348
+
349
+ def TCS7Mx(
350
+ u: ArrayLike, CT: float = 0.01, L: float = 1.0
351
+ ) -> Array:
352
+ """TCS7-M derivative in x direction"""
353
+ return TCS7M(u, axis=0, CT=CT, L=L)
354
+
355
+ def TCS7My(
356
+ u: ArrayLike, CT: float = 0.01, L: float = 1.0
357
+ ) -> Array:
358
+ """TCS7-M derivative in y direction"""
359
+ return TCS7M(u, axis=1, CT=CT, L=L)
360
+
361
+ def TCS7Mz(
362
+ u: ArrayLike, CT: float = 0.01, L: float = 1.0
363
+ ) -> Array:
364
+ """TCS7-M derivative in z direction"""
365
+ return TCS7M(u, axis=2, CT=CT, L=L)