sdf-xarray 0.2.0__cp312-cp312-win_amd64.whl → 0.5.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdf_xarray/plotting.py CHANGED
@@ -1,6 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ import warnings
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass
6
+ from types import MethodType
7
+ from typing import TYPE_CHECKING, Any
4
8
 
5
9
  import numpy as np
6
10
  import xarray as xr
@@ -10,154 +14,528 @@ if TYPE_CHECKING:
10
14
  from matplotlib.animation import FuncAnimation
11
15
 
12
16
 
13
- def get_frame_title(data: xr.DataArray, frame: int, display_sdf_name: bool) -> str:
14
- """Generate the title for a frame"""
15
- sdf_name = f", {frame:04d}.sdf" if display_sdf_name else ""
16
- time = data["time"][frame].to_numpy()
17
- return f"t = {time:.2e}s{sdf_name}"
17
+ @dataclass
18
+ class AnimationUnit:
19
+ update: Callable[[int], object]
20
+ n_frames: int
18
21
 
19
22
 
20
- def calculate_window_velocity_and_edges(
21
- data: xr.DataArray, x_axis_coord: str
22
- ) -> tuple[float, tuple[float, float], np.ndarray]:
23
- """Calculate the moving window's velocity and initial edges.
23
+ def get_frame_title(
24
+ data: xr.DataArray,
25
+ frame: int,
26
+ display_sdf_name: bool = False,
27
+ title_custom: str | None = None,
28
+ t: str = "time",
29
+ ) -> str:
30
+ """Generate the title for a frame
24
31
 
25
- 1. Finds a lineout of the target atribute in the x coordinate of the first frame
26
- 2. Removes the NaN values to isolate the simulation window
27
- 3. Produces the index size of the window, indexed at zero
28
- 4. Uses distance moved and final time of the simulation to calculate velocity and initial xlims
32
+ Parameters
33
+ ----------
34
+ data
35
+ DataArray containing the target data
36
+ frame
37
+ Frame number
38
+ display_sdf_name
39
+ Display the sdf file name in the animation title
40
+ title_custom
41
+ Custom title to add to the plot
42
+ t
43
+ Time coordinate
29
44
  """
30
- time_since_start = data["time"].values - data["time"].values[0]
31
- initial_window_edge = (0, 0)
32
- target_lineout = data.values[0, :, 0]
33
- target_lineout_window = target_lineout[~np.isnan(target_lineout)]
34
- x_grid = data[x_axis_coord].values
35
- window_size_index = target_lineout_window.size - 1
36
45
 
37
- velocity_window = (x_grid[-1] - x_grid[window_size_index]) / time_since_start[-1]
38
- initial_window_edge = (x_grid[0], x_grid[window_size_index])
39
- return velocity_window, initial_window_edge, time_since_start
46
+ # Adds custom text to the start of the title, if specified
47
+ title_custom = "" if title_custom is None else f"{title_custom}, "
48
+ # Adds the time axis and associated units to the title
49
+ t_axis_value = data[t][frame].values
50
+
51
+ t_axis_units = data[t].attrs.get("units", False)
52
+ t_axis_units_formatted = f" [{t_axis_units}]" if t_axis_units else ""
53
+ title_t_axis = f"{data[t].long_name} = {t_axis_value:.2e}{t_axis_units_formatted}"
54
+
55
+ # Adds sdf name to the title, if specifed
56
+ title_sdf = f", {frame:04d}.sdf" if display_sdf_name else ""
57
+ return f"{title_custom}{title_t_axis}{title_sdf}"
58
+
40
59
 
60
+ def calculate_window_boundaries(
61
+ data: xr.DataArray,
62
+ xlim: tuple[float, float] | None = None,
63
+ x_axis_name: str = "X_Grid_mid",
64
+ t: str = "time",
65
+ ) -> np.ndarray:
66
+ """Calculate the boundaries a moving window frame. If the user specifies xlim, this will
67
+ be used as the initial boundaries and the window will move along acordingly.
68
+
69
+ Parameters
70
+ ----------
71
+ data
72
+ DataArray containing the target data
73
+ xlim
74
+ x limits
75
+ x_axis_name
76
+ Name of coordinate to assign to the x-axis
77
+ t
78
+ Time coordinate
79
+ """
80
+ x_grid = data[x_axis_name].values
81
+ x_half_cell = (x_grid[1] - x_grid[0]) / 2
82
+ n_frames = data[t].size
83
+
84
+ # Find the window boundaries by finding the first and last non-NaN values in the 0th lineout
85
+ # along the x-axis.
86
+ window_boundaries = np.zeros((n_frames, 2))
87
+ for i in range(n_frames):
88
+ # Check if data is 1D
89
+ if data.ndim == 2:
90
+ target_lineout = data[i].values
91
+ # Check if data is 2D
92
+ if data.ndim == 3:
93
+ target_lineout = data[i, :, 0].values
94
+ x_grid_non_nan = x_grid[~np.isnan(target_lineout)]
95
+ window_boundaries[i, 0] = x_grid_non_nan[0] - x_half_cell
96
+ window_boundaries[i, 1] = x_grid_non_nan[-1] + x_half_cell
97
+
98
+ # User's choice for initial window edge supercedes the one calculated
99
+ if xlim is not None:
100
+ window_boundaries = window_boundaries + xlim - window_boundaries[0]
101
+ return window_boundaries
102
+
103
+
104
+ def compute_global_limits(
105
+ data: xr.DataArray,
106
+ min_percentile: float = 0,
107
+ max_percentile: float = 100,
108
+ ) -> tuple[float, float]:
109
+ """Remove all NaN values from the target data to calculate the global minimum and maximum of the data.
110
+ User defined percentiles can remove extreme outliers.
41
111
 
42
- def compute_global_limits(data: xr.DataArray) -> tuple[float, float]:
43
- """Remove all NaN values from the target data to calculate the 1st and 99th percentiles,
44
- excluding extreme outliers.
112
+ Parameters
113
+ ----------
114
+ data
115
+ DataArray containing the target data
116
+ min_percentile
117
+ Minimum percentile of the data
118
+ max_percentile
119
+ Maximum percentile of the data
45
120
  """
121
+
122
+ # Removes NaN values, needed for moving windows
46
123
  values_no_nan = data.values[~np.isnan(data.values)]
47
- global_min = np.percentile(values_no_nan, 1)
48
- global_max = np.percentile(values_no_nan, 99)
124
+
125
+ # Finds the global minimum and maximum of the plot, based on the percentile of the data
126
+ global_min = np.percentile(values_no_nan, min_percentile)
127
+ global_max = np.percentile(values_no_nan, max_percentile)
49
128
  return global_min, global_max
50
129
 
51
130
 
52
- def is_1d(data: xr.DataArray) -> bool:
53
- """Check if the data is 1D."""
54
- return len(data.shape) == 2
131
+ def _set_axes_labels(ax: plt.Axes, axis_kwargs: dict) -> None:
132
+ """Set the labels for the x and y axes"""
133
+ if "xlabel" in axis_kwargs:
134
+ ax.set_xlabel(axis_kwargs["xlabel"])
135
+ if "ylabel" in axis_kwargs:
136
+ ax.set_ylabel(axis_kwargs["ylabel"])
55
137
 
56
138
 
57
- def is_2d(data: xr.DataArray) -> bool:
58
- """Check if the data is 2D or 3D."""
59
- return len(data.shape) == 3
139
+ def _setup_2d_plot(
140
+ data: xr.DataArray,
141
+ ax: plt.Axes,
142
+ coord_names: list[str],
143
+ kwargs: dict,
144
+ axis_kwargs: dict,
145
+ min_percentile: float,
146
+ max_percentile: float,
147
+ t: str,
148
+ ) -> tuple[float, float]:
149
+ """Setup 2D plot initialization."""
150
+
151
+ kwargs.setdefault("x", coord_names[0])
152
+
153
+ data.isel({t: 0}).plot(ax=ax, **kwargs)
154
+
155
+ global_min, global_max = compute_global_limits(data, min_percentile, max_percentile)
156
+
157
+ _set_axes_labels(ax, axis_kwargs)
158
+
159
+ if "ylim" not in kwargs:
160
+ ax.set_ylim(global_min, global_max)
161
+
162
+ return global_min, global_max
60
163
 
61
164
 
62
- def generate_animation(
165
+ def _setup_3d_plot(
63
166
  data: xr.DataArray,
167
+ ax: plt.Axes,
168
+ coord_names: list[str],
169
+ kwargs: dict,
170
+ kwargs_original: dict,
171
+ axis_kwargs: dict,
172
+ min_percentile: float,
173
+ max_percentile: float,
174
+ t: str,
175
+ ) -> None:
176
+ """Setup 3D plot initialization."""
177
+ import matplotlib.pyplot as plt # noqa: PLC0415
178
+
179
+ if "norm" not in kwargs:
180
+ global_min, global_max = compute_global_limits(
181
+ data, min_percentile, max_percentile
182
+ )
183
+ kwargs["norm"] = plt.Normalize(vmin=global_min, vmax=global_max)
184
+
185
+ kwargs["add_colorbar"] = False
186
+ kwargs.setdefault("x", coord_names[0])
187
+ kwargs.setdefault("y", coord_names[1])
188
+
189
+ argmin_time = np.unravel_index(np.argmin(data.values), data.shape)[0]
190
+ plot = data.isel({t: argmin_time}).plot(ax=ax, **kwargs)
191
+ kwargs["cmap"] = plot.cmap
192
+
193
+ _set_axes_labels(ax, axis_kwargs)
194
+
195
+ if kwargs_original.get("add_colorbar", True):
196
+ long_name = data.attrs.get("long_name")
197
+ units = data.attrs.get("units")
198
+ fig = plot.get_figure()
199
+ fig.colorbar(plot, ax=ax, label=f"{long_name} [{units}]")
200
+
201
+
202
+ def _generate_animation(
203
+ data: xr.DataArray,
204
+ clear_axes: bool = False,
205
+ min_percentile: float = 0,
206
+ max_percentile: float = 100,
207
+ title: str | None = None,
64
208
  display_sdf_name: bool = False,
65
- fps: int = 10,
66
209
  move_window: bool = False,
210
+ t: str | None = None,
211
+ ax: plt.Axes | None = None,
212
+ kwargs: dict | None = None,
213
+ ) -> AnimationUnit:
214
+ """
215
+ Internal function for generating the plotting logic required for animations.
216
+
217
+ Parameters
218
+ ---------
219
+ data
220
+ DataArray containing the target data
221
+ clear_axes
222
+ Decide whether to run ``ax.clear()`` in every update
223
+ min_percentile
224
+ Minimum percentile of the data
225
+ max_percentile
226
+ Maximum percentile of the data
227
+ title
228
+ Custom title to add to the plot
229
+ display_sdf_name
230
+ Display the sdf file name in the animation title
231
+ move_window
232
+ Update the ``xlim`` to be only values that are not NaNs at each time interval
233
+ t
234
+ Coordinate for t axis (the coordinate which will be animated over).
235
+ If ``None``, use ``data.dims[0]``
236
+ ax
237
+ Matplotlib axes on which to plot
238
+ kwargs
239
+ Keyword arguments to be passed to matplotlib
240
+
241
+ Examples
242
+ --------
243
+ >>> anim = animate(ds["Derived_Number_Density_Electron"])
244
+ >>> anim.save("animation.gif")
245
+ """
246
+
247
+ if kwargs is None:
248
+ kwargs = {}
249
+ kwargs_original = kwargs.copy()
250
+
251
+ axis_kwargs = {}
252
+ for key in ("xlabel", "ylabel"):
253
+ if key in kwargs:
254
+ axis_kwargs[key] = kwargs.pop(key)
255
+
256
+ # Sets the animation coordinate (t) for iteration. If time is in the coords
257
+ # then it will set time to be t. If it is not it will fallback to the last
258
+ # coordinate passed in. By default coordinates are passed in from xarray in
259
+ # the form x, y, z so in order to preserve the x and y being on their
260
+ # respective axes we animate over the final coordinate that is passed in
261
+ # which in this example is z
262
+ coord_names = list(data.dims)
263
+ if t is None:
264
+ t = "time" if "time" in coord_names else coord_names[-1]
265
+ coord_names.remove(t)
266
+
267
+ N_frames = data[t].size
268
+
269
+ global_min = global_max = None
270
+ if data.ndim == 2:
271
+ global_min, global_max = _setup_2d_plot(
272
+ data=data,
273
+ ax=ax,
274
+ coord_names=coord_names,
275
+ kwargs=kwargs,
276
+ axis_kwargs=axis_kwargs,
277
+ min_percentile=min_percentile,
278
+ max_percentile=max_percentile,
279
+ t=t,
280
+ )
281
+ elif data.ndim == 3:
282
+ _setup_3d_plot(
283
+ data=data,
284
+ ax=ax,
285
+ coord_names=coord_names,
286
+ kwargs=kwargs,
287
+ kwargs_original=kwargs_original,
288
+ axis_kwargs=axis_kwargs,
289
+ min_percentile=min_percentile,
290
+ max_percentile=max_percentile,
291
+ t=t,
292
+ )
293
+
294
+ ax.set_title(get_frame_title(data, 0, display_sdf_name, title, t))
295
+
296
+ window_boundaries = None
297
+ if move_window:
298
+ window_boundaries = calculate_window_boundaries(
299
+ data, kwargs.get("xlim"), kwargs["x"]
300
+ )
301
+
302
+ def update(frame):
303
+ if clear_axes:
304
+ ax.clear()
305
+ # Set the xlim for each frame in the case of a moving window
306
+ if move_window:
307
+ kwargs["xlim"] = window_boundaries[frame]
308
+
309
+ plot = data.isel({t: frame}).plot(ax=ax, **kwargs)
310
+ ax.set_title(get_frame_title(data, frame, display_sdf_name, title, t))
311
+ _set_axes_labels(ax, axis_kwargs)
312
+
313
+ if data.ndim == 2 and "ylim" not in kwargs and global_min is not None:
314
+ ax.set_ylim(global_min, global_max)
315
+
316
+ return plot
317
+
318
+ return AnimationUnit(
319
+ update=update,
320
+ n_frames=N_frames,
321
+ )
322
+
323
+
324
+ def animate(
325
+ data: xr.DataArray,
326
+ fps: float = 10,
327
+ min_percentile: float = 0,
328
+ max_percentile: float = 100,
329
+ title: str | None = None,
330
+ display_sdf_name: bool = False,
331
+ move_window: bool = False,
332
+ t: str | None = None,
67
333
  ax: plt.Axes | None = None,
68
334
  **kwargs,
69
335
  ) -> FuncAnimation:
70
- """Generate an animation
336
+ """
337
+ Generate an animation using an `xarray.DataArray`. The intended use
338
+ of this function is via `sdf_xarray.plotting.EpochAccessor.animate`.
71
339
 
72
340
  Parameters
73
341
  ---------
74
- dataset
75
- The dataset containing the simulation data
76
- target_attribute
77
- The attribute to plot for each timestep
342
+ data
343
+ DataArray containing the target data
344
+ fps
345
+ Frames per second for the animation
346
+ min_percentile
347
+ Minimum percentile of the data
348
+ max_percentile
349
+ Maximum percentile of the data
350
+ title
351
+ Custom title to add to the plot
78
352
  display_sdf_name
79
353
  Display the sdf file name in the animation title
80
- fps
81
- Frames per second for the animation (default: 10)
82
354
  move_window
83
- If the simulation has a moving window, the animation will move along
84
- with it (default: False)
355
+ Update the ``xlim`` to be only values that are not NaNs at each time interval
356
+ t
357
+ Coordinate for t axis (the coordinate which will be animated over).
358
+ If ``None``, use ``data.dims[0]``
85
359
  ax
86
- Matplotlib axes on which to plot.
360
+ Matplotlib axes on which to plot
87
361
  kwargs
88
- Keyword arguments to be passed to matplotlib.
362
+ Keyword arguments to be passed to matplotlib
89
363
 
90
364
  Examples
91
365
  --------
92
- >>> generate_animation(dataset["Derived_Number_Density_Electron"])
366
+ >>> anim = animate(ds["Derived_Number_Density_Electron"])
367
+ >>> anim.save("animation.gif")
93
368
  """
94
- import matplotlib.pyplot as plt
95
- from matplotlib.animation import FuncAnimation
369
+ import matplotlib.pyplot as plt # noqa: PLC0415
370
+ from matplotlib.animation import FuncAnimation # noqa: PLC0415
96
371
 
372
+ # Create plot if no ax is provided
97
373
  if ax is None:
98
- _, ax = plt.subplots()
374
+ fig, ax = plt.subplots()
375
+ # Prevents figure from prematurely displaying in Jupyter notebook
376
+ plt.close(fig)
377
+
378
+ animation = _generate_animation(
379
+ data,
380
+ clear_axes=True,
381
+ min_percentile=min_percentile,
382
+ max_percentile=max_percentile,
383
+ title=title,
384
+ display_sdf_name=display_sdf_name,
385
+ move_window=move_window,
386
+ t=t,
387
+ ax=ax,
388
+ kwargs=kwargs,
389
+ )
99
390
 
100
- N_frames = data["time"].size
101
- global_min, global_max = compute_global_limits(data)
391
+ return FuncAnimation(
392
+ ax.get_figure(),
393
+ animation.update,
394
+ frames=range(animation.n_frames),
395
+ interval=1000 / fps,
396
+ repeat=True,
397
+ )
102
398
 
103
- if is_2d(data):
104
- kwargs["norm"] = plt.Normalize(vmin=global_min, vmax=global_max)
105
- kwargs["add_colorbar"] = False
106
- # Set default x and y coordinates for 2D data if not provided
107
- kwargs.setdefault("x", "X_Grid_mid")
108
- kwargs.setdefault("y", "Y_Grid_mid")
109
399
 
110
- # Initialize the plot with the first timestep
111
- plot = data.isel(time=0).plot(ax=ax, **kwargs)
112
- ax.set_title(get_frame_title(data, 0, display_sdf_name))
400
+ def animate_multiple(
401
+ *datasets: xr.DataArray,
402
+ datasets_kwargs: list[dict[str, Any]] | None = None,
403
+ fps: float = 10,
404
+ min_percentile: float = 0,
405
+ max_percentile: float = 100,
406
+ title: str | None = None,
407
+ display_sdf_name: bool = False,
408
+ move_window: bool = False,
409
+ t: str | None = None,
410
+ ax: plt.Axes | None = None,
411
+ **common_kwargs,
412
+ ) -> FuncAnimation:
413
+ """
414
+ Generate an animation using multiple `xarray.DataArray`. The intended use
415
+ of this function is via `sdf_xarray.dataset_accessor.EpochAccessor.animate_multiple`.
113
416
 
114
- # Add colorbar
115
- long_name = data.attrs.get("long_name")
116
- units = data.attrs.get("units")
117
- plt.colorbar(plot, ax=ax, label=f"{long_name} [${units}$]")
417
+ Parameters
418
+ ---------
419
+ datasets
420
+ `xarray.DataArray` objects containing the data to be animated
421
+ datasets_kwargs
422
+ A list of dictionaries, following the same order as ``datasets``, containing
423
+ per-dataset matplotlib keyword arguments. The list does not need to be the same
424
+ length as ``datasets``; missing entries are initialised as empty dictionaries
425
+ fps
426
+ Frames per second for the animation
427
+ min_percentile
428
+ Minimum percentile of the data
429
+ max_percentile
430
+ Maximum percentile of the data
431
+ title
432
+ Custom title to add to the plot
433
+ display_sdf_name
434
+ Display the sdf file name in the animation title
435
+ move_window
436
+ Update the ``xlim`` to be only values that are not NaNs at each time interval
437
+ t
438
+ Coordinate for t axis (the coordinate which will be animated over). If ``None``,
439
+ use ``data.dims[0]``
440
+ ax
441
+ Matplotlib axes on which to plot
442
+ common_kwargs
443
+ Matplotlib keyword arguments applied to all datasets. These are overridden by
444
+ per-dataset entries in ``datasets_kwargs``
118
445
 
119
- # Initialise plo and set y-limits for 1D data
120
- if is_1d(data):
121
- plot = data.isel(time=0).plot(ax=ax, **kwargs)
122
- ax.set_title(get_frame_title(data, 0, display_sdf_name))
123
- ax.set_ylim(global_min, global_max)
446
+ Examples
447
+ --------
448
+ >>> anim = animate_multiple(
449
+ ds["Derived_Number_Density_Electron"],
450
+ ds["Derived_Number_Density_Ion"],
451
+ datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}],
452
+ ylim=(0e27,4e27),
453
+ display_sdf_name=True,
454
+ ylabel="Derived Number Density [1/m$^3$]"
455
+ )
456
+ >>> anim.save("animation.gif")
457
+ """
458
+ import matplotlib.pyplot as plt # noqa: PLC0415
459
+ from matplotlib.animation import FuncAnimation # noqa: PLC0415
124
460
 
125
- if move_window:
126
- window_velocity, window_initial_edge, time_since_start = (
127
- calculate_window_velocity_and_edges(data, kwargs["x"])
461
+ if not datasets:
462
+ raise ValueError("At least one dataset must be provided")
463
+
464
+ # Create plot if no ax is provided
465
+ if ax is None:
466
+ fig, ax = plt.subplots()
467
+ # Prevents figure from prematurely displaying in Jupyter notebook
468
+ plt.close(fig)
469
+
470
+ n_datasets = len(datasets)
471
+ if datasets_kwargs is None:
472
+ # Initialise an empty series of dicts the same size as the number of datasets
473
+ datasets_kwargs = [{} for _ in range(n_datasets)]
474
+ else:
475
+ # The user might only want to use kwargs on some of the datasets so we make sure
476
+ # to initialise additional empty dicts and append them to the list
477
+ datasets_kwargs.extend({} for _ in range(n_datasets - len(datasets_kwargs)))
478
+
479
+ animations: list[AnimationUnit] = []
480
+ for da, kw in zip(datasets, datasets_kwargs):
481
+ animations.append(
482
+ _generate_animation(
483
+ da,
484
+ ax=ax,
485
+ min_percentile=min_percentile,
486
+ max_percentile=max_percentile,
487
+ title=title,
488
+ display_sdf_name=display_sdf_name,
489
+ move_window=move_window,
490
+ t=t,
491
+ # Per-dataset kwargs override common matplotlib kwargs
492
+ kwargs={**common_kwargs, **kw},
493
+ )
128
494
  )
129
495
 
130
- # User's choice for initial window edge supercides the one calculated
131
- if "xlim" in kwargs:
132
- window_initial_edge = kwargs["xlim"]
496
+ lengths = [anim.n_frames for anim in animations]
497
+ n_frames = min(lengths)
133
498
 
134
- def update(frame):
135
- # Set the xlim for each frame in the case of a moving window
136
- if move_window:
137
- kwargs["xlim"] = (
138
- window_initial_edge[0] + window_velocity * time_since_start[frame],
139
- window_initial_edge[1] * 0.99
140
- + window_velocity * time_since_start[frame],
141
- )
499
+ if len(set(lengths)) > 1:
500
+ warnings.warn(
501
+ "Datasets have different frame counts; truncating to the shortest",
502
+ stacklevel=2,
503
+ )
142
504
 
143
- # Update plot for the new frame
144
- ax.clear()
145
- data.isel(time=frame).plot(ax=ax, **kwargs)
146
- ax.set_title(get_frame_title(data, frame, display_sdf_name))
505
+ # Render the legend if a label exists for any 2D dataset
506
+ show_legend = any(
507
+ "label" in kw and da.ndim == 2 for da, kw in zip(datasets, datasets_kwargs)
508
+ )
147
509
 
148
- # # Update y-limits for 1D data
149
- if is_1d(data):
150
- ax.set_ylim(global_min, global_max)
510
+ def update(frame):
511
+ ax.clear()
512
+ for anim in animations:
513
+ anim.update(frame)
514
+ if show_legend:
515
+ ax.legend(loc="upper right")
151
516
 
152
517
  return FuncAnimation(
153
518
  ax.get_figure(),
154
519
  update,
155
- frames=range(N_frames),
520
+ frames=range(n_frames),
156
521
  interval=1000 / fps,
157
522
  repeat=True,
158
523
  )
159
524
 
160
525
 
526
+ def show(anim):
527
+ """Shows the FuncAnimation in a Jupyter notebook.
528
+
529
+ Parameters
530
+ ----------
531
+ anim
532
+ `matplotlib.animation.FuncAnimation`
533
+ """
534
+ from IPython.display import HTML # noqa: PLC0415
535
+
536
+ return HTML(anim.to_jshtml())
537
+
538
+
161
539
  @xr.register_dataarray_accessor("epoch")
162
540
  class EpochAccessor:
163
541
  def __init__(self, xarray_obj):
@@ -169,16 +547,21 @@ class EpochAccessor:
169
547
  Parameters
170
548
  ----------
171
549
  args
172
- Positional arguments passed to :func:`generate_animation`.
550
+ Positional arguments passed to :func:`animation`.
173
551
  kwargs
174
- Keyword arguments passed to :func:`generate_animation`.
552
+ Keyword arguments passed to :func:`animation`.
175
553
 
176
554
  Examples
177
555
  --------
178
- >>> import xarray as xr
179
- >>> from sdf_xarray import SDFPreprocess
180
- >>> ds = xr.open_mfdataset("*.sdf", preprocess=SDFPreprocess())
181
- >>> ani = ds["Electric_Field_Ey"].epoch.animate()
182
- >>> ani.save("myfile.mp4")
556
+ >>> anim = ds["Electric_Field_Ey"].epoch.animate()
557
+ >>> anim.save("animation.gif")
558
+ >>> # Or in a jupyter notebook:
559
+ >>> anim.show()
183
560
  """
184
- return generate_animation(self._obj, *args, **kwargs)
561
+
562
+ # Add anim.show() functionality
563
+ # anim.show() will display the animation in a jupyter notebook
564
+ anim = animate(self._obj, *args, **kwargs)
565
+ anim.show = MethodType(show, anim)
566
+
567
+ return anim