PyMHD 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,926 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/derivatives/derivative.py
7
+ -------------------------------
8
+
9
+ Implements the numerical schemes for derivatives of ScalarField and VectorField, including:
10
+ - Dx, Dy, Dz: partial derivatives of ScalarField and VectorField;
11
+ - grad : gradient of ScalarField;
12
+ - div : divergence of VectorField;
13
+ - curl : curl of VectorField;
14
+ - laplacian : Laplacian of ScalarField and VectorField.
15
+
16
+ Supported algorithms:
17
+ - 'WENO' : WENO5-Z and WENO7-Z;
18
+ - 'TENO' : TENO7-M;
19
+ - 'TCS' : targeted compact scheme (TCS7-M);
20
+ - 'CENTRAL' : 2nd, 4th, and 8th order central difference;
21
+ - 'SPECTRAL': Fourier spectral difference.
22
+ """
23
+
24
+ import numpy as np
25
+
26
+ from typing import TypeVar, Literal
27
+
28
+ from ..turbulence import ScalarField, VectorField
29
+
30
+ # Define generic type for Fields (ScalarField or VectorField)
31
+ Field = TypeVar('Field', ScalarField, VectorField)
32
+
33
+ # ========== Algorithm class ==========
34
+ # supported algorithms and parameters:
35
+ # 'WENO' : stencil width (5 or 7)
36
+ # 'TENO' : CT (discontinuity threshold)
37
+ # 'TCS' : CT (discontinuity threshold)
38
+ # 'CENTRAL' : stencil width (accuracy order + 1, i.e. order = stencil - 1)
39
+ # 'SPECTRAL': no parameters
40
+
41
+ class Algorithm:
42
+ """Algorithm class
43
+
44
+ Attributes
45
+ ----------
46
+ method : type of algorithm, options: 'WENO', 'TENO', 'TCS', 'CENTRAL', 'SPECTRAL'
47
+ stencil : stencil width for 'WENO', 'TENO', and 'CENTRAL'
48
+ CT : discontinuity threshold for 'TENO' and 'TCS'
49
+ """
50
+ method : str
51
+ stencil : int | None
52
+ CT : float | None
53
+
54
+ def __init__(self, method: str, stencil: int | None = None, CT: float | None = None):
55
+ self.method = method.upper()
56
+
57
+ if self.method == 'WENO':
58
+ self.stencil = 7 if stencil is None else stencil
59
+ self.CT = None
60
+ if self.stencil not in (5, 7):
61
+ raise ValueError("WENO stencil must be 5 or 7.")
62
+
63
+ elif self.method == 'CENTRAL':
64
+ # CENTRAL uses stencil = order + 1, supporting orders of 2, 4, and 8.
65
+ self.stencil = 9 if stencil is None else stencil
66
+ self.CT = None
67
+ if self.stencil not in (3, 5, 9):
68
+ raise ValueError("CENTRAL stencil must be 3, 5, or 9 (order 2, 4, 8).")
69
+
70
+ elif self.method == 'TENO':
71
+ if stencil not in (None, 7):
72
+ raise ValueError(f"{self.method} stencil must be 7 or None.")
73
+ self.stencil = 7 if stencil is None else stencil
74
+ self.CT = 0.01 if CT is None else CT
75
+
76
+ elif self.method == 'TCS':
77
+ if stencil not in (None, 7):
78
+ raise ValueError(f"{self.method} stencil must be 7 or None.")
79
+ self.stencil = None
80
+ self.CT = 0.01 if CT is None else CT
81
+
82
+ elif self.method == 'SPECTRAL':
83
+ self.stencil = None
84
+ self.CT = None
85
+
86
+ else:
87
+ raise ValueError(f"Unsupported method: {self.method}.")
88
+
89
+
90
+ # ========== WENO derivatives ==========
91
+
92
+ from .WENO import WENOx, WENOy, WENOz
93
+
94
+ def wenoDx(field: Field, stencil: int = 7) -> Field:
95
+ """x derivative of Field using WENO-Z
96
+
97
+ Parameters
98
+ ----------
99
+ field : Field
100
+ stencil : stencil width (5 or 7), default: 7
101
+
102
+ Returns
103
+ -------
104
+ x derivative of Field
105
+ """
106
+ if isinstance(field, ScalarField):
107
+ data = field.data
108
+ dx = field.dx
109
+
110
+ # WENO reconstructed cell interfaces
111
+ uL, uR = WENOx(data, stencil)
112
+ DuDx = (uR - uL) / dx
113
+
114
+ return ScalarField(DuDx, field.box)
115
+
116
+ if isinstance(field, VectorField):
117
+ DxVx = wenoDx(ScalarField(field.x, field.box), stencil=stencil)
118
+ DxVy = wenoDx(ScalarField(field.y, field.box), stencil=stencil)
119
+ DxVz = wenoDx(ScalarField(field.z, field.box), stencil=stencil)
120
+
121
+ return VectorField(DxVx.data, DxVy.data, DxVz.data, field.box)
122
+
123
+ raise TypeError("Input of 'wenoDx()' must be ScalarField or VectorField.")
124
+
125
+
126
+ def wenoDy(field: Field, stencil: int = 7) -> Field:
127
+ """y derivative of Field using WENO-Z
128
+
129
+ Parameters
130
+ ----------
131
+ field : Field
132
+ stencil : width of stencil (5 or 7), default: 7
133
+
134
+ Returns
135
+ -------
136
+ y derivative of Field
137
+ """
138
+ if isinstance(field, ScalarField):
139
+ data = field.data
140
+ dy = field.dy
141
+
142
+ # WENO reconstructed cell interfaces
143
+ uL, uR = WENOy(data, stencil)
144
+ DuDy = (uR - uL) / dy
145
+
146
+ return ScalarField(DuDy, field.box)
147
+
148
+ if isinstance(field, VectorField):
149
+ DyVx = wenoDy(ScalarField(field.x, field.box), stencil=stencil)
150
+ DyVy = wenoDy(ScalarField(field.y, field.box), stencil=stencil)
151
+ DyVz = wenoDy(ScalarField(field.z, field.box), stencil=stencil)
152
+
153
+ return VectorField(DyVx.data, DyVy.data, DyVz.data, field.box)
154
+
155
+ raise TypeError("Input of 'wenoDy()' must be ScalarField or VectorField.")
156
+
157
+
158
+ def wenoDz(field: Field, stencil: int = 7) -> Field:
159
+ """z derivative of Field using WENO-Z
160
+
161
+ Parameters
162
+ ----------
163
+ field : Field
164
+ stencil : width of stencil (5 or 7), default: 7
165
+
166
+ Returns
167
+ -------
168
+ z derivative of Field
169
+ """
170
+ if isinstance(field, ScalarField):
171
+ data = field.data
172
+ dz = field.dz
173
+
174
+ # WENO reconstructed cell interfaces
175
+ uL, uR = WENOz(data, stencil)
176
+ DuDz = (uR - uL) / dz
177
+
178
+ return ScalarField(DuDz, field.box)
179
+
180
+ if isinstance(field, VectorField):
181
+ DzVx = wenoDz(ScalarField(field.x, field.box), stencil=stencil)
182
+ DzVy = wenoDz(ScalarField(field.y, field.box), stencil=stencil)
183
+ DzVz = wenoDz(ScalarField(field.z, field.box), stencil=stencil)
184
+
185
+ return VectorField(DzVx.data, DzVy.data, DzVz.data, field.box)
186
+
187
+ raise TypeError("Input of 'wenoDz()' must be ScalarField or VectorField.")
188
+
189
+ # ========== TENO derivatives ==========
190
+
191
+ from .TENO import TENO7Mx, TENO7My, TENO7Mz
192
+
193
+ def tenoDx(field: Field, CT: float = 0.01) -> Field:
194
+ """x derivative of Field using TENO-M scheme
195
+
196
+ TENO-M: TENO scheme with multi-stencil discontinuity detector
197
+
198
+ Parameters
199
+ ----------
200
+ field : Field
201
+ CT : smoothness threshold, default: 0.01
202
+
203
+ Returns
204
+ -------
205
+ x derivative of Field
206
+ """
207
+ if isinstance(field, ScalarField):
208
+ data = field.data
209
+ dx = field.dx
210
+
211
+ # TENO reconstructed cell interfaces
212
+ uL, uR = TENO7Mx(data, mode='hybrid', CT=CT)
213
+
214
+ DuDx = np.asarray((uR - uL) / dx)
215
+
216
+ return ScalarField(DuDx, field.box)
217
+
218
+ if isinstance(field, VectorField):
219
+ DxVx = tenoDx(ScalarField(field.x, field.box), CT)
220
+ DxVy = tenoDx(ScalarField(field.y, field.box), CT)
221
+ DxVz = tenoDx(ScalarField(field.z, field.box), CT)
222
+
223
+ return VectorField(DxVx.data, DxVy.data, DxVz.data, field.box)
224
+
225
+ raise TypeError("Input of 'tenoDx()' must be ScalarField or VectorField.")
226
+
227
+
228
+ def tenoDy(field: Field, CT: float = 0.01) -> Field:
229
+ """y derivative of Field using TENO-M
230
+
231
+ TENO-M: TENO scheme with multi-stencil discontinuity detector
232
+
233
+ Parameters
234
+ ----------
235
+ field : Field
236
+ CT : smoothness threshold, default: 0.01
237
+
238
+ Returns
239
+ -------
240
+ y derivative of Field
241
+ """
242
+ if isinstance(field, ScalarField):
243
+ data = field.data
244
+ dy = field.dy
245
+
246
+ # TENO reconstructed cell interfaces
247
+ uL, uR = TENO7My(data, mode='hybrid', CT=CT)
248
+
249
+ DuDy = np.asarray((uR - uL) / dy)
250
+
251
+ return ScalarField(DuDy, field.box)
252
+
253
+ if isinstance(field, VectorField):
254
+ DyVx = tenoDy(ScalarField(field.x, field.box), CT)
255
+ DyVy = tenoDy(ScalarField(field.y, field.box), CT)
256
+ DyVz = tenoDy(ScalarField(field.z, field.box), CT)
257
+
258
+ return VectorField(DyVx.data, DyVy.data, DyVz.data, field.box)
259
+
260
+ raise TypeError("Input of 'tenoDy()' must be ScalarField or VectorField.")
261
+
262
+
263
+ def tenoDz(field: Field, CT: float = 0.01) -> Field:
264
+ """z derivative of Field using TENO-M
265
+
266
+ TENO-M: TENO scheme with multi-stencil discontinuity detector
267
+
268
+ Parameters
269
+ ----------
270
+ field : Field
271
+ CT : smoothness threshold, default: 0.01
272
+
273
+ Returns
274
+ -------
275
+ z derivative of Field
276
+ """
277
+ if isinstance(field, ScalarField):
278
+ data = field.data
279
+ dz = field.dz
280
+
281
+ # TENO reconstructed cell interfaces
282
+ uL, uR = TENO7Mz(data, mode='hybrid', CT=CT)
283
+
284
+ DuDz = np.asarray((uR - uL) / dz)
285
+
286
+ return ScalarField(DuDz, field.box)
287
+
288
+ if isinstance(field, VectorField):
289
+ DzVx = tenoDz(ScalarField(field.x, field.box), CT)
290
+ DzVy = tenoDz(ScalarField(field.y, field.box), CT)
291
+ DzVz = tenoDz(ScalarField(field.z, field.box), CT)
292
+
293
+ return VectorField(DzVx.data, DzVy.data, DzVz.data, field.box)
294
+
295
+ raise TypeError("Input of 'tenoDz()' must be ScalarField or VectorField.")
296
+
297
+
298
+ # ========== TCS (Targeted Compact Scheme) derivatives ==========
299
+
300
+ from .compact import TCS7Mx, TCS7My, TCS7Mz
301
+
302
+ def tcsDx(field: Field, CT: float = 0.01) -> Field:
303
+ """x derivative of Field using TCS7-M
304
+
305
+ TCS7-M: targeted compact scheme with multi-stencil discontinuity detector
306
+
307
+ Parameters
308
+ ----------
309
+ field : Field
310
+ CT : smoothness threshold, default: 0.01
311
+
312
+ Returns
313
+ -------
314
+ x derivative of Field
315
+ """
316
+ if isinstance(field, ScalarField):
317
+ DuDx = np.asarray(TCS7Mx(field.data, CT=CT, L=field.Lx))
318
+ return ScalarField(DuDx, field.box)
319
+
320
+ if isinstance(field, VectorField):
321
+ DxVx = tcsDx(ScalarField(field.x, field.box), CT)
322
+ DxVy = tcsDx(ScalarField(field.y, field.box), CT)
323
+ DxVz = tcsDx(ScalarField(field.z, field.box), CT)
324
+
325
+ return VectorField(DxVx.data, DxVy.data, DxVz.data, field.box)
326
+
327
+ raise TypeError("Input of 'tcsDx()' must be ScalarField or VectorField.")
328
+
329
+
330
+ def tcsDy(field: Field, CT: float = 0.01) -> Field:
331
+ """y derivative of Field using TCS7-M
332
+
333
+ TCS7-M: targeted compact scheme with multi-stencil discontinuity detector
334
+
335
+ Parameters
336
+ ----------
337
+ field : Field
338
+ CT : smoothness threshold, default: 0.01
339
+
340
+ Returns
341
+ -------
342
+ y derivative of Field
343
+ """
344
+ if isinstance(field, ScalarField):
345
+ DuDy = np.asarray(TCS7My(field.data, CT=CT, L=field.Ly))
346
+ return ScalarField(DuDy, field.box)
347
+
348
+ if isinstance(field, VectorField):
349
+ DyVx = tcsDy(ScalarField(field.x, field.box), CT)
350
+ DyVy = tcsDy(ScalarField(field.y, field.box), CT)
351
+ DyVz = tcsDy(ScalarField(field.z, field.box), CT)
352
+
353
+ return VectorField(DyVx.data, DyVy.data, DyVz.data, field.box)
354
+
355
+ raise TypeError("Input of 'tcsDy()' must be ScalarField or VectorField.")
356
+
357
+
358
+ def tcsDz(field: Field, CT: float = 0.01) -> Field:
359
+ """z derivative of Field using TCS7-M
360
+
361
+ TCS7-M: targeted compact scheme with multi-stencil discontinuity detector
362
+
363
+ Parameters
364
+ ----------
365
+ field : Field
366
+ CT : smoothness threshold, default: 0.01
367
+
368
+ Returns
369
+ -------
370
+ z derivative of Field
371
+ """
372
+ if isinstance(field, ScalarField):
373
+ DuDz = np.asarray(TCS7Mz(field.data, CT=CT, L=field.Lz))
374
+ return ScalarField(DuDz, field.box)
375
+
376
+ if isinstance(field, VectorField):
377
+ DzVx = tcsDz(ScalarField(field.x, field.box), CT)
378
+ DzVy = tcsDz(ScalarField(field.y, field.box), CT)
379
+ DzVz = tcsDz(ScalarField(field.z, field.box), CT)
380
+
381
+ return VectorField(DzVx.data, DzVy.data, DzVz.data, field.box)
382
+
383
+ raise TypeError("Input of 'tcsDz()' must be ScalarField or VectorField.")
384
+
385
+
386
+ # ========== Central difference derivatives ==========
387
+
388
+ def central(
389
+ u: np.ndarray, dx: float, direction: Literal['x', 'y', 'z'], order: int = 8
390
+ ) -> np.ndarray:
391
+ """derivative of array using central difference scheme
392
+
393
+ Parameters
394
+ ----------
395
+ u : 3D array
396
+ dx : grid spacing
397
+ direction : derivative direction ('x', 'y', 'z')
398
+ order : order of accuracy (2, 4, 8), default: 8
399
+
400
+ Returns
401
+ -------
402
+ derivative of u
403
+ """
404
+ # determine the axis of roll based on direction
405
+ directions = {'x': 0, 'y': 1, 'z': 2}
406
+ axis = directions[direction]
407
+
408
+ if order == 2:
409
+ # second order: f'(x) = (f(x+h) - f(x-h)) / (2h)
410
+ uR1 = np.roll(u, -1, axis=axis) # f(x + h)
411
+ uL1 = np.roll(u, 1, axis=axis) # f(x - h)
412
+ DuDx = (1 / 2 * uR1 - 1 / 2 * uL1) / dx
413
+
414
+ elif order == 4:
415
+ # fourth order: f'(x) = (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / (12h)
416
+ uR2 = np.roll(u, -2, axis=axis) # f(x + 2h)
417
+ uR1 = np.roll(u, -1, axis=axis) # f(x + h)
418
+ uL1 = np.roll(u, 1, axis=axis) # f(x - h)
419
+ uL2 = np.roll(u, 2, axis=axis) # f(x - 2h)
420
+ DuDx = (
421
+ 1 / 12 * uL2 - 2 / 3 * uL1 + 2 / 3 * uR1 - 1 / 12 * uR2
422
+ ) / dx
423
+
424
+ elif order == 8:
425
+ # eighth order:
426
+ # f'(x) = (1/280*f(x-4h) - 4/105*f(x-3h) + 1/5*f(x-2h) - 4/5*f(x-h)
427
+ # + 4/5*f(x+h) - 1/5*f(x+2h) + 4/105*f(x+3h) - 1/280*f(x+4h)) / h
428
+ uR4 = np.roll(u, -4, axis=axis) # f(x + 4h)
429
+ uR3 = np.roll(u, -3, axis=axis) # f(x + 3h)
430
+ uR2 = np.roll(u, -2, axis=axis) # f(x + 2h)
431
+ uR1 = np.roll(u, -1, axis=axis) # f(x + h)
432
+ uL1 = np.roll(u, 1, axis=axis) # f(x - h)
433
+ uL2 = np.roll(u, 2, axis=axis) # f(x - 2h)
434
+ uL3 = np.roll(u, 3, axis=axis) # f(x - 3h)
435
+ uL4 = np.roll(u, 4, axis=axis) # f(x - 4h)
436
+ DuDx = (
437
+ (1 / 280) * uL4
438
+ - (4 / 105) * uL3
439
+ + (1 / 5) * uL2
440
+ - (4 / 5) * uL1
441
+ + (4 / 5) * uR1
442
+ - (1 / 5) * uR2
443
+ + (4 / 105) * uR3
444
+ - (1 / 280) * uR4
445
+ ) / dx
446
+ else:
447
+ raise ValueError(f"Unsupported order: {order}. Supported orders are 2, 4, 8.")
448
+
449
+ return DuDx
450
+
451
+ def centralDx(field: Field, order: int = 8) -> Field:
452
+ """x derivative of Field using central difference scheme
453
+
454
+ Parameters
455
+ ----------
456
+ field : Field
457
+ order : order of accuracy (2, 4, 8), default: 8
458
+
459
+ Returns
460
+ -------
461
+ x derivative of Field
462
+ """
463
+ if isinstance(field, ScalarField):
464
+
465
+ data = field.data
466
+ dx = field.dx
467
+
468
+ DuDx = central(data, dx, 'x', order)
469
+
470
+ return ScalarField(DuDx, field.box)
471
+
472
+ if isinstance(field, VectorField):
473
+
474
+ DxVx = centralDx(ScalarField(field.x, field.box), order)
475
+ DxVy = centralDx(ScalarField(field.y, field.box), order)
476
+ DxVz = centralDx(ScalarField(field.z, field.box), order)
477
+
478
+ return VectorField(DxVx.data, DxVy.data, DxVz.data, field.box)
479
+
480
+ raise TypeError("Input of 'centralDx()' must be ScalarField or VectorField.")
481
+
482
+
483
+ def centralDy(field: Field, order: int = 8) -> Field:
484
+ """y derivative of Field using central difference scheme
485
+
486
+ Parameters
487
+ ----------
488
+ field : Field
489
+ order : order of accuracy (2, 4, 8), default: 8
490
+
491
+ Returns
492
+ -------
493
+ y derivative of Field
494
+ """
495
+ if isinstance(field, ScalarField):
496
+
497
+ data = field.data
498
+ dy = field.dy
499
+
500
+ DuDy = central(data, dy, 'y', order)
501
+
502
+ return ScalarField(DuDy, field.box)
503
+
504
+ if isinstance(field, VectorField):
505
+
506
+ DyVx = centralDy(ScalarField(field.x, field.box), order)
507
+ DyVy = centralDy(ScalarField(field.y, field.box), order)
508
+ DyVz = centralDy(ScalarField(field.z, field.box), order)
509
+
510
+ return VectorField(DyVx.data, DyVy.data, DyVz.data, field.box)
511
+
512
+ raise TypeError("Input of 'centralDy()' must be ScalarField or VectorField.")
513
+
514
+
515
+ def centralDz(field: Field, order: int = 8) -> Field:
516
+ """z derivative of Field using central difference scheme
517
+
518
+ Parameters
519
+ ----------
520
+ field : Field
521
+ order : order of accuracy (2, 4, 8), default: 8
522
+
523
+ Returns
524
+ -------
525
+ z derivative of Field
526
+ """
527
+ if isinstance(field, ScalarField):
528
+
529
+ data = field.data
530
+ dz = field.dz
531
+
532
+ DuDz = central(data, dz, 'z', order)
533
+
534
+ return ScalarField(DuDz, field.box)
535
+
536
+ if isinstance(field, VectorField):
537
+
538
+ DzVx = centralDz(ScalarField(field.x, field.box), order)
539
+ DzVy = centralDz(ScalarField(field.y, field.box), order)
540
+ DzVz = centralDz(ScalarField(field.z, field.box), order)
541
+
542
+ return VectorField(DzVx.data, DzVy.data, DzVz.data, field.box)
543
+
544
+ raise TypeError("Input of 'centralDz()' must be ScalarField or VectorField.")
545
+
546
+
547
+ # ========== Spectral difference ==========
548
+
549
+ def spectralDx(field: Field) -> Field:
550
+ """x derivative of Field using spectral difference
551
+
552
+ Parameters
553
+ ----------
554
+ field : Field
555
+
556
+ Returns
557
+ -------
558
+ x derivative of Field
559
+ """
560
+ if isinstance(field, ScalarField):
561
+ data, Nx = field.data, field.Nx
562
+
563
+ kx = 2 * np.pi * np.fft.fftfreq(Nx, d=field.dx)
564
+
565
+ # 1D FFT along x
566
+ uhat = np.fft.fft(data, axis=0)
567
+ DuDx = np.real(np.fft.ifft(1j * kx[:, np.newaxis, np.newaxis] * uhat, axis=0))
568
+
569
+ return ScalarField(DuDx, field.box)
570
+
571
+ if isinstance(field, VectorField):
572
+
573
+ DxVx = spectralDx(ScalarField(field.x, field.box))
574
+ DxVy = spectralDx(ScalarField(field.y, field.box))
575
+ DxVz = spectralDx(ScalarField(field.z, field.box))
576
+
577
+ return VectorField(DxVx.data, DxVy.data, DxVz.data, field.box)
578
+
579
+ raise TypeError("Input of 'spectralDx()' must be ScalarField or VectorField.")
580
+
581
+
582
+ def spectralDy(field: Field) -> Field:
583
+ """y derivative of Field using spectral difference
584
+
585
+ Parameters
586
+ ----------
587
+ field : Field
588
+
589
+ Returns
590
+ -------
591
+ y derivative of Field
592
+
593
+ """
594
+ if isinstance(field, ScalarField):
595
+ data, Ny = field.data, field.Ny
596
+
597
+ ky = 2 * np.pi * np.fft.fftfreq(Ny, d=field.dy)
598
+
599
+ # 1D FFT along y
600
+ uhat = np.fft.fft(data, axis=1)
601
+ DuDy = np.real(np.fft.ifft(1j * ky[np.newaxis, :, np.newaxis] * uhat, axis=1))
602
+
603
+ return ScalarField(DuDy, field.box)
604
+
605
+ if isinstance(field, VectorField):
606
+ DyVx = spectralDy(ScalarField(field.x, field.box))
607
+ DyVy = spectralDy(ScalarField(field.y, field.box))
608
+ DyVz = spectralDy(ScalarField(field.z, field.box))
609
+
610
+ return VectorField(DyVx.data, DyVy.data, DyVz.data, field.box)
611
+
612
+ raise TypeError("Input of 'spectralDy()' must be ScalarField or VectorField.")
613
+
614
+ def spectralDz(field: Field) -> Field:
615
+ """z derivative of Field using spectral difference
616
+
617
+ Parameters
618
+ ----------
619
+ field : Field
620
+
621
+ Returns
622
+ -------
623
+ z derivative of Field
624
+ """
625
+ if isinstance(field, ScalarField):
626
+ data, Nz = field.data, field.Nz
627
+
628
+ kz = 2 * np.pi * np.fft.fftfreq(Nz, d=field.dz)
629
+
630
+ # 1D FFT along z
631
+ uhat = np.fft.fft(data, axis=2)
632
+ DuDz = np.real(np.fft.ifft(1j * kz[np.newaxis, np.newaxis, :] * uhat, axis=2))
633
+
634
+ return ScalarField(DuDz, field.box)
635
+
636
+ if isinstance(field, VectorField):
637
+ DzVx = spectralDz(ScalarField(field.x, field.box))
638
+ DzVy = spectralDz(ScalarField(field.y, field.box))
639
+ DzVz = spectralDz(ScalarField(field.z, field.box))
640
+
641
+ return VectorField(DzVx.data, DzVy.data, DzVz.data, field.box)
642
+
643
+ raise TypeError("Input of 'spectralDz()' must be ScalarField or VectorField.")
644
+
645
+
646
+ # ========== Derivatives calculation ==========
647
+
648
+ def Dx(field: Field, algorithm: Algorithm) -> Field:
649
+ """x derivative of Field
650
+
651
+ Parameters
652
+ ----------
653
+ field : Field
654
+ algorithm : Algorithm
655
+
656
+ Returns
657
+ -------
658
+ x derivative of Field
659
+ """
660
+ if not isinstance(field, (ScalarField, VectorField)):
661
+ raise TypeError("Input of 'Dx()' must be ScalarField or VectorField.")
662
+
663
+ if algorithm.method == 'WENO':
664
+ if algorithm.stencil is None:
665
+ raise ValueError("WENO requires stencil.")
666
+ return wenoDx(field, stencil=algorithm.stencil)
667
+
668
+ if algorithm.method == 'TENO':
669
+ if algorithm.CT is None:
670
+ raise ValueError("TENO requires CT.")
671
+ return tenoDx(field, CT=algorithm.CT)
672
+
673
+ if algorithm.method == 'TCS':
674
+ if algorithm.CT is None:
675
+ raise ValueError("TCS requires CT.")
676
+ return tcsDx(field, CT=algorithm.CT)
677
+
678
+ if algorithm.method == 'CENTRAL':
679
+ if algorithm.stencil is None:
680
+ raise ValueError("CENTRAL requires stencil.")
681
+ return centralDx(field, order=algorithm.stencil - 1)
682
+
683
+ if algorithm.method == 'SPECTRAL':
684
+ return spectralDx(field)
685
+
686
+ raise ValueError(f"Unsupported method: {algorithm.method}.")
687
+
688
+
689
+ def Dy(field: Field, algorithm: Algorithm) -> Field:
690
+ """y derivative of Field
691
+
692
+ Parameters
693
+ ----------
694
+ field : Field
695
+ algorithm : Algorithm
696
+
697
+ Returns
698
+ -------
699
+ y derivative of Field
700
+ """
701
+ if not isinstance(field, (ScalarField, VectorField)):
702
+ raise TypeError("Input of 'Dy()' must be ScalarField or VectorField.")
703
+
704
+ if algorithm.method == 'WENO':
705
+ if algorithm.stencil is None:
706
+ raise ValueError("WENO requires stencil.")
707
+ return wenoDy(field, stencil=algorithm.stencil)
708
+
709
+ if algorithm.method == 'TENO':
710
+ if algorithm.CT is None:
711
+ raise ValueError("TENO requires CT.")
712
+ return tenoDy(field, CT=algorithm.CT)
713
+
714
+ if algorithm.method == 'TCS':
715
+ if algorithm.CT is None:
716
+ raise ValueError("TCS requires CT.")
717
+ return tcsDy(field, CT=algorithm.CT)
718
+
719
+ if algorithm.method == 'CENTRAL':
720
+ if algorithm.stencil is None:
721
+ raise ValueError("CENTRAL requires stencil.")
722
+ return centralDy(field, order=algorithm.stencil - 1)
723
+
724
+ if algorithm.method == 'SPECTRAL':
725
+ return spectralDy(field)
726
+
727
+ raise ValueError(f"Unsupported method: {algorithm.method}.")
728
+
729
+ def Dz(field: Field, algorithm: Algorithm) -> Field:
730
+ """z derivative of Field
731
+
732
+ Parameters
733
+ ----------
734
+ field : Field
735
+ algorithm : Algorithm
736
+
737
+ Returns
738
+ -------
739
+ z derivative of Field
740
+ """
741
+ if not isinstance(field, (ScalarField, VectorField)):
742
+ raise TypeError("Input of 'Dz()' must be ScalarField or VectorField.")
743
+
744
+ if algorithm.method == 'WENO':
745
+ if algorithm.stencil is None:
746
+ raise ValueError("WENO requires stencil.")
747
+ return wenoDz(field, stencil=algorithm.stencil)
748
+
749
+ if algorithm.method == 'TENO':
750
+ if algorithm.CT is None:
751
+ raise ValueError("TENO requires CT.")
752
+ return tenoDz(field, CT=algorithm.CT)
753
+
754
+ if algorithm.method == 'TCS':
755
+ if algorithm.CT is None:
756
+ raise ValueError("TCS requires CT.")
757
+ return tcsDz(field, CT=algorithm.CT)
758
+
759
+ if algorithm.method == 'CENTRAL':
760
+ if algorithm.stencil is None:
761
+ raise ValueError("CENTRAL requires stencil.")
762
+ return centralDz(field, order=algorithm.stencil - 1)
763
+
764
+ if algorithm.method == 'SPECTRAL':
765
+ return spectralDz(field)
766
+
767
+ raise ValueError(f"Unsupported method: {algorithm.method}.")
768
+
769
+
770
+ # ========== nabla operators ==========
771
+
772
+ def grad(field: ScalarField, algorithm: Algorithm) -> VectorField:
773
+ """gradient of ScalarField
774
+
775
+ Gradient: ∇f = (∂f/∂x, ∂f/∂y, ∂f/∂z)
776
+
777
+ Parameters
778
+ ----------
779
+ field : ScalarField
780
+ algorithm : Algorithm
781
+
782
+ Returns
783
+ -------
784
+ gradient of ScalarField
785
+ """
786
+ if not isinstance(field, ScalarField):
787
+ raise TypeError("Input of 'grad()' must be ScalarField.")
788
+
789
+ DfDx: ScalarField = Dx(field, algorithm)
790
+ DfDy: ScalarField = Dy(field, algorithm)
791
+ DfDz: ScalarField = Dz(field, algorithm)
792
+
793
+ return VectorField(DfDx.data, DfDy.data, DfDz.data, field.box)
794
+
795
+
796
+ def div(field: VectorField, algorithm: Algorithm) -> ScalarField:
797
+ """divergence of VectorField
798
+
799
+ Divergence: ∇·V = ∂Vx/∂x + ∂Vy/∂y + ∂Vz/∂z
800
+
801
+ Parameters
802
+ ----------
803
+ field : VectorField
804
+ algorithm : Algorithm
805
+
806
+ Returns
807
+ -------
808
+ divergence of VectorField
809
+ """
810
+ if not isinstance(field, VectorField):
811
+ raise TypeError("Input of 'div()' must be VectorField.")
812
+
813
+ DVxDx: ScalarField = Dx(ScalarField(field.x, field.box), algorithm)
814
+ DVyDy: ScalarField = Dy(ScalarField(field.y, field.box), algorithm)
815
+ DVzDz: ScalarField = Dz(ScalarField(field.z, field.box), algorithm)
816
+
817
+ return DVxDx + DVyDy + DVzDz
818
+
819
+
820
+ def curl(field: VectorField, algorithm: Algorithm) -> VectorField:
821
+ """curl of VectorField
822
+
823
+ Curl: ∇ × V = (∂Vz/∂y - ∂Vy/∂z, ∂Vx/∂z - ∂Vz/∂x, ∂Vy/∂x - ∂Vx/∂y)
824
+
825
+ Parameters
826
+ ----------
827
+ field : VectorField
828
+ algorithm : Algorithm
829
+
830
+ Returns
831
+ -------
832
+ curl of VectorField
833
+ """
834
+ if not isinstance(field, VectorField):
835
+ raise TypeError("Input of 'curl()' must be VectorField.")
836
+
837
+ Vx = ScalarField(field.x, field.box)
838
+ Vy = ScalarField(field.y, field.box)
839
+ Vz = ScalarField(field.z, field.box)
840
+
841
+ curlVx = Dy(Vz, algorithm) - Dz(Vy, algorithm)
842
+ curlVy = Dz(Vx, algorithm) - Dx(Vz, algorithm)
843
+ curlVz = Dx(Vy, algorithm) - Dy(Vx, algorithm)
844
+
845
+ return VectorField(curlVx.data, curlVy.data, curlVz.data, field.box)
846
+
847
+ def laplacian(field: Field, algorithm: Algorithm) -> Field:
848
+ """laplacian of Field
849
+
850
+ Laplacian: ∇²f = ∂²f/∂x² + ∂²f/∂y² + ∂²f/∂z²
851
+
852
+ Parameters
853
+ ----------
854
+ field : Field
855
+ algorithm : Algorithm
856
+
857
+ Returns
858
+ -------
859
+ Laplacian of Field
860
+ """
861
+ if not isinstance(field, (ScalarField, VectorField)):
862
+ raise TypeError("Input of 'laplacian()' must be ScalarField or VectorField.")
863
+
864
+ if isinstance(field, ScalarField):
865
+
866
+ D2fDx2: ScalarField = Dx(Dx(field, algorithm), algorithm)
867
+ D2fDy2: ScalarField = Dy(Dy(field, algorithm), algorithm)
868
+ D2fDz2: ScalarField = Dz(Dz(field, algorithm), algorithm)
869
+
870
+ return D2fDx2 + D2fDy2 + D2fDz2
871
+
872
+ if isinstance(field, VectorField):
873
+
874
+ laplacianVx = laplacian(ScalarField(field.x, field.box), algorithm)
875
+ laplacianVy = laplacian(ScalarField(field.y, field.box), algorithm)
876
+ laplacianVz = laplacian(ScalarField(field.z, field.box), algorithm)
877
+
878
+ return VectorField(laplacianVx.data, laplacianVy.data, laplacianVz.data, field.box)
879
+
880
+ raise TypeError("Input of 'laplacian()' must be ScalarField or VectorField.")
881
+
882
+ # ========== cell averages to cell centers ==========
883
+
884
+ def average2center(field: Field, algorithm: Algorithm) -> Field:
885
+ """convert cell averages to cell centers
886
+
887
+ Convert cell averages to cell centers for 3D FVM with sixth order accuracy:
888
+ u_ctr = u_avg
889
+ - (1/24)[dx²∂²u_avg/∂x² + dy²∂²u_avg/∂y² + dz²∂²u_avg/∂z²]
890
+ + (7/5760)[dx⁴∂⁴u_avg/∂x⁴ + dy⁴∂⁴u_avg/∂y⁴ + dz⁴∂⁴u_avg/∂z⁴]
891
+ + (1/576)[(dx*dy)²∂⁴u_avg/∂x²∂y² + (dx*dz)²∂⁴u_avg/∂x²∂z² + (dy*dz)²∂⁴u_avg/∂y²∂z²]
892
+
893
+ Parameters
894
+ ----------
895
+ field : Field, cell averages
896
+ algorithm : Algorithm
897
+
898
+ Returns
899
+ -------
900
+ cellcenters: Field, cell centers
901
+ """
902
+ if not isinstance(field, (ScalarField, VectorField)):
903
+ raise TypeError("Input of 'average2center()' must be ScalarField or VectorField.")
904
+
905
+ dx, dy, dz = field.dx, field.dy, field.dz
906
+
907
+ D2uDx2: Field = Dx(Dx(field, algorithm), algorithm) # ∂²u/∂x²
908
+ D2uDy2: Field = Dy(Dy(field, algorithm), algorithm) # ∂²u/∂y²
909
+ D2uDz2: Field = Dz(Dz(field, algorithm), algorithm) # ∂²u/∂z²
910
+
911
+ D4uDx4: Field = Dx(Dx(D2uDx2, algorithm), algorithm) # ∂⁴u/∂x⁴
912
+ D4uDy4: Field = Dy(Dy(D2uDy2, algorithm), algorithm) # ∂⁴u/∂y⁴
913
+ D4uDz4: Field = Dz(Dz(D2uDz2, algorithm), algorithm) # ∂⁴u/∂z⁴
914
+
915
+ D4uDx2Dy2: Field = Dx(Dx(D2uDy2, algorithm), algorithm) # ∂⁴u/∂x²∂y²
916
+ D4uDx2Dz2: Field = Dx(Dx(D2uDz2, algorithm), algorithm) # ∂⁴u/∂x²∂z²
917
+ D4uDy2Dz2: Field = Dy(Dy(D2uDz2, algorithm), algorithm) # ∂⁴u/∂y²∂z²
918
+
919
+ cellcenters: Field = (
920
+ field \
921
+ - (1/24) * (dx**2 * D2uDx2 + dy**2 * D2uDy2 + dz**2 * D2uDz2) \
922
+ + (7/5760) * (dx**4 * D4uDx4 + dy**4 * D4uDy4 + dz**4 * D4uDz4) \
923
+ + (1/576) * ((dx*dy)**2 * D4uDx2Dy2 + (dx*dz)**2 * D4uDx2Dz2 + (dy*dz)**2 * D4uDy2Dz2)
924
+ )
925
+
926
+ return cellcenters