roms-tools 2.5.0__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/plot.py CHANGED
@@ -33,9 +33,15 @@ def _plot(
33
33
  kwargs : dict, optional
34
34
  Additional keyword arguments to pass to `pcolormesh` (e.g., colormap or color limits).
35
35
 
36
- Notes
37
- -----
38
- The function raises a `NotImplementedError` if the domain contains the North or South Pole.
36
+ Returns
37
+ -------
38
+ matplotlib.figure.Figure
39
+ The generated figure with the plotted data.
40
+
41
+ Raises
42
+ ------
43
+ NotImplementedError
44
+ If the domain contains the North or South Pole.
39
45
  """
40
46
 
41
47
  field = field.squeeze()
@@ -84,6 +90,348 @@ def _plot(
84
90
 
85
91
  ax.set_title(title)
86
92
 
93
+ return fig
94
+
95
+
96
+ def _plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=False):
97
+ """Plots nested parent and child grids with boundary overlays and grid masking.
98
+
99
+ Parameters
100
+ ----------
101
+ parent_grid_ds : xarray.Dataset
102
+ The parent grid dataset containing `lon_rho`, `lat_rho`, and `mask_rho` variables.
103
+ child_grid_ds : xarray.Dataset
104
+ The child grid dataset containing `lon_rho` and `lat_rho` variables.
105
+ parent_straddle : bool
106
+ Whether the parent grid straddles the 180-degree meridian. If True, longitudes
107
+ greater than 180° are wrapped to the -180° to 180° range.
108
+ with_dim_names : bool, optional
109
+ Whether to include dimension names in the plotted grid boundaries. Defaults to False.
110
+
111
+ Returns
112
+ -------
113
+ matplotlib.figure.Figure
114
+ The generated figure displaying the parent and child grid boundaries, mask,
115
+ and additional map features.
116
+ """
117
+
118
+ parent_lon_deg = parent_grid_ds["lon_rho"]
119
+ parent_lat_deg = parent_grid_ds["lat_rho"]
120
+
121
+ child_lon_deg = child_grid_ds["lon_rho"]
122
+ child_lat_deg = child_grid_ds["lat_rho"]
123
+
124
+ if parent_straddle:
125
+ parent_lon_deg = xr.where(
126
+ parent_lon_deg > 180, parent_lon_deg - 360, parent_lon_deg
127
+ )
128
+ child_lon_deg = xr.where(
129
+ child_lon_deg > 180, child_lon_deg - 360, child_lon_deg
130
+ )
131
+
132
+ trans = _get_projection(parent_lon_deg, parent_lat_deg)
133
+
134
+ parent_lon_deg = parent_lon_deg.values
135
+ parent_lat_deg = parent_lat_deg.values
136
+ child_lon_deg = child_lon_deg.values
137
+ child_lat_deg = child_lat_deg.values
138
+
139
+ fig, ax = plt.subplots(1, 1, figsize=(13, 7), subplot_kw={"projection": trans})
140
+
141
+ _add_boundary_to_ax(
142
+ ax,
143
+ parent_lon_deg,
144
+ parent_lat_deg,
145
+ trans,
146
+ c="r",
147
+ label="parent grid",
148
+ with_dim_names=with_dim_names,
149
+ )
150
+
151
+ _add_boundary_to_ax(
152
+ ax,
153
+ child_lon_deg,
154
+ child_lat_deg,
155
+ trans,
156
+ c="g",
157
+ label="child grid",
158
+ with_dim_names=with_dim_names,
159
+ )
160
+
161
+ vmax = 3
162
+ vmin = 0
163
+ cmap = plt.colormaps.get_cmap("Blues")
164
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
165
+
166
+ _add_field_to_ax(
167
+ ax,
168
+ parent_lon_deg,
169
+ parent_lat_deg,
170
+ parent_grid_ds.mask_rho,
171
+ add_colorbar=False,
172
+ kwargs=kwargs,
173
+ )
174
+
175
+ ax.coastlines(
176
+ resolution="50m", linewidth=0.5, color="black"
177
+ ) # add map of coastlines
178
+
179
+ # Add gridlines with labels for latitude and longitude
180
+ gridlines = ax.gridlines(
181
+ draw_labels=True, linewidth=0.5, color="gray", alpha=0.7, linestyle="--"
182
+ )
183
+ gridlines.top_labels = False # Hide top labels
184
+ gridlines.right_labels = False # Hide right labels
185
+ gridlines.xlabel_style = {
186
+ "size": 10,
187
+ "color": "black",
188
+ } # Customize longitude label style
189
+ gridlines.ylabel_style = {
190
+ "size": 10,
191
+ "color": "black",
192
+ } # Customize latitude label style
193
+
194
+ ax.legend(loc="best")
195
+
196
+ return fig
197
+
198
+
199
+ def _section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
200
+ """Plots a vertical section of a field with optional interface depths.
201
+
202
+ Parameters
203
+ ----------
204
+ field : xarray.DataArray
205
+ The field to plot, typically representing a vertical section of ocean data.
206
+ interface_depth : xarray.DataArray, optional
207
+ Interface depth values to overlay on the plot, useful for visualizing vertical layers.
208
+ Defaults to None.
209
+ title : str, optional
210
+ Title of the plot. Defaults to an empty string.
211
+ kwargs : dict, optional
212
+ Additional keyword arguments to pass to `xarray.plot`. Defaults to an empty dictionary.
213
+ ax : matplotlib.axes.Axes, optional
214
+ Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
215
+
216
+ Returns
217
+ -------
218
+ matplotlib.figure.Figure
219
+ The generated figure with the plotted section.
220
+
221
+ Raises
222
+ ------
223
+ ValueError
224
+ If no dimension in `field.dims` starts with any of the recognized horizontal dimension
225
+ prefixes (`eta_rho`, `eta_v`, `xi_rho`, `xi_u`, `lat`, `lon`).
226
+ ValueError
227
+ If no coordinate in `field.coords` starts with either `layer` or `interface`.
228
+
229
+ Notes
230
+ -----
231
+ - NaN values at the horizontal ends are dropped before plotting.
232
+ """
233
+
234
+ if ax is None:
235
+ fig, ax = plt.subplots(1, 1, figsize=(9, 5))
236
+
237
+ dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
238
+ try:
239
+ xdim = next(
240
+ dim
241
+ for dim in field.dims
242
+ if any(dim.startswith(prefix) for prefix in dims_to_check)
243
+ )
244
+ except StopIteration:
245
+ raise ValueError(
246
+ "None of the dimensions found in field.dims starts with (eta_rho, eta_v, xi_rho, xi_u, lat, lon)"
247
+ )
248
+
249
+ depths_to_check = [
250
+ "layer",
251
+ "interface",
252
+ ]
253
+ try:
254
+ depth_label = next(
255
+ depth_label
256
+ for depth_label in field.coords
257
+ if any(depth_label.startswith(prefix) for prefix in depths_to_check)
258
+ )
259
+ except StopIteration:
260
+ raise ValueError(
261
+ "None of the coordinates found in field.coords starts with (layer_depth, interface_depth)"
262
+ )
263
+
264
+ # Handle NaNs on either horizontal end
265
+ field = field.where(~field[depth_label].isnull(), drop=True)
266
+
267
+ more_kwargs = {"x": xdim, "y": depth_label, "yincrease": False}
268
+
269
+ field.plot(**kwargs, **more_kwargs, ax=ax)
270
+
271
+ if interface_depth is not None:
272
+ layer_key = "s_rho" if "s_rho" in interface_depth.dims else "s_w"
273
+
274
+ for i in range(len(interface_depth[layer_key])):
275
+ ax.plot(
276
+ interface_depth[xdim], interface_depth.isel({layer_key: i}), color="k"
277
+ )
278
+
279
+ ax.set_title(title)
280
+ ax.set_ylabel("Depth [m]")
281
+
282
+ if xdim == "lon":
283
+ xlabel = "Longitude [°E]"
284
+ elif xdim == "lat":
285
+ xlabel = "Latitude [°N]"
286
+ else:
287
+ xlabel = xdim
288
+ ax.set_xlabel(xlabel)
289
+
290
+ return fig
291
+
292
+
293
+ def _profile_plot(field, title="", ax=None):
294
+ """Plots a vertical profile of the given field against depth.
295
+
296
+ This function generates a profile plot by plotting the field values against
297
+ depth. It automatically detects the appropriate depth coordinate and
298
+ reverses the y-axis to follow the convention of increasing depth downward.
299
+
300
+ Parameters
301
+ ----------
302
+ field : xarray.DataArray
303
+ The field to plot, typically representing vertical profile data.
304
+ title : str, optional
305
+ Title of the plot. Defaults to an empty string.
306
+ ax : matplotlib.axes.Axes, optional
307
+ Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
308
+
309
+ Returns
310
+ -------
311
+ matplotlib.figure.Figure
312
+ The generated figure with the plotted profile.
313
+
314
+ Raises
315
+ ------
316
+ ValueError
317
+ If no coordinate in `field.coords` starts with either `layer_depth` or `interface_depth`.
318
+
319
+ Notes
320
+ -----
321
+ - The y-axis is inverted to ensure that depth increases downward.
322
+ """
323
+
324
+ depths_to_check = [
325
+ "layer_depth",
326
+ "interface_depth",
327
+ ]
328
+ try:
329
+ depth_label = next(
330
+ depth_label
331
+ for depth_label in field.coords
332
+ if any(depth_label.startswith(prefix) for prefix in depths_to_check)
333
+ )
334
+ except StopIteration:
335
+ raise ValueError(
336
+ "None of the coordinates found in field.coords starts with (layer_depth, interface_depth)"
337
+ )
338
+
339
+ if ax is None:
340
+ fig, ax = plt.subplots(1, 1, figsize=(4, 7))
341
+ kwargs = {"y": depth_label, "yincrease": False}
342
+ field.plot(**kwargs, linewidth=2)
343
+ ax.set_title(title)
344
+ ax.set_ylabel("Depth [m]")
345
+ ax.grid()
346
+
347
+ return fig
348
+
349
+
350
+ def _line_plot(field, title="", ax=None):
351
+ """Plots a line graph of the given field with grey vertical bars indicating NaN
352
+ regions.
353
+
354
+ Parameters
355
+ ----------
356
+ field : xarray.DataArray
357
+ The field to plot, typically a 1D or 2D field with one spatial dimension.
358
+ title : str, optional
359
+ Title of the plot. Defaults to an empty string.
360
+ ax : matplotlib.axes.Axes, optional
361
+ Pre-existing axes to draw the plot on. If None, a new figure and axes are created.
362
+
363
+ Returns
364
+ -------
365
+ matplotlib.figure.Figure
366
+ The generated figure with the plotted data and highlighted NaN regions.
367
+
368
+ Raises
369
+ ------
370
+ ValueError
371
+ If none of the dimensions in `field.dims` starts with one of the expected
372
+ prefixes: `eta_rho`, `eta_v`, `xi_rho`, `xi_u`, `lat`, or `lon`.
373
+
374
+ Notes
375
+ -----
376
+ - NaN regions are identified and marked using `axvspan` with a grey shade.
377
+ """
378
+
379
+ if ax is None:
380
+ fig, ax = plt.subplots(1, 1, figsize=(7, 4))
381
+
382
+ field.plot(ax=ax, linewidth=2)
383
+
384
+ # Loop through the NaNs in the field and add grey vertical bars
385
+ dims_to_check = ["eta_rho", "eta_v", "xi_rho", "xi_u", "lat", "lon"]
386
+ try:
387
+ xdim = next(
388
+ dim
389
+ for dim in field.dims
390
+ if any(dim.startswith(prefix) for prefix in dims_to_check)
391
+ )
392
+ except StopIteration:
393
+ raise ValueError(
394
+ "None of the dimensions found in field.dims starts with (eta_rho, eta_v, xi_rho, xi_u, lat, lon)"
395
+ )
396
+
397
+ nan_mask = np.isnan(field.values)
398
+ nan_indices = np.where(nan_mask)[0]
399
+
400
+ if len(nan_indices) > 0:
401
+ # Add grey vertical bars for each NaN region
402
+ start_idx = nan_indices[0]
403
+ for idx in range(1, len(nan_indices)):
404
+ if nan_indices[idx] != nan_indices[idx - 1] + 1:
405
+ ax.axvspan(
406
+ field[xdim][start_idx],
407
+ field[xdim][nan_indices[idx - 1] + 1],
408
+ color="gray",
409
+ alpha=0.3,
410
+ )
411
+ start_idx = nan_indices[idx]
412
+ # Add the last region of NaNs, making sure we don't go out of bounds
413
+ ax.axvspan(
414
+ field[xdim][start_idx],
415
+ field[xdim][nan_indices[-1]],
416
+ color="gray",
417
+ alpha=0.3,
418
+ )
419
+
420
+ # Set plot title and grid
421
+ ax.set_title(title)
422
+ ax.grid()
423
+ ax.set_xlim([field[xdim][0], field[xdim][-1]])
424
+
425
+ if xdim == "lon":
426
+ xlabel = "Longitude [°E]"
427
+ elif xdim == "lat":
428
+ xlabel = "Latitude [°N]"
429
+ else:
430
+ xlabel = xdim
431
+ ax.set_xlabel(xlabel)
432
+
433
+ return fig
434
+
87
435
 
88
436
  def _add_boundary_to_ax(
89
437
  ax, lon_deg, lat_deg, trans, c="red", label="", with_dim_names=False
@@ -238,214 +586,3 @@ def _get_projection(lon, lat):
238
586
  return ccrs.NearsidePerspective(
239
587
  central_longitude=lon.mean().values, central_latitude=lat.mean().values
240
588
  )
241
-
242
-
243
- def _section_plot(field, interface_depth=None, title="", kwargs={}, ax=None):
244
-
245
- if ax is None:
246
- fig, ax = plt.subplots(1, 1, figsize=(9, 5))
247
-
248
- dims_to_check = ["eta_rho", "eta_u", "eta_v", "xi_rho", "xi_u", "xi_v"]
249
- try:
250
- xdim = next(
251
- dim
252
- for dim in field.dims
253
- if any(dim.startswith(prefix) for prefix in dims_to_check)
254
- )
255
- except StopIteration:
256
- raise ValueError(
257
- "None of the dimensions found in field.dims starts with (eta_rho, eta_u, eta_v, xi_rho, xi_u, xi_v)"
258
- )
259
-
260
- depths_to_check = [
261
- "layer_depth",
262
- "interface_depth",
263
- ]
264
- try:
265
- depth_label = next(
266
- depth_label
267
- for depth_label in field.coords
268
- if any(depth_label.startswith(prefix) for prefix in depths_to_check)
269
- )
270
- except StopIteration:
271
- raise ValueError(
272
- "None of the coordinates found in field.coords starts with (layer_depth_rho, layer_depth_u, layer_depth_v, interface_depth_rho, interface_depth_u, interface_depth_v)"
273
- )
274
-
275
- more_kwargs = {"x": xdim, "y": depth_label, "yincrease": False}
276
- field.plot(**kwargs, **more_kwargs, ax=ax)
277
-
278
- if interface_depth is not None:
279
- layer_key = "s_rho" if "s_rho" in interface_depth.dims else "s_w"
280
-
281
- for i in range(len(interface_depth[layer_key])):
282
- ax.plot(
283
- interface_depth[xdim], interface_depth.isel({layer_key: i}), color="k"
284
- )
285
-
286
- ax.set_title(title)
287
- ax.set_ylabel("Depth [m]")
288
-
289
-
290
- def _profile_plot(field, title="", ax=None):
291
- """Plots a profile of the given field against depth.
292
-
293
- Parameters
294
- ----------
295
- field : xarray.DataArray
296
- Data to plot.
297
- title : str, optional
298
- Title of the plot.
299
- ax : matplotlib.axes.Axes, optional
300
- Axes to plot on. If None, a new figure is created.
301
-
302
- Raises
303
- ------
304
- ValueError
305
- If no expected depth coordinate is found in the field.
306
- """
307
-
308
- depths_to_check = [
309
- "layer_depth",
310
- "interface_depth",
311
- ]
312
- try:
313
- depth_label = next(
314
- depth_label
315
- for depth_label in field.coords
316
- if any(depth_label.startswith(prefix) for prefix in depths_to_check)
317
- )
318
- except StopIteration:
319
- raise ValueError(
320
- "None of the expected coordinates (layer_depth_rho, layer_depth_u, layer_depth_v, interface_depth_rho, interface_depth_u, interface_depth_v) found in field.coords"
321
- )
322
-
323
- if ax is None:
324
- fig, ax = plt.subplots(1, 1, figsize=(4, 7))
325
- kwargs = {"y": depth_label, "yincrease": False}
326
- field.plot(**kwargs)
327
- ax.set_title(title)
328
- ax.set_ylabel("Depth [m]")
329
- ax.grid()
330
-
331
-
332
- def _line_plot(field, title="", ax=None):
333
- """Plots a line graph of the given field, with grey vertical bars where NaNs are
334
- located.
335
-
336
- Parameters
337
- ----------
338
- field : xarray.DataArray
339
- Data to plot.
340
- title : str, optional
341
- Title of the plot.
342
- ax : matplotlib.axes.Axes, optional
343
- Axes to plot on. If None, a new figure is created.
344
-
345
- Returns
346
- -------
347
- None
348
- Modifies the plot in-place.
349
- """
350
- if ax is None:
351
- fig, ax = plt.subplots(1, 1, figsize=(7, 4))
352
- field.plot(ax=ax)
353
-
354
- # Loop through the NaNs in the field and add grey vertical bars
355
- nan_mask = np.isnan(field.values)
356
- nan_indices = np.where(nan_mask)[0]
357
-
358
- if len(nan_indices) > 0:
359
- # Add grey vertical bars for each NaN region
360
- start_idx = nan_indices[0]
361
- for idx in range(1, len(nan_indices)):
362
- if nan_indices[idx] != nan_indices[idx - 1] + 1:
363
- ax.axvspan(start_idx, nan_indices[idx - 1] + 1, color="gray", alpha=0.3)
364
- start_idx = nan_indices[idx]
365
- # Add the last region of NaNs
366
- ax.axvspan(start_idx, nan_indices[-1] + 1, color="gray", alpha=0.3)
367
-
368
- # Set plot title and grid
369
- ax.set_title(title)
370
- ax.grid()
371
-
372
-
373
- def _plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=False):
374
-
375
- parent_lon_deg = parent_grid_ds["lon_rho"]
376
- parent_lat_deg = parent_grid_ds["lat_rho"]
377
-
378
- child_lon_deg = child_grid_ds["lon_rho"]
379
- child_lat_deg = child_grid_ds["lat_rho"]
380
-
381
- if parent_straddle:
382
- parent_lon_deg = xr.where(
383
- parent_lon_deg > 180, parent_lon_deg - 360, parent_lon_deg
384
- )
385
- child_lon_deg = xr.where(
386
- child_lon_deg > 180, child_lon_deg - 360, child_lon_deg
387
- )
388
-
389
- trans = _get_projection(parent_lon_deg, parent_lat_deg)
390
-
391
- parent_lon_deg = parent_lon_deg.values
392
- parent_lat_deg = parent_lat_deg.values
393
- child_lon_deg = child_lon_deg.values
394
- child_lat_deg = child_lat_deg.values
395
-
396
- fig, ax = plt.subplots(1, 1, figsize=(13, 7), subplot_kw={"projection": trans})
397
-
398
- _add_boundary_to_ax(
399
- ax,
400
- parent_lon_deg,
401
- parent_lat_deg,
402
- trans,
403
- c="r",
404
- label="parent grid",
405
- with_dim_names=with_dim_names,
406
- )
407
-
408
- _add_boundary_to_ax(
409
- ax,
410
- child_lon_deg,
411
- child_lat_deg,
412
- trans,
413
- c="g",
414
- label="child grid",
415
- with_dim_names=with_dim_names,
416
- )
417
-
418
- vmax = 3
419
- vmin = 0
420
- cmap = plt.colormaps.get_cmap("Blues")
421
- kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
422
-
423
- _add_field_to_ax(
424
- ax,
425
- parent_lon_deg,
426
- parent_lat_deg,
427
- parent_grid_ds.mask_rho,
428
- add_colorbar=False,
429
- kwargs=kwargs,
430
- )
431
-
432
- ax.coastlines(
433
- resolution="50m", linewidth=0.5, color="black"
434
- ) # add map of coastlines
435
-
436
- # Add gridlines with labels for latitude and longitude
437
- gridlines = ax.gridlines(
438
- draw_labels=True, linewidth=0.5, color="gray", alpha=0.7, linestyle="--"
439
- )
440
- gridlines.top_labels = False # Hide top labels
441
- gridlines.right_labels = False # Hide right labels
442
- gridlines.xlabel_style = {
443
- "size": 10,
444
- "color": "black",
445
- } # Customize longitude label style
446
- gridlines.ylabel_style = {
447
- "size": 10,
448
- "color": "black",
449
- } # Customize latitude label style
450
-
451
- ax.legend(loc="best")