sdf-xarray 0.5.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdf_xarray/plotting.py ADDED
@@ -0,0 +1,567 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass
6
+ from types import MethodType
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ import numpy as np
10
+ import xarray as xr
11
+
12
+ if TYPE_CHECKING:
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.animation import FuncAnimation
15
+
16
+
17
+ @dataclass
18
+ class AnimationUnit:
19
+ update: Callable[[int], object]
20
+ n_frames: int
21
+
22
+
23
+ def get_frame_title(
24
+ data: xr.DataArray,
25
+ frame: int,
26
+ display_sdf_name: bool = False,
27
+ title_custom: str | None = None,
28
+ t: str = "time",
29
+ ) -> str:
30
+ """Generate the title for a frame
31
+
32
+ Parameters
33
+ ----------
34
+ data
35
+ DataArray containing the target data
36
+ frame
37
+ Frame number
38
+ display_sdf_name
39
+ Display the sdf file name in the animation title
40
+ title_custom
41
+ Custom title to add to the plot
42
+ t
43
+ Time coordinate
44
+ """
45
+
46
+ # Adds custom text to the start of the title, if specified
47
+ title_custom = "" if title_custom is None else f"{title_custom}, "
48
+ # Adds the time axis and associated units to the title
49
+ t_axis_value = data[t][frame].values
50
+
51
+ t_axis_units = data[t].attrs.get("units", False)
52
+ t_axis_units_formatted = f" [{t_axis_units}]" if t_axis_units else ""
53
+ title_t_axis = f"{data[t].long_name} = {t_axis_value:.2e}{t_axis_units_formatted}"
54
+
55
+ # Adds sdf name to the title, if specifed
56
+ title_sdf = f", {frame:04d}.sdf" if display_sdf_name else ""
57
+ return f"{title_custom}{title_t_axis}{title_sdf}"
58
+
59
+
60
+ def calculate_window_boundaries(
61
+ data: xr.DataArray,
62
+ xlim: tuple[float, float] | None = None,
63
+ x_axis_name: str = "X_Grid_mid",
64
+ t: str = "time",
65
+ ) -> np.ndarray:
66
+ """Calculate the boundaries a moving window frame. If the user specifies xlim, this will
67
+ be used as the initial boundaries and the window will move along acordingly.
68
+
69
+ Parameters
70
+ ----------
71
+ data
72
+ DataArray containing the target data
73
+ xlim
74
+ x limits
75
+ x_axis_name
76
+ Name of coordinate to assign to the x-axis
77
+ t
78
+ Time coordinate
79
+ """
80
+ x_grid = data[x_axis_name].values
81
+ x_half_cell = (x_grid[1] - x_grid[0]) / 2
82
+ n_frames = data[t].size
83
+
84
+ # Find the window boundaries by finding the first and last non-NaN values in the 0th lineout
85
+ # along the x-axis.
86
+ window_boundaries = np.zeros((n_frames, 2))
87
+ for i in range(n_frames):
88
+ # Check if data is 1D
89
+ if data.ndim == 2:
90
+ target_lineout = data[i].values
91
+ # Check if data is 2D
92
+ if data.ndim == 3:
93
+ target_lineout = data[i, :, 0].values
94
+ x_grid_non_nan = x_grid[~np.isnan(target_lineout)]
95
+ window_boundaries[i, 0] = x_grid_non_nan[0] - x_half_cell
96
+ window_boundaries[i, 1] = x_grid_non_nan[-1] + x_half_cell
97
+
98
+ # User's choice for initial window edge supercedes the one calculated
99
+ if xlim is not None:
100
+ window_boundaries = window_boundaries + xlim - window_boundaries[0]
101
+ return window_boundaries
102
+
103
+
104
+ def compute_global_limits(
105
+ data: xr.DataArray,
106
+ min_percentile: float = 0,
107
+ max_percentile: float = 100,
108
+ ) -> tuple[float, float]:
109
+ """Remove all NaN values from the target data to calculate the global minimum and maximum of the data.
110
+ User defined percentiles can remove extreme outliers.
111
+
112
+ Parameters
113
+ ----------
114
+ data
115
+ DataArray containing the target data
116
+ min_percentile
117
+ Minimum percentile of the data
118
+ max_percentile
119
+ Maximum percentile of the data
120
+ """
121
+
122
+ # Removes NaN values, needed for moving windows
123
+ values_no_nan = data.values[~np.isnan(data.values)]
124
+
125
+ # Finds the global minimum and maximum of the plot, based on the percentile of the data
126
+ global_min = np.percentile(values_no_nan, min_percentile)
127
+ global_max = np.percentile(values_no_nan, max_percentile)
128
+ return global_min, global_max
129
+
130
+
131
+ def _set_axes_labels(ax: plt.Axes, axis_kwargs: dict) -> None:
132
+ """Set the labels for the x and y axes"""
133
+ if "xlabel" in axis_kwargs:
134
+ ax.set_xlabel(axis_kwargs["xlabel"])
135
+ if "ylabel" in axis_kwargs:
136
+ ax.set_ylabel(axis_kwargs["ylabel"])
137
+
138
+
139
+ def _setup_2d_plot(
140
+ data: xr.DataArray,
141
+ ax: plt.Axes,
142
+ coord_names: list[str],
143
+ kwargs: dict,
144
+ axis_kwargs: dict,
145
+ min_percentile: float,
146
+ max_percentile: float,
147
+ t: str,
148
+ ) -> tuple[float, float]:
149
+ """Setup 2D plot initialization."""
150
+
151
+ kwargs.setdefault("x", coord_names[0])
152
+
153
+ data.isel({t: 0}).plot(ax=ax, **kwargs)
154
+
155
+ global_min, global_max = compute_global_limits(data, min_percentile, max_percentile)
156
+
157
+ _set_axes_labels(ax, axis_kwargs)
158
+
159
+ if "ylim" not in kwargs:
160
+ ax.set_ylim(global_min, global_max)
161
+
162
+ return global_min, global_max
163
+
164
+
165
+ def _setup_3d_plot(
166
+ data: xr.DataArray,
167
+ ax: plt.Axes,
168
+ coord_names: list[str],
169
+ kwargs: dict,
170
+ kwargs_original: dict,
171
+ axis_kwargs: dict,
172
+ min_percentile: float,
173
+ max_percentile: float,
174
+ t: str,
175
+ ) -> None:
176
+ """Setup 3D plot initialization."""
177
+ import matplotlib.pyplot as plt # noqa: PLC0415
178
+
179
+ if "norm" not in kwargs:
180
+ global_min, global_max = compute_global_limits(
181
+ data, min_percentile, max_percentile
182
+ )
183
+ kwargs["norm"] = plt.Normalize(vmin=global_min, vmax=global_max)
184
+
185
+ kwargs["add_colorbar"] = False
186
+ kwargs.setdefault("x", coord_names[0])
187
+ kwargs.setdefault("y", coord_names[1])
188
+
189
+ argmin_time = np.unravel_index(np.argmin(data.values), data.shape)[0]
190
+ plot = data.isel({t: argmin_time}).plot(ax=ax, **kwargs)
191
+ kwargs["cmap"] = plot.cmap
192
+
193
+ _set_axes_labels(ax, axis_kwargs)
194
+
195
+ if kwargs_original.get("add_colorbar", True):
196
+ long_name = data.attrs.get("long_name")
197
+ units = data.attrs.get("units")
198
+ fig = plot.get_figure()
199
+ fig.colorbar(plot, ax=ax, label=f"{long_name} [{units}]")
200
+
201
+
202
+ def _generate_animation(
203
+ data: xr.DataArray,
204
+ clear_axes: bool = False,
205
+ min_percentile: float = 0,
206
+ max_percentile: float = 100,
207
+ title: str | None = None,
208
+ display_sdf_name: bool = False,
209
+ move_window: bool = False,
210
+ t: str | None = None,
211
+ ax: plt.Axes | None = None,
212
+ kwargs: dict | None = None,
213
+ ) -> AnimationUnit:
214
+ """
215
+ Internal function for generating the plotting logic required for animations.
216
+
217
+ Parameters
218
+ ---------
219
+ data
220
+ DataArray containing the target data
221
+ clear_axes
222
+ Decide whether to run ``ax.clear()`` in every update
223
+ min_percentile
224
+ Minimum percentile of the data
225
+ max_percentile
226
+ Maximum percentile of the data
227
+ title
228
+ Custom title to add to the plot
229
+ display_sdf_name
230
+ Display the sdf file name in the animation title
231
+ move_window
232
+ Update the ``xlim`` to be only values that are not NaNs at each time interval
233
+ t
234
+ Coordinate for t axis (the coordinate which will be animated over).
235
+ If ``None``, use ``data.dims[0]``
236
+ ax
237
+ Matplotlib axes on which to plot
238
+ kwargs
239
+ Keyword arguments to be passed to matplotlib
240
+
241
+ Examples
242
+ --------
243
+ >>> anim = animate(ds["Derived_Number_Density_Electron"])
244
+ >>> anim.save("animation.gif")
245
+ """
246
+
247
+ if kwargs is None:
248
+ kwargs = {}
249
+ kwargs_original = kwargs.copy()
250
+
251
+ axis_kwargs = {}
252
+ for key in ("xlabel", "ylabel"):
253
+ if key in kwargs:
254
+ axis_kwargs[key] = kwargs.pop(key)
255
+
256
+ # Sets the animation coordinate (t) for iteration. If time is in the coords
257
+ # then it will set time to be t. If it is not it will fallback to the last
258
+ # coordinate passed in. By default coordinates are passed in from xarray in
259
+ # the form x, y, z so in order to preserve the x and y being on their
260
+ # respective axes we animate over the final coordinate that is passed in
261
+ # which in this example is z
262
+ coord_names = list(data.dims)
263
+ if t is None:
264
+ t = "time" if "time" in coord_names else coord_names[-1]
265
+ coord_names.remove(t)
266
+
267
+ N_frames = data[t].size
268
+
269
+ global_min = global_max = None
270
+ if data.ndim == 2:
271
+ global_min, global_max = _setup_2d_plot(
272
+ data=data,
273
+ ax=ax,
274
+ coord_names=coord_names,
275
+ kwargs=kwargs,
276
+ axis_kwargs=axis_kwargs,
277
+ min_percentile=min_percentile,
278
+ max_percentile=max_percentile,
279
+ t=t,
280
+ )
281
+ elif data.ndim == 3:
282
+ _setup_3d_plot(
283
+ data=data,
284
+ ax=ax,
285
+ coord_names=coord_names,
286
+ kwargs=kwargs,
287
+ kwargs_original=kwargs_original,
288
+ axis_kwargs=axis_kwargs,
289
+ min_percentile=min_percentile,
290
+ max_percentile=max_percentile,
291
+ t=t,
292
+ )
293
+
294
+ ax.set_title(get_frame_title(data, 0, display_sdf_name, title, t))
295
+
296
+ window_boundaries = None
297
+ if move_window:
298
+ window_boundaries = calculate_window_boundaries(
299
+ data, kwargs.get("xlim"), kwargs["x"]
300
+ )
301
+
302
+ def update(frame):
303
+ if clear_axes:
304
+ ax.clear()
305
+ # Set the xlim for each frame in the case of a moving window
306
+ if move_window:
307
+ kwargs["xlim"] = window_boundaries[frame]
308
+
309
+ plot = data.isel({t: frame}).plot(ax=ax, **kwargs)
310
+ ax.set_title(get_frame_title(data, frame, display_sdf_name, title, t))
311
+ _set_axes_labels(ax, axis_kwargs)
312
+
313
+ if data.ndim == 2 and "ylim" not in kwargs and global_min is not None:
314
+ ax.set_ylim(global_min, global_max)
315
+
316
+ return plot
317
+
318
+ return AnimationUnit(
319
+ update=update,
320
+ n_frames=N_frames,
321
+ )
322
+
323
+
324
+ def animate(
325
+ data: xr.DataArray,
326
+ fps: float = 10,
327
+ min_percentile: float = 0,
328
+ max_percentile: float = 100,
329
+ title: str | None = None,
330
+ display_sdf_name: bool = False,
331
+ move_window: bool = False,
332
+ t: str | None = None,
333
+ ax: plt.Axes | None = None,
334
+ **kwargs,
335
+ ) -> FuncAnimation:
336
+ """
337
+ Generate an animation using an `xarray.DataArray`. The intended use
338
+ of this function is via `sdf_xarray.plotting.EpochAccessor.animate`.
339
+
340
+ Parameters
341
+ ---------
342
+ data
343
+ DataArray containing the target data
344
+ fps
345
+ Frames per second for the animation
346
+ min_percentile
347
+ Minimum percentile of the data
348
+ max_percentile
349
+ Maximum percentile of the data
350
+ title
351
+ Custom title to add to the plot
352
+ display_sdf_name
353
+ Display the sdf file name in the animation title
354
+ move_window
355
+ Update the ``xlim`` to be only values that are not NaNs at each time interval
356
+ t
357
+ Coordinate for t axis (the coordinate which will be animated over).
358
+ If ``None``, use ``data.dims[0]``
359
+ ax
360
+ Matplotlib axes on which to plot
361
+ kwargs
362
+ Keyword arguments to be passed to matplotlib
363
+
364
+ Examples
365
+ --------
366
+ >>> anim = animate(ds["Derived_Number_Density_Electron"])
367
+ >>> anim.save("animation.gif")
368
+ """
369
+ import matplotlib.pyplot as plt # noqa: PLC0415
370
+ from matplotlib.animation import FuncAnimation # noqa: PLC0415
371
+
372
+ # Create plot if no ax is provided
373
+ if ax is None:
374
+ fig, ax = plt.subplots()
375
+ # Prevents figure from prematurely displaying in Jupyter notebook
376
+ plt.close(fig)
377
+
378
+ animation = _generate_animation(
379
+ data,
380
+ clear_axes=True,
381
+ min_percentile=min_percentile,
382
+ max_percentile=max_percentile,
383
+ title=title,
384
+ display_sdf_name=display_sdf_name,
385
+ move_window=move_window,
386
+ t=t,
387
+ ax=ax,
388
+ kwargs=kwargs,
389
+ )
390
+
391
+ return FuncAnimation(
392
+ ax.get_figure(),
393
+ animation.update,
394
+ frames=range(animation.n_frames),
395
+ interval=1000 / fps,
396
+ repeat=True,
397
+ )
398
+
399
+
400
+ def animate_multiple(
401
+ *datasets: xr.DataArray,
402
+ datasets_kwargs: list[dict[str, Any]] | None = None,
403
+ fps: float = 10,
404
+ min_percentile: float = 0,
405
+ max_percentile: float = 100,
406
+ title: str | None = None,
407
+ display_sdf_name: bool = False,
408
+ move_window: bool = False,
409
+ t: str | None = None,
410
+ ax: plt.Axes | None = None,
411
+ **common_kwargs,
412
+ ) -> FuncAnimation:
413
+ """
414
+ Generate an animation using multiple `xarray.DataArray`. The intended use
415
+ of this function is via `sdf_xarray.dataset_accessor.EpochAccessor.animate_multiple`.
416
+
417
+ Parameters
418
+ ---------
419
+ datasets
420
+ `xarray.DataArray` objects containing the data to be animated
421
+ datasets_kwargs
422
+ A list of dictionaries, following the same order as ``datasets``, containing
423
+ per-dataset matplotlib keyword arguments. The list does not need to be the same
424
+ length as ``datasets``; missing entries are initialised as empty dictionaries
425
+ fps
426
+ Frames per second for the animation
427
+ min_percentile
428
+ Minimum percentile of the data
429
+ max_percentile
430
+ Maximum percentile of the data
431
+ title
432
+ Custom title to add to the plot
433
+ display_sdf_name
434
+ Display the sdf file name in the animation title
435
+ move_window
436
+ Update the ``xlim`` to be only values that are not NaNs at each time interval
437
+ t
438
+ Coordinate for t axis (the coordinate which will be animated over). If ``None``,
439
+ use ``data.dims[0]``
440
+ ax
441
+ Matplotlib axes on which to plot
442
+ common_kwargs
443
+ Matplotlib keyword arguments applied to all datasets. These are overridden by
444
+ per-dataset entries in ``datasets_kwargs``
445
+
446
+ Examples
447
+ --------
448
+ >>> anim = animate_multiple(
449
+ ds["Derived_Number_Density_Electron"],
450
+ ds["Derived_Number_Density_Ion"],
451
+ datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}],
452
+ ylim=(0e27,4e27),
453
+ display_sdf_name=True,
454
+ ylabel="Derived Number Density [1/m$^3$]"
455
+ )
456
+ >>> anim.save("animation.gif")
457
+ """
458
+ import matplotlib.pyplot as plt # noqa: PLC0415
459
+ from matplotlib.animation import FuncAnimation # noqa: PLC0415
460
+
461
+ if not datasets:
462
+ raise ValueError("At least one dataset must be provided")
463
+
464
+ # Create plot if no ax is provided
465
+ if ax is None:
466
+ fig, ax = plt.subplots()
467
+ # Prevents figure from prematurely displaying in Jupyter notebook
468
+ plt.close(fig)
469
+
470
+ n_datasets = len(datasets)
471
+ if datasets_kwargs is None:
472
+ # Initialise an empty series of dicts the same size as the number of datasets
473
+ datasets_kwargs = [{} for _ in range(n_datasets)]
474
+ else:
475
+ # The user might only want to use kwargs on some of the datasets so we make sure
476
+ # to initialise additional empty dicts and append them to the list
477
+ datasets_kwargs.extend({} for _ in range(n_datasets - len(datasets_kwargs)))
478
+
479
+ animations: list[AnimationUnit] = []
480
+ for da, kw in zip(datasets, datasets_kwargs):
481
+ animations.append(
482
+ _generate_animation(
483
+ da,
484
+ ax=ax,
485
+ min_percentile=min_percentile,
486
+ max_percentile=max_percentile,
487
+ title=title,
488
+ display_sdf_name=display_sdf_name,
489
+ move_window=move_window,
490
+ t=t,
491
+ # Per-dataset kwargs override common matplotlib kwargs
492
+ kwargs={**common_kwargs, **kw},
493
+ )
494
+ )
495
+
496
+ lengths = [anim.n_frames for anim in animations]
497
+ n_frames = min(lengths)
498
+
499
+ if len(set(lengths)) > 1:
500
+ warnings.warn(
501
+ "Datasets have different frame counts; truncating to the shortest",
502
+ stacklevel=2,
503
+ )
504
+
505
+ # Render the legend if a label exists for any 2D dataset
506
+ show_legend = any(
507
+ "label" in kw and da.ndim == 2 for da, kw in zip(datasets, datasets_kwargs)
508
+ )
509
+
510
+ def update(frame):
511
+ ax.clear()
512
+ for anim in animations:
513
+ anim.update(frame)
514
+ if show_legend:
515
+ ax.legend(loc="upper right")
516
+
517
+ return FuncAnimation(
518
+ ax.get_figure(),
519
+ update,
520
+ frames=range(n_frames),
521
+ interval=1000 / fps,
522
+ repeat=True,
523
+ )
524
+
525
+
526
+ def show(anim):
527
+ """Shows the FuncAnimation in a Jupyter notebook.
528
+
529
+ Parameters
530
+ ----------
531
+ anim
532
+ `matplotlib.animation.FuncAnimation`
533
+ """
534
+ from IPython.display import HTML # noqa: PLC0415
535
+
536
+ return HTML(anim.to_jshtml())
537
+
538
+
539
+ @xr.register_dataarray_accessor("epoch")
540
+ class EpochAccessor:
541
+ def __init__(self, xarray_obj):
542
+ self._obj = xarray_obj
543
+
544
+ def animate(self, *args, **kwargs) -> FuncAnimation:
545
+ """Generate animations of Epoch data.
546
+
547
+ Parameters
548
+ ----------
549
+ args
550
+ Positional arguments passed to :func:`animation`.
551
+ kwargs
552
+ Keyword arguments passed to :func:`animation`.
553
+
554
+ Examples
555
+ --------
556
+ >>> anim = ds["Electric_Field_Ey"].epoch.animate()
557
+ >>> anim.save("animation.gif")
558
+ >>> # Or in a jupyter notebook:
559
+ >>> anim.show()
560
+ """
561
+
562
+ # Add anim.show() functionality
563
+ # anim.show() will display the animation in a jupyter notebook
564
+ anim = animate(self._obj, *args, **kwargs)
565
+ anim.show = MethodType(show, anim)
566
+
567
+ return anim