roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,528 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import yaml
4
+ import importlib.metadata
5
+ from dataclasses import dataclass, field, asdict
6
+ from roms_tools.setup.grid import Grid
7
+ from roms_tools.setup.vertical_coordinate import VerticalCoordinate
8
+ from datetime import datetime
9
+ from roms_tools.setup.datasets import Dataset
10
+ from roms_tools.setup.fill import fill_and_interpolate
11
+ from roms_tools.setup.utils import (
12
+ nan_check,
13
+ interpolate_from_rho_to_u,
14
+ interpolate_from_rho_to_v,
15
+ extrapolate_deepest_to_bottom,
16
+ )
17
+ from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
18
+ import matplotlib.pyplot as plt
19
+
20
+
21
+ @dataclass(frozen=True, kw_only=True)
22
+ class InitialConditions:
23
+ """
24
+ Represents initial conditions for ROMS.
25
+
26
+ Parameters
27
+ ----------
28
+ grid : Grid
29
+ Object representing the grid information.
30
+ vertical_coordinate: VerticalCoordinate
31
+ Object representing the vertical coordinate information
32
+ ini_time : datetime
33
+ Desired initialization time.
34
+ model_reference_date : datetime, optional
35
+ Reference date for the model. Default is January 1, 2000.
36
+ source : str, optional
37
+ Source of the initial condition data. Default is "GLORYS".
38
+ filename: str
39
+ Path to the source data file. Can contain wildcards.
40
+
41
+ Attributes
42
+ ----------
43
+ ds : xr.Dataset
44
+ Xarray Dataset containing the initial condition data.
45
+
46
+ """
47
+
48
+ grid: Grid
49
+ vertical_coordinate: VerticalCoordinate
50
+ ini_time: datetime
51
+ model_reference_date: datetime = datetime(2000, 1, 1)
52
+ source: str = "GLORYS"
53
+ filename: str
54
+ ds: xr.Dataset = field(init=False, repr=False)
55
+
56
+ def __post_init__(self):
57
+
58
+ # Check that the source is "GLORYS"
59
+ if self.source != "GLORYS":
60
+ raise ValueError('Only "GLORYS" is a valid option for source.')
61
+ if self.source == "GLORYS":
62
+ dims = {
63
+ "longitude": "longitude",
64
+ "latitude": "latitude",
65
+ "depth": "depth",
66
+ "time": "time",
67
+ }
68
+
69
+ varnames = {
70
+ "temp": "thetao",
71
+ "salt": "so",
72
+ "u": "uo",
73
+ "v": "vo",
74
+ "ssh": "zos",
75
+ }
76
+
77
+ data = Dataset(
78
+ filename=self.filename,
79
+ start_time=self.ini_time,
80
+ var_names=varnames.values(),
81
+ dim_names=dims,
82
+ )
83
+
84
+ lon = self.grid.ds.lon_rho
85
+ lat = self.grid.ds.lat_rho
86
+ angle = self.grid.ds.angle
87
+
88
+ # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
89
+ lon = xr.where(lon > 180, lon - 360, lon)
90
+ straddle = True
91
+ if not self.grid.straddle and abs(lon).min() > 5:
92
+ lon = xr.where(lon < 0, lon + 360, lon)
93
+ straddle = False
94
+
95
+ # The following consists of two steps:
96
+ # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
97
+ # We perform these two steps for two reasons:
98
+ # A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
99
+ # B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
100
+ # can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
101
+ # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
102
+ # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
103
+ data.choose_subdomain(
104
+ latitude_range=[lat.min().values, lat.max().values],
105
+ longitude_range=[lon.min().values, lon.max().values],
106
+ margin=2,
107
+ straddle=straddle,
108
+ )
109
+
110
+ # interpolate onto desired grid
111
+ fill_dims = [dims["latitude"], dims["longitude"]]
112
+
113
+ # 2d interpolation
114
+ mask = xr.where(data.ds[varnames["ssh"]].isel(time=0).isnull(), 0, 1)
115
+ coords = {dims["latitude"]: lat, dims["longitude"]: lon}
116
+
117
+ ssh = fill_and_interpolate(
118
+ data.ds[varnames["ssh"]].astype(np.float64),
119
+ mask,
120
+ fill_dims=fill_dims,
121
+ coords=coords,
122
+ method="linear",
123
+ )
124
+
125
+ # 3d interpolation
126
+
127
+ # extrapolate deepest value all the way to bottom ("flooding")
128
+ for var in ["temp", "salt", "u", "v"]:
129
+ data.ds[varnames[var]] = extrapolate_deepest_to_bottom(
130
+ data.ds[varnames[var]], dims["depth"]
131
+ )
132
+
133
+ mask = xr.where(data.ds[varnames["temp"]].isel(time=0).isnull(), 0, 1)
134
+ coords = {
135
+ dims["latitude"]: lat,
136
+ dims["longitude"]: lon,
137
+ dims["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
138
+ }
139
+
140
+ # setting fillvalue_interp to None means that we allow extrapolation in the
141
+ # interpolation step to avoid NaNs at the surface if the lowest depth in original
142
+ # data is greater than zero
143
+ data_vars = {}
144
+ for var in ["temp", "salt", "u", "v"]:
145
+ data_vars[var] = fill_and_interpolate(
146
+ data.ds[varnames[var]].astype(np.float64),
147
+ mask,
148
+ fill_dims=fill_dims,
149
+ coords=coords,
150
+ method="linear",
151
+ fillvalue_interp=None,
152
+ )
153
+
154
+ # rotate to grid orientation
155
+ u_rot = data_vars["u"] * np.cos(angle) + data_vars["v"] * np.sin(angle)
156
+ v_rot = data_vars["v"] * np.cos(angle) - data_vars["u"] * np.sin(angle)
157
+
158
+ # interpolate to u- and v-points
159
+ u = interpolate_from_rho_to_u(u_rot)
160
+ v = interpolate_from_rho_to_v(v_rot)
161
+
162
+ # 3d masks for ROMS domain
163
+ umask = self.grid.ds.mask_u.expand_dims({"s_rho": u.s_rho})
164
+ vmask = self.grid.ds.mask_v.expand_dims({"s_rho": v.s_rho})
165
+
166
+ u = u * umask
167
+ v = v * vmask
168
+
169
+ # Compute barotropic velocity
170
+ # thicknesses
171
+ dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
172
+ dz = dz.rename({"s_w": "s_rho"})
173
+ # thicknesses at u- and v-points
174
+ dzu = interpolate_from_rho_to_u(dz)
175
+ dzv = interpolate_from_rho_to_v(dz)
176
+
177
+ ubar = (dzu * u).sum(dim="s_rho") / dzu.sum(dim="s_rho")
178
+ vbar = (dzv * v).sum(dim="s_rho") / dzv.sum(dim="s_rho")
179
+
180
+ # save in new dataset
181
+ ds = xr.Dataset()
182
+
183
+ ds["temp"] = data_vars["temp"].astype(np.float32)
184
+ ds["temp"].attrs["long_name"] = "Potential temperature"
185
+ ds["temp"].attrs["units"] = "Celsius"
186
+
187
+ ds["salt"] = data_vars["salt"].astype(np.float32)
188
+ ds["salt"].attrs["long_name"] = "Salinity"
189
+ ds["salt"].attrs["units"] = "PSU"
190
+
191
+ ds["zeta"] = ssh.astype(np.float32)
192
+ ds["zeta"].attrs["long_name"] = "Free surface"
193
+ ds["zeta"].attrs["units"] = "m"
194
+
195
+ ds["u"] = u.astype(np.float32)
196
+ ds["u"].attrs["long_name"] = "u-flux component"
197
+ ds["u"].attrs["units"] = "m/s"
198
+
199
+ ds["v"] = v.astype(np.float32)
200
+ ds["v"].attrs["long_name"] = "v-flux component"
201
+ ds["v"].attrs["units"] = "m/s"
202
+
203
+ # initialize vertical velocity to zero
204
+ ds["w"] = xr.zeros_like(
205
+ self.vertical_coordinate.ds["interface_depth_rho"].expand_dims(
206
+ time=ds[dims["time"]]
207
+ )
208
+ ).astype(np.float32)
209
+ ds["w"].attrs["long_name"] = "w-flux component"
210
+ ds["w"].attrs["units"] = "m/s"
211
+
212
+ ds["ubar"] = ubar.transpose(dims["time"], "eta_rho", "xi_u").astype(np.float32)
213
+ ds["ubar"].attrs["long_name"] = "vertically integrated u-flux component"
214
+ ds["ubar"].attrs["units"] = "m/s"
215
+
216
+ ds["vbar"] = vbar.transpose(dims["time"], "eta_v", "xi_rho").astype(np.float32)
217
+ ds["vbar"].attrs["long_name"] = "vertically integrated v-flux component"
218
+ ds["vbar"].attrs["units"] = "m/s"
219
+
220
+ ds = ds.assign_coords(
221
+ {
222
+ "layer_depth_u": self.vertical_coordinate.ds["layer_depth_u"],
223
+ "layer_depth_v": self.vertical_coordinate.ds["layer_depth_v"],
224
+ "interface_depth_u": self.vertical_coordinate.ds["interface_depth_u"],
225
+ "interface_depth_v": self.vertical_coordinate.ds["interface_depth_v"],
226
+ }
227
+ )
228
+
229
+ ds.attrs["title"] = "ROMS initial conditions file created by ROMS-Tools"
230
+ # Include the version of roms-tools
231
+ try:
232
+ roms_tools_version = importlib.metadata.version("roms-tools")
233
+ except importlib.metadata.PackageNotFoundError:
234
+ roms_tools_version = "unknown"
235
+ ds.attrs["roms_tools_version"] = roms_tools_version
236
+ ds.attrs["ini_time"] = str(self.ini_time)
237
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
238
+ ds.attrs["source"] = self.source
239
+
240
+ if dims["time"] != "time":
241
+ ds = ds.rename({dims["time"]: "time"})
242
+
243
+ # Translate the time coordinate to days since the model reference date
244
+ model_reference_date = np.datetime64(self.model_reference_date)
245
+
246
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
247
+ ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
248
+ ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
249
+ ds["ocean_time"].attrs[
250
+ "long_name"
251
+ ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
252
+ ds["ocean_time"].attrs["units"] = "seconds"
253
+
254
+ ds["theta_s"] = self.vertical_coordinate.ds["theta_s"]
255
+ ds["theta_b"] = self.vertical_coordinate.ds["theta_b"]
256
+ ds["Tcline"] = self.vertical_coordinate.ds["Tcline"]
257
+ ds["hc"] = self.vertical_coordinate.ds["hc"]
258
+ ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
259
+ ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
260
+
261
+ ds = ds.drop_vars(["s_rho"])
262
+
263
+ object.__setattr__(self, "ds", ds)
264
+
265
+ ds["zeta"].load()
266
+ nan_check(ds["zeta"].squeeze(), self.grid.ds.mask_rho)
267
+
268
+ def plot(
269
+ self,
270
+ varname,
271
+ s=None,
272
+ eta=None,
273
+ xi=None,
274
+ depth_contours=False,
275
+ layer_contours=False,
276
+ ) -> None:
277
+ """
278
+ Plot the initial conditions field for a given eta-, xi-, or s_rho-slice.
279
+
280
+ Parameters
281
+ ----------
282
+ varname : str
283
+ The name of the initial conditions field to plot. Options include:
284
+ - "temp": Potential temperature.
285
+ - "salt": Salinity.
286
+ - "zeta": Free surface.
287
+ - "u": u-flux component.
288
+ - "v": v-flux component.
289
+ - "w": w-flux component.
290
+ - "ubar": Vertically integrated u-flux component.
291
+ - "vbar": Vertically integrated v-flux component.
292
+ s : int, optional
293
+ The index of the vertical layer to plot. Default is None.
294
+ eta : int, optional
295
+ The eta-index to plot. Default is None.
296
+ xi : int, optional
297
+ The xi-index to plot. Default is None.
298
+ depth_contours : bool, optional
299
+ Whether to include depth contours in the plot. Default is False.
300
+
301
+ Returns
302
+ -------
303
+ None
304
+ This method does not return any value. It generates and displays a plot.
305
+
306
+ Raises
307
+ ------
308
+ ValueError
309
+ If the specified varname is not one of the valid options.
310
+ If field is 3D and none of s_rho, eta, xi are specified.
311
+ If field is 2D and both eta and xi are specified.
312
+ """
313
+
314
+ if len(self.ds[varname].squeeze().dims) == 3 and not any(
315
+ [s is not None, eta is not None, xi is not None]
316
+ ):
317
+ raise ValueError(
318
+ "For 3D fields, at least one of s, eta, or xi must be specified."
319
+ )
320
+
321
+ if len(self.ds[varname].squeeze().dims) == 2 and all(
322
+ [eta is not None, xi is not None]
323
+ ):
324
+ raise ValueError("For 2D fields, specify either eta or xi, not both.")
325
+
326
+ self.ds[varname].load()
327
+ field = self.ds[varname].squeeze()
328
+
329
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
330
+ interface_depth = self.ds.interface_depth_rho
331
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
332
+ interface_depth = self.ds.interface_depth_u
333
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
334
+ interface_depth = self.ds.interface_depth_v
335
+
336
+ # slice the field as desired
337
+ title = field.long_name
338
+ if s is not None:
339
+ title = title + f", s_rho = {field.s_rho[s].item()}"
340
+ field = field.isel(s_rho=s)
341
+ else:
342
+ depth_contours = False
343
+
344
+ if eta is not None:
345
+ if "eta_rho" in field.dims:
346
+ title = title + f", eta_rho = {field.eta_rho[eta].item()}"
347
+ field = field.isel(eta_rho=eta)
348
+ interface_depth = interface_depth.isel(eta_rho=eta)
349
+ elif "eta_v" in field.dims:
350
+ title = title + f", eta_v = {field.eta_v[eta].item()}"
351
+ field = field.isel(eta_v=eta)
352
+ interface_depth = interface_depth.isel(eta_v=eta)
353
+ else:
354
+ raise ValueError(
355
+ f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
356
+ )
357
+ if xi is not None:
358
+ if "xi_rho" in field.dims:
359
+ title = title + f", xi_rho = {field.xi_rho[xi].item()}"
360
+ field = field.isel(xi_rho=xi)
361
+ interface_depth = interface_depth.isel(xi_rho=xi)
362
+ elif "xi_u" in field.dims:
363
+ title = title + f", xi_u = {field.xi_u[xi].item()}"
364
+ field = field.isel(xi_u=xi)
365
+ interface_depth = interface_depth.isel(xi_u=xi)
366
+ else:
367
+ raise ValueError(
368
+ f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
369
+ )
370
+
371
+ # chose colorbar
372
+ if varname in ["u", "v", "w", "ubar", "vbar", "zeta"]:
373
+ vmax = max(field.max().values, -field.min().values)
374
+ vmin = -vmax
375
+ cmap = plt.colormaps.get_cmap("RdBu_r")
376
+ else:
377
+ vmax = field.max().values
378
+ vmin = field.min().values
379
+ cmap = plt.colormaps.get_cmap("YlOrRd")
380
+ cmap.set_bad(color="gray")
381
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
382
+
383
+ if eta is None and xi is None:
384
+ _plot(
385
+ self.grid.ds,
386
+ field=field,
387
+ straddle=self.grid.straddle,
388
+ depth_contours=depth_contours,
389
+ title=title,
390
+ kwargs=kwargs,
391
+ c="g",
392
+ )
393
+ else:
394
+ if not layer_contours:
395
+ interface_depth = None
396
+ else:
397
+ # restrict number of layer_contours to 10 for the sake of plot clearity
398
+ nr_layers = len(interface_depth["s_w"])
399
+ selected_layers = np.linspace(
400
+ 0, nr_layers - 1, min(nr_layers, 10), dtype=int
401
+ )
402
+ interface_depth = interface_depth.isel(s_w=selected_layers)
403
+
404
+ if len(field.dims) == 2:
405
+ _section_plot(
406
+ field, interface_depth=interface_depth, title=title, kwargs=kwargs
407
+ )
408
+ else:
409
+ if "s_rho" in field.dims:
410
+ _profile_plot(field, title=title)
411
+ else:
412
+ _line_plot(field, title=title)
413
+
414
+ def save(self, filepath: str) -> None:
415
+ """
416
+ Save the initial conditions information to a netCDF4 file.
417
+
418
+ Parameters
419
+ ----------
420
+ filepath
421
+ """
422
+ self.ds.to_netcdf(filepath)
423
+
424
+ def to_yaml(self, filepath: str) -> None:
425
+ """
426
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
427
+
428
+ Parameters
429
+ ----------
430
+ filepath : str
431
+ The path to the YAML file where the parameters will be saved.
432
+ """
433
+ # Serialize Grid data
434
+ grid_data = asdict(self.grid)
435
+ grid_data.pop("ds", None) # Exclude non-serializable fields
436
+ grid_data.pop("straddle", None)
437
+
438
+ # Serialize VerticalCoordinate data
439
+ vertical_coordinate_data = asdict(self.vertical_coordinate)
440
+ vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
441
+ vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
442
+
443
+ # Include the version of roms-tools
444
+ try:
445
+ roms_tools_version = importlib.metadata.version("roms-tools")
446
+ except importlib.metadata.PackageNotFoundError:
447
+ roms_tools_version = "unknown"
448
+
449
+ # Create header
450
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
451
+
452
+ grid_yaml_data = {"Grid": grid_data}
453
+ vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
454
+
455
+ initial_conditions_data = {
456
+ "InitialConditions": {
457
+ "filename": self.filename,
458
+ "ini_time": self.ini_time.isoformat(),
459
+ "model_reference_date": self.model_reference_date.isoformat(),
460
+ "source": self.source,
461
+ }
462
+ }
463
+
464
+ yaml_data = {
465
+ **grid_yaml_data,
466
+ **vertical_coordinate_yaml_data,
467
+ **initial_conditions_data,
468
+ }
469
+
470
+ with open(filepath, "w") as file:
471
+ # Write header
472
+ file.write(header)
473
+ # Write YAML data
474
+ yaml.dump(yaml_data, file, default_flow_style=False)
475
+
476
+ @classmethod
477
+ def from_yaml(cls, filepath: str) -> "InitialConditions":
478
+ """
479
+ Create an instance of the InitialConditions class from a YAML file.
480
+
481
+ Parameters
482
+ ----------
483
+ filepath : str
484
+ The path to the YAML file from which the parameters will be read.
485
+
486
+ Returns
487
+ -------
488
+ InitialConditions
489
+ An instance of the InitialConditions class.
490
+ """
491
+ # Read the entire file content
492
+ with open(filepath, "r") as file:
493
+ file_content = file.read()
494
+
495
+ # Split the content into YAML documents
496
+ documents = list(yaml.safe_load_all(file_content))
497
+
498
+ initial_conditions_data = None
499
+
500
+ # Process the YAML documents
501
+ for doc in documents:
502
+ if doc is None:
503
+ continue
504
+ if "InitialConditions" in doc:
505
+ initial_conditions_data = doc["InitialConditions"]
506
+ break
507
+
508
+ if initial_conditions_data is None:
509
+ raise ValueError(
510
+ "No InitialConditions configuration found in the YAML file."
511
+ )
512
+
513
+ # Convert from string to datetime
514
+ for date_string in ["model_reference_date", "ini_time"]:
515
+ initial_conditions_data[date_string] = datetime.fromisoformat(
516
+ initial_conditions_data[date_string]
517
+ )
518
+
519
+ # Create VerticalCoordinate instance from the YAML file
520
+ vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
521
+ grid = vertical_coordinate.grid
522
+
523
+ # Create and return an instance of InitialConditions
524
+ return cls(
525
+ grid=grid,
526
+ vertical_coordinate=vertical_coordinate,
527
+ **initial_conditions_data,
528
+ )
roms_tools/setup/plot.py CHANGED
@@ -3,11 +3,48 @@ import matplotlib.pyplot as plt
3
3
  import xarray as xr
4
4
 
5
5
 
6
- def _plot(grid_ds, field=None, straddle=False, c="red", kwargs={}):
7
- lon_deg = grid_ds["lon_rho"]
8
- lat_deg = grid_ds["lat_rho"]
6
+ def _plot(
7
+ grid_ds,
8
+ field=None,
9
+ depth_contours=False,
10
+ straddle=False,
11
+ coarse_grid=False,
12
+ c="red",
13
+ title="",
14
+ kwargs={},
15
+ ):
16
+
17
+ if field is None:
18
+ lon_deg = grid_ds["lon_rho"]
19
+ lat_deg = grid_ds["lat_rho"]
20
+
21
+ else:
22
+
23
+ field = field.squeeze()
24
+
25
+ if coarse_grid:
26
+
27
+ field = field.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
28
+ field = field.where(grid_ds.mask_coarse)
29
+ lon_deg = field.lon
30
+ lat_deg = field.lat
31
+
32
+ else:
33
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
34
+ field = field.where(grid_ds.mask_rho)
35
+ lon_deg = grid_ds["lon_rho"]
36
+ lat_deg = grid_ds["lat_rho"]
37
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
38
+ field = field.where(grid_ds.mask_u)
39
+ lon_deg = grid_ds["lon_u"]
40
+ lat_deg = grid_ds["lat_u"]
41
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
42
+ field = field.where(grid_ds.mask_v)
43
+ lon_deg = grid_ds["lon_v"]
44
+ lat_deg = grid_ds["lat_v"]
45
+ else:
46
+ ValueError("provided field does not have two horizontal dimension")
9
47
 
10
- if field is not None:
11
48
  # check if North or South pole are in domain
12
49
  if lat_deg.max().values > 89 or lat_deg.min().values < -89:
13
50
  raise NotImplementedError(
@@ -52,7 +89,115 @@ def _plot(grid_ds, field=None, straddle=False, c="red", kwargs={}):
52
89
  resolution="50m", linewidth=0.5, color="black"
53
90
  ) # add map of coastlines
54
91
  ax.gridlines()
92
+ ax.set_title(title)
55
93
 
56
94
  if field is not None:
57
95
  p = ax.pcolormesh(lon_deg, lat_deg, field, transform=proj, **kwargs)
58
96
  plt.colorbar(p, label=f"{field.long_name} [{field.units}]")
97
+
98
+ if depth_contours:
99
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
100
+ if "layer_depth_rho" in field.coords:
101
+ depth = field.layer_depth_rho
102
+ else:
103
+ depth = field.interface_depth_rho
104
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
105
+ if "layer_depth_u" in field.coords:
106
+ depth = field.layer_depth_u
107
+ else:
108
+ depth = field.interface_depth_u
109
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
110
+ if "layer_depth_v" in field.coords:
111
+ depth = field.layer_depth_v
112
+ else:
113
+ depth = field.interface_depth_v
114
+
115
+ cs = ax.contour(lon_deg, lat_deg, depth, transform=proj, colors="k")
116
+ ax.clabel(cs, inline=True, fontsize=10)
117
+
118
+ return fig
119
+
120
+
121
+ def _section_plot(field, interface_depth=None, title="", kwargs={}):
122
+
123
+ fig, ax = plt.subplots(1, 1, figsize=(9, 5))
124
+
125
+ dims_to_check = ["eta_rho", "eta_u", "eta_v", "xi_rho", "xi_u", "xi_v"]
126
+ try:
127
+ xdim = next(
128
+ dim
129
+ for dim in field.dims
130
+ if any(dim.startswith(prefix) for prefix in dims_to_check)
131
+ )
132
+ except StopIteration:
133
+ raise ValueError(
134
+ "None of the dimensions found in field.dims starts with (eta_rho, eta_u, eta_v, xi_rho, xi_u, xi_v)"
135
+ )
136
+
137
+ depths_to_check = [
138
+ "layer_depth_rho",
139
+ "layer_depth_u",
140
+ "layer_depth_v",
141
+ "interface_depth_rho",
142
+ "interface_depth_u",
143
+ "interface_depth_v",
144
+ ]
145
+ try:
146
+ depth_label = next(
147
+ depth_label
148
+ for depth_label in field.coords
149
+ if any(depth_label.startswith(prefix) for prefix in depths_to_check)
150
+ )
151
+ except StopIteration:
152
+ raise ValueError(
153
+ "None of the coordinates found in field.coords starts with (layer_depth_rho, layer_depth_u, layer_depth_v, interface_depth_rho, interface_depth_u, interface_depth_v)"
154
+ )
155
+
156
+ more_kwargs = {"x": xdim, "y": depth_label, "yincrease": False}
157
+ field.plot(**kwargs, **more_kwargs, ax=ax)
158
+
159
+ if interface_depth is not None:
160
+ layer_key = "s_rho" if "s_rho" in interface_depth else "s_w"
161
+
162
+ for i in range(len(interface_depth[layer_key])):
163
+ ax.plot(
164
+ interface_depth[xdim], interface_depth.isel({layer_key: i}), color="k"
165
+ )
166
+
167
+ ax.set_title(title)
168
+
169
+
170
+ def _profile_plot(field, title=""):
171
+
172
+ depths_to_check = [
173
+ "layer_depth_rho",
174
+ "layer_depth_u",
175
+ "layer_depth_v",
176
+ "interface_depth_rho",
177
+ "interface_depth_u",
178
+ "interface_depth_v",
179
+ ]
180
+ try:
181
+ depth_label = next(
182
+ depth_label
183
+ for depth_label in depths_to_check
184
+ if depth_label in field.coords
185
+ )
186
+ except StopIteration:
187
+ raise ValueError(
188
+ "None of the expected coordinates (layer_depth_rho, layer_depth_u, layer_depth_v, interface_depth_rho, interface_depth_u, interface_depth_v) found in field.coords"
189
+ )
190
+
191
+ fig, ax = plt.subplots(1, 1, figsize=(4, 7))
192
+ kwargs = {"y": depth_label, "yincrease": False}
193
+ field.plot(**kwargs)
194
+ ax.set_title(title)
195
+ ax.grid()
196
+
197
+
198
+ def _line_plot(field, title=""):
199
+
200
+ fig, ax = plt.subplots(1, 1, figsize=(7, 4))
201
+ field.plot()
202
+ ax.set_title(title)
203
+ ax.grid()