PyMHD 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymhd/plot/slc.py ADDED
@@ -0,0 +1,648 @@
1
+ # PyMHD: Python for Magnetohydrodynamic Turbulence.
2
+ # Copyright (c) 2026 Yuyang Hua (华宇阳)
3
+ # License: MIT
4
+
5
+ """
6
+ pymhd/plot/slc.py
7
+ -----------------
8
+
9
+ Implements the tools for plotting 2D slices of turbulence.
10
+
11
+ Currently supports MRI-driven turbulence, forced MHD turbulence, and hydrodynamic turbulence.
12
+ - Shearing-box simulations: the box ratio is hard coded to be Lx : Ly : Lz = 2 : 4 : 1
13
+ - Forced turbulence: the box ratio is hard coded to be Lx : Ly : Lz = 1 : 1 : 1
14
+
15
+ TODO: Support arbitrary box ratio.
16
+ """
17
+
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+ from matplotlib.colors import LogNorm, Normalize
21
+
22
+ from pathlib import Path
23
+
24
+ from ..turbulence import Turbulence
25
+
26
+ # Font: Computer Modern
27
+ plt.rcParams['font.family'] = 'serif'
28
+ plt.rcParams['font.serif'] = ['cmr10']
29
+ plt.rcParams['mathtext.fontset'] = 'cm' # Computer Modern
30
+ plt.rcParams['axes.unicode_minus'] = False
31
+ plt.rcParams['axes.formatter.use_mathtext'] = True
32
+
33
+ plt.rcParams['font.size'] = 18
34
+ plt.rcParams['axes.labelsize'] = 18
35
+ plt.rcParams['axes.titlesize'] = 18
36
+ plt.rcParams['xtick.labelsize'] = 16
37
+ plt.rcParams['ytick.labelsize'] = 16
38
+ plt.rcParams['figure.titlesize'] = 16
39
+
40
+ def getrange(
41
+ variables: list[tuple[str, np.ndarray]],
42
+ fraction : float = 1.0,
43
+ ) -> dict[str, tuple[float, float]]:
44
+ """Compute the colormap ranges for each variable
45
+
46
+ Parameters
47
+ ----------
48
+ variables : list of tuples, each containing a variable name and its data
49
+ fraction : float in (0, 1]. Proportion of data to include in [vmin, vmax].
50
+ Default 1.0 = full range. For non-rho, range is symmetric: vmax = percentile(|data|, fraction*100).
51
+
52
+ Returns
53
+ -------
54
+ ranges : list of tuples, variable names and colormap ranges (vmin, vmax)
55
+ """
56
+ if not (0 < fraction <= 1.0):
57
+ raise ValueError("fraction must be in (0, 1]")
58
+
59
+ ranges: dict[str, tuple[float, float]] = {}
60
+
61
+ for varname, vardata in variables:
62
+
63
+ data = vardata.flatten()
64
+
65
+ if varname == 'rho':
66
+ mean = np.mean(data)
67
+ delta = max(np.max(data) - mean, mean - np.min(data))
68
+ ranges[varname] = (mean - delta, mean + delta)
69
+ else:
70
+ vmax = float(np.percentile(np.abs(data), fraction * 100))
71
+ ranges[varname] = (-vmax, vmax)
72
+
73
+ return ranges
74
+
75
+
76
+ def plotForcedTurbulence(
77
+ turbulence: Turbulence,
78
+ fraction : float = 1.0,
79
+ ) -> None:
80
+ """Plot 2D slices of forced turbulence
81
+
82
+ Supports both forced MHD and hydrodynamic turbulence.
83
+
84
+ Parameters
85
+ ----------
86
+ turbulence: Turbulence object
87
+ fraction : float in (0, 1]. Proportion of data in color range; default 1.0 = full range.
88
+ """
89
+ basedir = Path('slices')
90
+ outputdirs = {
91
+ 'rho': basedir / 'rho',
92
+ 'V' : basedir / 'V',
93
+ 'B' : basedir / 'B',
94
+ 'J' : basedir / 'J',
95
+ 'all': basedir / 'all',
96
+ }
97
+ for outdir in outputdirs.values():
98
+ outdir.mkdir(parents=True, exist_ok=True)
99
+
100
+ pct = fraction * 100
101
+
102
+ def get_slice_range(slices: list[tuple[str, np.ndarray]]) -> tuple[float, float]:
103
+ """Get a shared linear color range from the three plotted slices."""
104
+ data = np.concatenate([arr.flatten() for _, arr in slices])
105
+ vmax = float(np.percentile(np.abs(data), pct))
106
+ return -vmax, vmax
107
+
108
+ for index, time in enumerate(turbulence.times):
109
+
110
+ rho = turbulence.rhos[index]
111
+ V = turbulence.Vs[index]
112
+ B = turbulence.Bs[index]
113
+ J = turbulence.Js[index]
114
+
115
+ Nx, Ny, Nz = rho.data.shape
116
+ Lx, Ly, Lz = rho.box
117
+
118
+ # Each tuple: (output_dir_key, variables_for_range, slices_for_x_y_z, colorbar_labels_for_x_y_z)
119
+ groups: list[tuple[str, list[tuple[str, np.ndarray]], list[str]]] = [
120
+ (
121
+ 'rho',
122
+ [('x', rho.data[Nx // 2, :, :]), ('y', rho.data[:, Ny // 2, :]), ('z', rho.data[:, :, Nz // 2])],
123
+ [r'$\rho$', r'$\rho$', r'$\rho$'],
124
+ ),
125
+ (
126
+ 'V',
127
+ [('x', V.x[Nx // 2, :, :]), ('y', V.y[:, Ny // 2, :]), ('z', V.z[:, :, Nz // 2])],
128
+ [r'$u_x$', r'$u_y$', r'$u_z$'],
129
+ ),
130
+ (
131
+ 'B',
132
+ [('x', B.x[Nx // 2, :, :]), ('y', B.y[:, Ny // 2, :]), ('z', B.z[:, :, Nz // 2])],
133
+ [r'$B_x$', r'$B_y$', r'$B_z$'],
134
+ ),
135
+ (
136
+ 'J',
137
+ [('x', J.x[Nx // 2, :, :]), ('y', J.y[:, Ny // 2, :]), ('z', J.z[:, :, Nz // 2])],
138
+ [r'$J_x$', r'$J_y$', r'$J_z$'],
139
+ ),
140
+ ]
141
+
142
+ for outkey, slices, cbarlabels in groups:
143
+
144
+ fig, axes = plt.subplots(1, 3, figsize=(16, 6), constrained_layout=True)
145
+ ax1, ax2, ax3 = axes
146
+
147
+ if outkey == 'rho':
148
+ cmap = 'Blues'
149
+ rho_values = np.concatenate([arr.flatten() for _, arr in slices])
150
+ if np.any(rho_values < 0):
151
+ raise ValueError("rho slice data contains negative values, cannot use LogNorm.")
152
+ vmin = float(np.min(rho_values))
153
+ vmax = float(np.max(rho_values))
154
+ useLog = (vmin > 0) and (vmax / vmin > 10)
155
+ norm = LogNorm(vmin=vmin, vmax=vmax) if useLog else Normalize(vmin=vmin, vmax=vmax)
156
+ im1 = ax1.imshow(
157
+ slices[0][1].T, origin='lower', cmap=cmap, norm=norm,
158
+ extent=(-Ly / 2, Ly / 2, -Lz / 2, Lz / 2), aspect='auto'
159
+ )
160
+ im2 = ax2.imshow(
161
+ slices[1][1].T, origin='lower', cmap=cmap, norm=norm,
162
+ extent=(-Lz / 2, Lz / 2, -Lx / 2, Lx / 2), aspect='auto'
163
+ )
164
+ im3 = ax3.imshow(
165
+ slices[2][1].T, origin='lower', cmap=cmap, norm=norm,
166
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2), aspect='auto'
167
+ )
168
+
169
+ else:
170
+ cmap = 'RdBu'
171
+ vmin, vmax = get_slice_range(slices)
172
+ im1 = ax1.imshow(
173
+ slices[0][1].T, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax,
174
+ extent=(-Ly / 2, Ly / 2, -Lz / 2, Lz / 2), aspect='auto'
175
+ )
176
+ im2 = ax2.imshow(
177
+ slices[1][1].T, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax,
178
+ extent=(-Lz / 2, Lz / 2, -Lx / 2, Lx / 2), aspect='auto'
179
+ )
180
+ im3 = ax3.imshow(
181
+ slices[2][1].T, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax,
182
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2), aspect='auto'
183
+ )
184
+
185
+ ax1.set_xlabel(r'$y$')
186
+ ax1.set_ylabel(r'$z$')
187
+ ax2.set_xlabel(r'$z$')
188
+ ax2.set_ylabel(r'$x$')
189
+ ax3.set_xlabel(r'$x$')
190
+ ax3.set_ylabel(r'$y$')
191
+
192
+ for ax in [ax1, ax2, ax3]:
193
+ ax.tick_params(direction='in', width=1.5, pad=7)
194
+ ax.set_box_aspect(1)
195
+ for spine in ax.spines.values():
196
+ spine.set_linewidth(1.5)
197
+
198
+ # Manually control colorbar gap/size; width always equals subplot width
199
+ cbar_bottom = 1.035
200
+ cbar_height = 0.06
201
+ cax1 = ax1.inset_axes((0, cbar_bottom, 1, cbar_height))
202
+ cax2 = ax2.inset_axes((0, cbar_bottom, 1, cbar_height))
203
+ cax3 = ax3.inset_axes((0, cbar_bottom, 1, cbar_height))
204
+
205
+ cbar1 = fig.colorbar(im1, cax=cax1, orientation='horizontal')
206
+ cbar1.set_label(cbarlabels[0], labelpad=8)
207
+ cbar1.ax.xaxis.set_ticks_position('top')
208
+ cbar1.ax.xaxis.set_label_position('top')
209
+ cbar1.ax.tick_params(labelsize=12, pad=2)
210
+ outline_spine = cbar1.ax.spines.get("outline")
211
+ if outline_spine is not None:
212
+ outline_spine.set_linewidth(1.5)
213
+
214
+ cbar2 = fig.colorbar(im2, cax=cax2, orientation='horizontal')
215
+ cbar2.set_label(cbarlabels[1], labelpad=8)
216
+ cbar2.ax.xaxis.set_ticks_position('top')
217
+ cbar2.ax.xaxis.set_label_position('top')
218
+ cbar2.ax.tick_params(labelsize=12, pad=2)
219
+ outline_spine = cbar2.ax.spines.get("outline")
220
+ if outline_spine is not None:
221
+ outline_spine.set_linewidth(1.5)
222
+
223
+ cbar3 = fig.colorbar(im3, cax=cax3, orientation='horizontal')
224
+ cbar3.set_label(cbarlabels[2], labelpad=8)
225
+ cbar3.ax.xaxis.set_ticks_position('top')
226
+ cbar3.ax.xaxis.set_label_position('top')
227
+ cbar3.ax.tick_params(labelsize=12, pad=2)
228
+ outline_spine = cbar3.ax.spines.get("outline")
229
+ if outline_spine is not None:
230
+ outline_spine.set_linewidth(1.5)
231
+
232
+ plt.savefig(outputdirs[outkey] / f't={time:.2f}.pdf', bbox_inches='tight')
233
+ plt.close()
234
+
235
+ all: list[tuple[str, np.ndarray, str, str]] = [
236
+ ('rho', rho.data[:, :, Nz // 2], r'$\rho$', 'Blues'),
237
+ ('Vz' , V.z[:, :, Nz // 2], r'$u_z$' , 'RdBu' ),
238
+ ('Bz' , B.z[:, :, Nz // 2], r'$B_z$' , 'RdBu' ),
239
+ ]
240
+
241
+ fig, axes = plt.subplots(1, 3, figsize=(16, 6), constrained_layout=True)
242
+
243
+ for ax, (varname, slicedata, cbarlabel, cmap) in zip(axes, all):
244
+ if varname == 'rho':
245
+ flattened = slicedata.flatten()
246
+ if np.any(flattened < 0):
247
+ raise ValueError("rho slice data contains negative values, cannot use LogNorm.")
248
+ vmin = float(np.min(flattened))
249
+ vmax = float(np.max(flattened))
250
+ useLog = (vmin > 0) and (vmax / vmin > 10)
251
+ norm = LogNorm(vmin=vmin, vmax=vmax) if useLog else Normalize(vmin=vmin, vmax=vmax)
252
+ image = ax.imshow(
253
+ slicedata.T,
254
+ origin='lower',
255
+ cmap=cmap, norm=norm,
256
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2),
257
+ aspect='auto',
258
+ )
259
+ else:
260
+ vmax = float(np.percentile(np.abs(slicedata.flatten()), pct))
261
+ image = ax.imshow(
262
+ slicedata.T,
263
+ origin='lower',
264
+ cmap=cmap, vmin=-vmax, vmax=vmax,
265
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2),
266
+ aspect='auto',
267
+ )
268
+
269
+ ax.set_xlabel(r'$x$')
270
+ ax.set_ylabel(r'$y$')
271
+ ax.tick_params(direction='in', width=1.5, pad=7)
272
+ ax.set_box_aspect(1)
273
+ for spine in ax.spines.values():
274
+ spine.set_linewidth(1.5)
275
+
276
+ cax = ax.inset_axes((0, 1.035, 1, 0.06))
277
+ cbar = fig.colorbar(image, cax=cax, orientation='horizontal')
278
+ cbar.set_label(cbarlabel, labelpad=8)
279
+ cbar.ax.xaxis.set_ticks_position('top')
280
+ cbar.ax.xaxis.set_label_position('top')
281
+ cbar.ax.tick_params(labelsize=12, pad=2)
282
+ outline_spine = cbar.ax.spines.get("outline")
283
+ if outline_spine is not None:
284
+ outline_spine.set_linewidth(1.5)
285
+
286
+ plt.savefig(outputdirs['all'] / f't={time:.2f}.pdf', bbox_inches='tight')
287
+ plt.close()
288
+
289
+
290
+ def plotHydroTurbulence(
291
+ turbulence: Turbulence,
292
+ fraction: float = 1.0,
293
+ ) -> None:
294
+ """Plot 2D slices of hydrodynamic turbulence
295
+
296
+ Parameters
297
+ ----------
298
+ turbulence: Turbulence object
299
+ fraction : float in (0, 1]. Proportion of data in color range; default 1.0 = full range.
300
+ """
301
+ basedir = Path('slices')
302
+ outputdirs = {
303
+ 'rho': basedir / 'rho',
304
+ 'V' : basedir / 'V',
305
+ 'all': basedir / 'all',
306
+ }
307
+ for outdir in outputdirs.values():
308
+ outdir.mkdir(parents=True, exist_ok=True)
309
+
310
+ pct = fraction * 100
311
+
312
+ def get_slice_range(slices: list[tuple[str, np.ndarray]]) -> tuple[float, float]:
313
+ """Get a shared linear color range from the three plotted slices."""
314
+ data = np.concatenate([arr.flatten() for _, arr in slices])
315
+ vmax = float(np.percentile(np.abs(data), pct))
316
+ return -vmax, vmax
317
+
318
+ for index, time in enumerate(turbulence.times):
319
+
320
+ rho = turbulence.rhos[index]
321
+ V = turbulence.Vs[index]
322
+
323
+ Nx, Ny, Nz = rho.data.shape
324
+ Lx, Ly, Lz = rho.box
325
+
326
+ groups: list[tuple[str, list[tuple[str, np.ndarray]], list[str]]] = [
327
+ (
328
+ 'rho',
329
+ [('x', rho.data[Nx // 2, :, :]), ('y', rho.data[:, Ny // 2, :]), ('z', rho.data[:, :, Nz // 2])],
330
+ [r'$\rho$', r'$\rho$', r'$\rho$'],
331
+ ),
332
+ (
333
+ 'V',
334
+ [('x', V.x[Nx // 2, :, :]), ('y', V.y[:, Ny // 2, :]), ('z', V.z[:, :, Nz // 2])],
335
+ [r'$u_x$', r'$u_y$', r'$u_z$'],
336
+ ),
337
+ ]
338
+
339
+ for outkey, slices, cbarlabels in groups:
340
+
341
+ fig, axes = plt.subplots(1, 3, figsize=(16, 6), constrained_layout=True)
342
+ ax1, ax2, ax3 = axes
343
+
344
+ if outkey == 'rho':
345
+ cmap = 'Blues'
346
+ rho_values = np.concatenate([arr.flatten() for _, arr in slices])
347
+ if np.any(rho_values < 0):
348
+ raise ValueError("rho slice data contains negative values, cannot use LogNorm.")
349
+ vmin = float(np.min(rho_values))
350
+ vmax = float(np.max(rho_values))
351
+ useLog = (vmin > 0) and (vmax / vmin > 10)
352
+ norm = LogNorm(vmin=vmin, vmax=vmax) if useLog else Normalize(vmin=vmin, vmax=vmax)
353
+ im1 = ax1.imshow(
354
+ slices[0][1].T, origin='lower', cmap=cmap, norm=norm,
355
+ extent=(-Ly / 2, Ly / 2, -Lz / 2, Lz / 2), aspect='auto'
356
+ )
357
+ im2 = ax2.imshow(
358
+ slices[1][1].T, origin='lower', cmap=cmap, norm=norm,
359
+ extent=(-Lz / 2, Lz / 2, -Lx / 2, Lx / 2), aspect='auto'
360
+ )
361
+ im3 = ax3.imshow(
362
+ slices[2][1].T, origin='lower', cmap=cmap, norm=norm,
363
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2), aspect='auto'
364
+ )
365
+
366
+ else:
367
+ cmap = 'RdBu'
368
+ vmin, vmax = get_slice_range(slices)
369
+ im1 = ax1.imshow(
370
+ slices[0][1].T, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax,
371
+ extent=(-Ly / 2, Ly / 2, -Lz / 2, Lz / 2), aspect='auto'
372
+ )
373
+ im2 = ax2.imshow(
374
+ slices[1][1].T, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax,
375
+ extent=(-Lz / 2, Lz / 2, -Lx / 2, Lx / 2), aspect='auto'
376
+ )
377
+ im3 = ax3.imshow(
378
+ slices[2][1].T, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax,
379
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2), aspect='auto'
380
+ )
381
+
382
+ ax1.set_xlabel(r'$y$')
383
+ ax1.set_ylabel(r'$z$')
384
+ ax2.set_xlabel(r'$z$')
385
+ ax2.set_ylabel(r'$x$')
386
+ ax3.set_xlabel(r'$x$')
387
+ ax3.set_ylabel(r'$y$')
388
+
389
+ for ax in [ax1, ax2, ax3]:
390
+ ax.tick_params(direction='in', width=1.5, pad=7)
391
+ ax.set_box_aspect(1)
392
+ for spine in ax.spines.values():
393
+ spine.set_linewidth(1.5)
394
+
395
+ cbar_bottom = 1.035
396
+ cbar_height = 0.06
397
+ cax1 = ax1.inset_axes((0, cbar_bottom, 1, cbar_height))
398
+ cax2 = ax2.inset_axes((0, cbar_bottom, 1, cbar_height))
399
+ cax3 = ax3.inset_axes((0, cbar_bottom, 1, cbar_height))
400
+
401
+ cbar1 = fig.colorbar(im1, cax=cax1, orientation='horizontal')
402
+ cbar1.set_label(cbarlabels[0], labelpad=8)
403
+ cbar1.ax.xaxis.set_ticks_position('top')
404
+ cbar1.ax.xaxis.set_label_position('top')
405
+ cbar1.ax.tick_params(labelsize=12, pad=2)
406
+ outline_spine = cbar1.ax.spines.get("outline")
407
+ if outline_spine is not None:
408
+ outline_spine.set_linewidth(1.5)
409
+
410
+ cbar2 = fig.colorbar(im2, cax=cax2, orientation='horizontal')
411
+ cbar2.set_label(cbarlabels[1], labelpad=8)
412
+ cbar2.ax.xaxis.set_ticks_position('top')
413
+ cbar2.ax.xaxis.set_label_position('top')
414
+ cbar2.ax.tick_params(labelsize=12, pad=2)
415
+ outline_spine = cbar2.ax.spines.get("outline")
416
+ if outline_spine is not None:
417
+ outline_spine.set_linewidth(1.5)
418
+
419
+ cbar3 = fig.colorbar(im3, cax=cax3, orientation='horizontal')
420
+ cbar3.set_label(cbarlabels[2], labelpad=8)
421
+ cbar3.ax.xaxis.set_ticks_position('top')
422
+ cbar3.ax.xaxis.set_label_position('top')
423
+ cbar3.ax.tick_params(labelsize=12, pad=2)
424
+ outline_spine = cbar3.ax.spines.get("outline")
425
+ if outline_spine is not None:
426
+ outline_spine.set_linewidth(1.5)
427
+
428
+ plt.savefig(outputdirs[outkey] / f't={time:.2f}.pdf', bbox_inches='tight')
429
+ plt.close()
430
+
431
+ # all: two subplots, left=rho z=0, right=Vz z=0
432
+ all_slices: list[tuple[str, np.ndarray, str, str]] = [
433
+ ('rho', rho.data[:, :, Nz // 2], r'$\rho$', 'Blues'),
434
+ ('Vz', V.z[:, :, Nz // 2], r'$u_z$', 'RdBu'),
435
+ ]
436
+
437
+ fig, axes = plt.subplots(1, 2, figsize=(12, 7), constrained_layout=True)
438
+
439
+ for ax, (varname, slicedata, cbarlabel, cmap) in zip(axes, all_slices):
440
+ if varname == 'rho':
441
+ flattened = slicedata.flatten()
442
+ if np.any(flattened < 0):
443
+ raise ValueError("rho slice data contains negative values, cannot use LogNorm.")
444
+ vmin = float(np.min(flattened))
445
+ vmax = float(np.max(flattened))
446
+ useLog = (vmin > 0) and (vmax / vmin > 10)
447
+ norm = LogNorm(vmin=vmin, vmax=vmax) if useLog else Normalize(vmin=vmin, vmax=vmax)
448
+ image = ax.imshow(
449
+ slicedata.T,
450
+ origin='lower',
451
+ cmap=cmap,
452
+ norm=norm,
453
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2),
454
+ aspect='auto',
455
+ )
456
+ else:
457
+ vmax = float(np.percentile(np.abs(slicedata.flatten()), pct))
458
+ image = ax.imshow(
459
+ slicedata.T,
460
+ origin='lower',
461
+ cmap=cmap,
462
+ vmin=-vmax,
463
+ vmax=vmax,
464
+ extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2),
465
+ aspect='auto',
466
+ )
467
+
468
+ ax.set_xlabel(r'$x$')
469
+ ax.set_ylabel(r'$y$')
470
+ ax.tick_params(direction='in', width=1.5, pad=7)
471
+ ax.set_box_aspect(1)
472
+ for spine in ax.spines.values():
473
+ spine.set_linewidth(1.5)
474
+
475
+ cax = ax.inset_axes((0, 1.035, 1, 0.06))
476
+ cbar = fig.colorbar(image, cax=cax, orientation='horizontal')
477
+ cbar.set_label(cbarlabel, labelpad=8)
478
+ cbar.ax.xaxis.set_ticks_position('top')
479
+ cbar.ax.xaxis.set_label_position('top')
480
+ cbar.ax.tick_params(labelsize=12, pad=2)
481
+ outline_spine = cbar.ax.spines.get("outline")
482
+ if outline_spine is not None:
483
+ outline_spine.set_linewidth(1.5)
484
+
485
+ plt.savefig(outputdirs['all'] / f't={time:.2f}.pdf', bbox_inches='tight')
486
+ plt.close()
487
+
488
+
489
+ def plotMRITurbulence(
490
+ turbulence: Turbulence,
491
+ fraction: float = 1.0,
492
+ ) -> None:
493
+ """Plot 2D slices of MRI-driven turbulence
494
+
495
+ Plot a figure for each physical component, containing three slices in the following directions:
496
+ - upper left : yx plane (z=0)
497
+ - upper right: zx plane (y=0)
498
+ - lower left : yz plane (x=0)
499
+
500
+ Parameters
501
+ ----------
502
+ turbulence: Turbulence object
503
+ fraction : float in (0, 1]. Passed to getrange; default 1.0 = full range.
504
+ """
505
+ basedir = Path('slices')
506
+ basedir.mkdir(parents=True, exist_ok=True)
507
+
508
+ for index, time in enumerate(turbulence.times):
509
+
510
+ rho = turbulence.rhos[index]
511
+ V = turbulence.Vs[index]
512
+ B = turbulence.Bs[index]
513
+
514
+ Nx, Ny, Nz = rho.data.shape
515
+
516
+ variables: list[tuple[str, np.ndarray]] = [
517
+ ('rho', rho.data),
518
+ ('Vx', V.x), ('Vy', V.y), ('Vz', V.z),
519
+ ('Bx', B.x), ('By', B.y), ('Bz', B.z),
520
+ ]
521
+
522
+ ranges = getrange(variables, fraction=fraction)
523
+
524
+ for varname, vardata in variables:
525
+
526
+ vardir = basedir / varname
527
+ vardir.mkdir(parents=True, exist_ok=True)
528
+
529
+ fig = plt.figure(figsize=(14, 8), constrained_layout=False)
530
+
531
+ # Create grid layout, ensure the height of im1 and im2 is the same
532
+ # The first row occupies 2/3 height, the second row occupies 1/3 height
533
+ # The first column occupies 2/3 width, the second column occupies 1/3 width
534
+ gs = plt.GridSpec(
535
+ 2, 3, figure = fig,
536
+ width_ratios = [4, 1, 0.2],
537
+ height_ratios = [2, 1],
538
+ left = 0.1,
539
+ right = 0.9,
540
+ top = 0.9,
541
+ bottom = 0.1,
542
+ wspace = 0.08,
543
+ hspace = 0.08
544
+ )
545
+
546
+ ax1 = fig.add_subplot(gs[0, 0]) # yx plane
547
+ ax2 = fig.add_subplot(gs[0, 1]) # zx plane
548
+ ax3 = fig.add_subplot(gs[1, 0]) # yz plane
549
+
550
+ cax = fig.add_subplot(gs[:, 2]) # colorbar subplot
551
+
552
+ vmin, vmax = ranges[varname]
553
+
554
+ Lx, Ly, Lz = rho.box
555
+
556
+ cmap = 'RdBu'
557
+
558
+ im1 = ax1.imshow(
559
+ vardata[:, :, Nz//2] , origin='upper',
560
+ cmap=cmap, vmin=vmin, vmax=vmax,
561
+ extent=(-Ly/2, Ly/2, Lx/2, -Lx/2)
562
+ )
563
+ im2 = ax2.imshow(
564
+ vardata[:, Ny//2, :] , origin='upper',
565
+ cmap=cmap, vmin=vmin, vmax=vmax,
566
+ extent=(-Lz/2, Lz/2, Lx/2, -Lx/2)
567
+ )
568
+ im3 = ax3.imshow(
569
+ vardata[Nx//2, :, :].T, origin='lower',
570
+ cmap=cmap, vmin=vmin, vmax=vmax,
571
+ extent=(-Ly/2, Ly/2, -Lz/2, Lz/2)
572
+ )
573
+
574
+ ax1.set_aspect('equal')
575
+ ax2.set_aspect('equal')
576
+ ax3.set_aspect('equal')
577
+
578
+ # Set the position of the ticklabels
579
+ # Remove the bottom axis ticklabels of im1
580
+ ax1.xaxis.set_ticklabels([])
581
+
582
+ # Remove the left axis ticklabels of im2
583
+ ax2.yaxis.set_ticklabels([])
584
+ ax2.set_ylabel('')
585
+
586
+ ax1.tick_params(direction='in')
587
+ ax2.tick_params(direction='in')
588
+ ax3.tick_params(direction='in')
589
+
590
+ ax1.set_ylabel(r'$x$', labelpad=0)
591
+ ax2.set_xlabel(r'$z$', labelpad=10)
592
+ ax3.set_xlabel(r'$y$', labelpad=10)
593
+ ax3.set_ylabel(r'$z$', labelpad=0)
594
+
595
+ linewidth = 1.5
596
+ pad = 7
597
+
598
+ ax1.tick_params(width=linewidth, pad=pad)
599
+ ax2.tick_params(width=linewidth, pad=pad)
600
+ ax3.tick_params(width=linewidth, pad=pad)
601
+
602
+ # set the width of the axes borders
603
+ for ax in [ax1, ax2, ax3]:
604
+ for spine in ax.spines.values():
605
+ spine.set_linewidth(linewidth)
606
+
607
+ cbar = fig.colorbar(im1, cax=cax, shrink=0.9)
608
+
609
+ cbarlabel = {
610
+ 'rho': r'$\rho$',
611
+ 'Vx' : r'$V_x$',
612
+ 'Vy' : r'$V_y$',
613
+ 'Vz' : r'$V_z$',
614
+ 'Bx' : r'$B_x$',
615
+ 'By' : r'$B_y$',
616
+ 'Bz' : r'$B_z$'
617
+ }[varname]
618
+
619
+ cbar.set_label(f'{cbarlabel}', labelpad=10)
620
+
621
+ outline_spine = cbar.ax.spines.get("outline")
622
+ if outline_spine is not None:
623
+ outline_spine.set_linewidth(linewidth) # width of the colorbar border
624
+
625
+ plt.savefig(vardir / f't={time:.2f}.pdf', bbox_inches='tight')
626
+ plt.close()
627
+
628
+ def plot2dslice(
629
+ turbulence: Turbulence,
630
+ fraction: float = 1.0,
631
+ ) -> None:
632
+ """Plot 2D slices of turbulence
633
+
634
+ Route to the corresponding implementation based on turbulence.type.
635
+
636
+ Parameters
637
+ ----------
638
+ turbulence: Turbulence object
639
+ fraction : float in (0, 1]. Proportion of data in color range; default 1.0 = full range.
640
+ """
641
+ if turbulence.type == 'MRI':
642
+ plotMRITurbulence(turbulence, fraction=fraction)
643
+ elif turbulence.type in ('SSD', 'Bx', 'Bz'):
644
+ plotForcedTurbulence(turbulence, fraction=fraction)
645
+ elif turbulence.type == 'hydro':
646
+ plotHydroTurbulence(turbulence, fraction=fraction)
647
+ else:
648
+ raise ValueError(f"Unsupported turbulence type: {turbulence.type}")