roms-tools 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/setup/grid.py CHANGED
@@ -3,21 +3,20 @@ from dataclasses import dataclass, field, asdict
3
3
 
4
4
  import numpy as np
5
5
  import xarray as xr
6
+ import matplotlib.pyplot as plt
6
7
  import yaml
7
8
  import importlib.metadata
8
9
 
9
10
  from roms_tools.setup.topography import _add_topography_and_mask, _add_velocity_masks
10
- from roms_tools.setup.plot import _plot
11
+ from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
11
12
  from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
13
+ from roms_tools.setup.vertical_coordinate import sigma_stretch, compute_depth
12
14
 
13
15
  import warnings
14
16
 
15
17
  RADIUS_OF_EARTH = 6371315.0 # in m
16
18
 
17
19
 
18
- # TODO should we store an xgcm.Grid object instead of an xarray Dataset? Or even subclass xgcm.Grid?
19
-
20
-
21
20
  @dataclass(frozen=True, kw_only=True)
22
21
  class Grid:
23
22
  """
@@ -43,20 +42,19 @@ class Grid:
43
42
  Rotation of grid x-direction from lines of constant latitude, measured in degrees.
44
43
  Positive values represent a counterclockwise rotation.
45
44
  The default is 0, which means that the x-direction of the grid is aligned with lines of constant latitude.
45
+ N : int, optional
46
+ The number of vertical levels. The default is 100.
47
+ theta_s : float, optional
48
+ The surface control parameter. Must satisfy 0 < theta_s <= 10. The default is 5.0.
49
+ theta_b : float, optional
50
+ The bottom control parameter. Must satisfy 0 < theta_b <= 4. The default is 2.0.
51
+ hc : float, optional
52
+ The critical depth (in meters). The default is 300.0.
46
53
  topography_source : str, optional
47
54
  Specifies the data source to use for the topography. Options are
48
55
  "ETOPO5". The default is "ETOPO5".
49
- smooth_factor : float, optional
50
- The smoothing factor used in the domain-wide Gaussian smoothing of the
51
- topography. Smaller values result in less smoothing, while larger
52
- values produce more smoothing. The default is 8.
53
56
  hmin : float, optional
54
- The minimum ocean depth (in meters). The default is 5.
55
- rmax : float, optional
56
- The maximum slope parameter (in meters). This parameter controls
57
- the local smoothing of the topography. Smaller values result in
58
- smoother topography, while larger values preserve more detail.
59
- The default is 0.2.
57
+ The minimum ocean depth (in meters). The default is 5.0.
60
58
 
61
59
  Attributes
62
60
  ----------
@@ -74,14 +72,18 @@ class Grid:
74
72
  Latitude of grid center.
75
73
  rot : float
76
74
  Rotation of grid x-direction from lines of constant latitude.
75
+ N : int
76
+ The number of vertical levels.
77
+ theta_s : float
78
+ The surface control parameter.
79
+ theta_b : float
80
+ The bottom control parameter.
81
+ hc : float
82
+ The critical depth (in meters).
77
83
  topography_source : str
78
84
  Data source used for the topography.
79
- smooth_factor : int
80
- Smoothing factor used in the domain-wide Gaussian smoothing of the topography.
81
85
  hmin : float
82
86
  Minimum ocean depth (in meters).
83
- rmax : float
84
- Maximum slope parameter (in meters).
85
87
  ds : xr.Dataset
86
88
  The xarray Dataset containing the grid data.
87
89
  straddle : bool
@@ -101,10 +103,12 @@ class Grid:
101
103
  center_lon: float
102
104
  center_lat: float
103
105
  rot: float = 0
106
+ N: int = 100
107
+ theta_s: float = 5.0
108
+ theta_b: float = 2.0
109
+ hc: float = 300.0
104
110
  topography_source: str = "ETOPO5"
105
- smooth_factor: int = 8
106
111
  hmin: float = 5.0
107
- rmax: float = 0.2
108
112
  ds: xr.Dataset = field(init=False, repr=False)
109
113
  straddle: bool = field(init=False, repr=False)
110
114
 
@@ -123,21 +127,27 @@ class Grid:
123
127
  object.__setattr__(self, "ds", ds)
124
128
 
125
129
  # Update self.ds with topography and mask information
126
- self.add_topography_and_mask(
130
+ self.update_topography_and_mask(
127
131
  topography_source=self.topography_source,
128
- smooth_factor=self.smooth_factor,
129
132
  hmin=self.hmin,
130
- rmax=self.rmax,
131
133
  )
132
134
 
133
135
  # Check if the Greenwich meridian goes through the domain.
134
136
  self._straddle()
135
137
 
136
- def add_topography_and_mask(
137
- self, topography_source="ETOPO5", smooth_factor=8, hmin=5.0, rmax=0.2
138
- ) -> None:
138
+ ds = _add_lat_lon_at_velocity_points(self.ds, self.straddle)
139
+ object.__setattr__(self, "ds", ds)
140
+
141
+ # Update the grid by adding grid variables that are coarsened versions of the original grid variables
142
+ self._coarsen()
143
+
144
+ self.update_vertical_coordinate(
145
+ N=self.N, theta_s=self.theta_s, theta_b=self.theta_b, hc=self.hc
146
+ )
147
+
148
+ def update_topography_and_mask(self, hmin, topography_source="ETOPO5") -> None:
139
149
  """
140
- Add topography and mask to the grid dataset.
150
+ Update the grid dataset by adding or overwriting the topography and land/sea mask.
141
151
 
142
152
  This method processes the topography data and generates a land/sea mask.
143
153
  It applies several steps, including interpolating topography, smoothing
@@ -146,20 +156,11 @@ class Grid:
146
156
 
147
157
  Parameters
148
158
  ----------
149
- topography_source : str, optional
159
+ hmin : float
160
+ The minimum ocean depth (in meters).
161
+ topography_source : str
150
162
  Specifies the data source to use for the topography. Options are
151
- "ETOPO5". The default is "ETOPO5".
152
- smooth_factor : float, optional
153
- The smoothing factor used in the domain-wide Gaussian smoothing of the
154
- topography. Smaller values result in less smoothing, while larger
155
- values produce more smoothing. The default is 8.
156
- hmin : float, optional
157
- The minimum ocean depth (in meters). The default is 5.
158
- rmax : float, optional
159
- The maximum slope parameter (in meters). This parameter controls
160
- the local smoothing of the topography. Smaller values result in
161
- smoother topography, while larger values preserve more detail.
162
- The default is 0.2.
163
+ "ETOPO5". Default is "ETOPO5".
163
164
 
164
165
  Returns
165
166
  -------
@@ -167,77 +168,346 @@ class Grid:
167
168
  This method modifies the dataset in place and does not return a value.
168
169
  """
169
170
 
170
- ds = _add_topography_and_mask(
171
- self.ds, topography_source, smooth_factor, hmin, rmax
172
- )
171
+ ds = _add_topography_and_mask(self.ds, topography_source, hmin)
173
172
  # Assign the updated dataset back to the frozen dataclass
174
173
  object.__setattr__(self, "ds", ds)
174
+ object.__setattr__(self, "topography_source", topography_source)
175
+ object.__setattr__(self, "hmin", hmin)
175
176
 
176
- def compute_bathymetry_laplacian(self):
177
+ def _straddle(self) -> None:
177
178
  """
178
- Compute the Laplacian of the 'h' field in the provided grid dataset.
179
+ Check if the Greenwich meridian goes through the domain.
180
+
181
+ This method sets the `straddle` attribute to `True` if the Greenwich meridian
182
+ (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
183
+ the `straddle` attribute to `False`.
179
184
 
180
- Adds:
181
- xarray.DataArray: The Laplacian of the 'h' field as a new variable in the dataset self.ds.
185
+ The check is based on whether the longitudinal differences between adjacent
186
+ points exceed 300 degrees, indicating a potential wraparound of longitude.
182
187
  """
183
188
 
184
- # Extract the 'h' field and grid spacing variables
185
- h = self.ds.h
186
- pm = self.ds.pm # Reciprocal of grid spacing in x-direction
187
- pn = self.ds.pn # Reciprocal of grid spacing in y-direction
189
+ if (
190
+ np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
191
+ or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
192
+ ):
193
+ object.__setattr__(self, "straddle", True)
194
+ else:
195
+ object.__setattr__(self, "straddle", False)
188
196
 
189
- # Compute second derivatives using finite differences
190
- d2h_dx2 = (h.shift(xi_rho=-1) - 2 * h + h.shift(xi_rho=1)) * pm**2
191
- d2h_dy2 = (h.shift(eta_rho=-1) - 2 * h + h.shift(eta_rho=1)) * pn**2
197
+ def _coarsen(self):
198
+ """
199
+ Update the grid by adding grid variables that are coarsened versions of the original
200
+ fine-resoluion grid variables. The coarsening is by a factor of two.
192
201
 
193
- # Compute the Laplacian by summing second derivatives
194
- laplacian_h = d2h_dx2 + d2h_dy2
202
+ The specific variables being coarsened are:
203
+ - `lon_rho` -> `lon_coarse`: Longitude at rho points.
204
+ - `lat_rho` -> `lat_coarse`: Latitude at rho points.
205
+ - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
206
+ - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
195
207
 
196
- # Add the Laplacian as a new variable in the dataset
197
- self.ds["h_laplacian"] = laplacian_h
198
- self.ds["h_laplacian"].attrs["long_name"] = "Laplacian of final bathymetry"
199
- self.ds["h_laplacian"].attrs["units"] = "1/m"
208
+ Returns
209
+ -------
210
+ None
200
211
 
201
- def save(self, filepath: str) -> None:
212
+ Modifies
213
+ --------
214
+ self.ds : xr.Dataset
215
+ The dataset attribute of the Grid instance is updated with the new coarser variables.
202
216
  """
203
- Save the grid information to a netCDF4 file.
217
+ d = {
218
+ "angle": "angle_coarse",
219
+ "mask_rho": "mask_coarse",
220
+ "lat_rho": "lat_coarse",
221
+ "lon_rho": "lon_coarse",
222
+ }
223
+
224
+ for fine_var, coarse_var in d.items():
225
+ fine_field = self.ds[fine_var]
226
+ if self.straddle and fine_var == "lon_rho":
227
+ fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
228
+
229
+ coarse_field = _f2c(fine_field)
230
+ if fine_var == "lon_rho":
231
+ coarse_field = xr.where(
232
+ coarse_field < 0, coarse_field + 360, coarse_field
233
+ )
234
+ if coarse_var in ["lon_coarse", "lat_coarse"]:
235
+ ds = self.ds.assign_coords({coarse_var: coarse_field})
236
+ object.__setattr__(self, "ds", ds)
237
+ else:
238
+ self.ds[coarse_var] = coarse_field
239
+ self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0).astype(
240
+ np.int32
241
+ )
242
+
243
+ def update_vertical_coordinate(self, N, theta_s, theta_b, hc) -> None:
244
+ """
245
+ Create vertical coordinate variables for the ROMS grid.
246
+
247
+ This method computes the S-coordinate stretching curves and depths
248
+ at various grid points (rho, u, v) using the specified parameters.
249
+ The computed depths and stretching curves are added to the dataset
250
+ as new coordinates, along with their corresponding attributes.
204
251
 
205
252
  Parameters
206
253
  ----------
207
- filepath
254
+ N : int
255
+ Number of vertical levels.
256
+ theta_s : float
257
+ S-coordinate surface control parameter.
258
+ theta_b : float
259
+ S-coordinate bottom control parameter.
260
+ hc : float
261
+ Critical depth (m) used in ROMS vertical coordinate stretching.
262
+
263
+ Returns
264
+ -------
265
+ None
266
+ This method modifies the dataset in place by adding vertical coordinate variables.
208
267
  """
209
- self.ds.to_netcdf(filepath)
210
268
 
211
- def to_yaml(self, filepath: str) -> None:
269
+ ds = self.ds
270
+ # need to drop vertical coordinates because they could cause conflict if N changed
271
+ vars_to_drop = [
272
+ "layer_depth_rho",
273
+ "layer_depth_u",
274
+ "layer_depth_v",
275
+ "interface_depth_rho",
276
+ "interface_depth_u",
277
+ "interface_depth_v",
278
+ "sc_r",
279
+ "Cs_r",
280
+ ]
281
+
282
+ for var in vars_to_drop:
283
+ if var in ds.variables:
284
+ ds = ds.drop_vars(var)
285
+
286
+ h = ds.h
287
+
288
+ cs_r, sigma_r = sigma_stretch(theta_s, theta_b, N, "r")
289
+ zr = compute_depth(h * 0, h, hc, cs_r, sigma_r)
290
+ cs_w, sigma_w = sigma_stretch(theta_s, theta_b, N, "w")
291
+ zw = compute_depth(h * 0, h, hc, cs_w, sigma_w)
292
+
293
+ ds["sc_r"] = sigma_r.astype(np.float32)
294
+ ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
295
+ ds["sc_r"].attrs["units"] = "nondimensional"
296
+
297
+ ds["Cs_r"] = cs_r.astype(np.float32)
298
+ ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
299
+ ds["Cs_r"].attrs["units"] = "nondimensional"
300
+
301
+ ds.attrs["theta_s"] = np.float32(theta_s)
302
+ ds.attrs["theta_b"] = np.float32(theta_b)
303
+ ds.attrs["hc"] = np.float32(hc)
304
+
305
+ depth = -zr
306
+ depth.attrs["long_name"] = "Layer depth at rho-points"
307
+ depth.attrs["units"] = "m"
308
+
309
+ depth_u = interpolate_from_rho_to_u(depth)
310
+ depth_u.attrs["long_name"] = "Layer depth at u-points"
311
+ depth_u.attrs["units"] = "m"
312
+
313
+ depth_v = interpolate_from_rho_to_v(depth)
314
+ depth_v.attrs["long_name"] = "Layer depth at v-points"
315
+ depth_v.attrs["units"] = "m"
316
+
317
+ interface_depth = -zw
318
+ interface_depth.attrs["long_name"] = "Interface depth at rho-points"
319
+ interface_depth.attrs["units"] = "m"
320
+
321
+ interface_depth_u = interpolate_from_rho_to_u(interface_depth)
322
+ interface_depth_u.attrs["long_name"] = "Interface depth at u-points"
323
+ interface_depth_u.attrs["units"] = "m"
324
+
325
+ interface_depth_v = interpolate_from_rho_to_v(interface_depth)
326
+ interface_depth_v.attrs["long_name"] = "Interface depth at v-points"
327
+ interface_depth_v.attrs["units"] = "m"
328
+
329
+ ds = ds.assign_coords(
330
+ {
331
+ "layer_depth_rho": depth.astype(np.float32),
332
+ "layer_depth_u": depth_u.astype(np.float32),
333
+ "layer_depth_v": depth_v.astype(np.float32),
334
+ "interface_depth_rho": interface_depth.astype(np.float32),
335
+ "interface_depth_u": interface_depth_u.astype(np.float32),
336
+ "interface_depth_v": interface_depth_v.astype(np.float32),
337
+ }
338
+ )
339
+ ds = ds.drop_vars(["eta_rho", "xi_rho"])
340
+
341
+ object.__setattr__(self, "ds", ds)
342
+ object.__setattr__(self, "theta_s", theta_s)
343
+ object.__setattr__(self, "theta_b", theta_b)
344
+ object.__setattr__(self, "hc", hc)
345
+ object.__setattr__(self, "N", N)
346
+
347
+ def plot(self, bathymetry: bool = False) -> None:
212
348
  """
213
- Export the parameters of the class to a YAML file, including the version of roms-tools.
349
+ Plot the grid.
214
350
 
215
351
  Parameters
216
352
  ----------
217
- filepath : str
218
- The path to the YAML file where the parameters will be saved.
353
+ bathymetry : bool
354
+ Whether or not to plot the bathymetry. Default is False.
355
+
356
+ Returns
357
+ -------
358
+ None
359
+ This method does not return any value. It generates and displays a plot.
360
+
219
361
  """
220
- data = asdict(self)
221
- data.pop("ds", None)
222
- data.pop("straddle", None)
223
362
 
224
- # Include the version of roms-tools
225
- try:
226
- roms_tools_version = importlib.metadata.version("roms-tools")
227
- except importlib.metadata.PackageNotFoundError:
228
- roms_tools_version = "unknown"
363
+ if bathymetry:
364
+ kwargs = {"cmap": "YlGnBu"}
229
365
 
230
- # Create header
231
- header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
366
+ _plot(
367
+ self.ds,
368
+ field=self.ds.h.where(self.ds.mask_rho),
369
+ straddle=self.straddle,
370
+ kwargs=kwargs,
371
+ )
372
+ else:
373
+ _plot(self.ds, straddle=self.straddle)
232
374
 
233
- # Use the class name as the top-level key
234
- yaml_data = {self.__class__.__name__: data}
375
+ def plot_vertical_coordinate(
376
+ self,
377
+ varname="layer_depth_rho",
378
+ s=None,
379
+ eta=None,
380
+ xi=None,
381
+ ) -> None:
382
+ """
383
+ Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
235
384
 
236
- with open(filepath, "w") as file:
237
- # Write header
238
- file.write(header)
239
- # Write YAML data
240
- yaml.dump(yaml_data, file, default_flow_style=False)
385
+ Parameters
386
+ ----------
387
+ varname : str, optional
388
+ The vertical coordinate field to plot. Options include:
389
+ - "layer_depth_rho": Layer depth at rho-points.
390
+ - "layer_depth_u": Layer depth at u-points.
391
+ - "layer_depth_v": Layer depth at v-points.
392
+ - "interface_depth_rho": Interface depth at rho-points.
393
+ - "interface_depth_u": Interface depth at u-points.
394
+ - "interface_depth_v": Interface depth at v-points.
395
+ s: int, optional
396
+ The s-index to plot. Default is None.
397
+ eta : int, optional
398
+ The eta-index to plot. Default is None.
399
+ xi : int, optional
400
+ The xi-index to plot. Default is None.
401
+
402
+ Returns
403
+ -------
404
+ None
405
+ This method does not return any value. It generates and displays a plot.
406
+
407
+ Raises
408
+ ------
409
+ ValueError
410
+ If the specified varname is not one of the valid options.
411
+ If none of s, eta, xi are specified.
412
+ """
413
+
414
+ if not any([s is not None, eta is not None, xi is not None]):
415
+ raise ValueError("At least one of s, eta, or xi must be specified.")
416
+
417
+ self.ds[varname].load()
418
+ field = self.ds[varname].squeeze()
419
+
420
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
421
+ interface_depth = self.ds.interface_depth_rho
422
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
423
+ interface_depth = self.ds.interface_depth_u
424
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
425
+ interface_depth = self.ds.interface_depth_v
426
+
427
+ # slice the field as desired
428
+ title = field.long_name
429
+ if s is not None:
430
+ if "s_rho" in field.dims:
431
+ title = title + f", s_rho = {field.s_rho[s].item()}"
432
+ field = field.isel(s_rho=s)
433
+ elif "s_w" in field.dims:
434
+ title = title + f", s_w = {field.s_w[s].item()}"
435
+ field = field.isel(s_w=s)
436
+ else:
437
+ raise ValueError(
438
+ f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
439
+ )
440
+
441
+ if eta is not None:
442
+ if "eta_rho" in field.dims:
443
+ title = title + f", eta_rho = {field.eta_rho[eta].item()}"
444
+ field = field.isel(eta_rho=eta)
445
+ interface_depth = interface_depth.isel(eta_rho=eta)
446
+ elif "eta_v" in field.dims:
447
+ title = title + f", eta_v = {field.eta_v[eta].item()}"
448
+ field = field.isel(eta_v=eta)
449
+ interface_depth = interface_depth.isel(eta_v=eta)
450
+ else:
451
+ raise ValueError(
452
+ f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
453
+ )
454
+ if xi is not None:
455
+ if "xi_rho" in field.dims:
456
+ title = title + f", xi_rho = {field.xi_rho[xi].item()}"
457
+ field = field.isel(xi_rho=xi)
458
+ interface_depth = interface_depth.isel(xi_rho=xi)
459
+ elif "xi_u" in field.dims:
460
+ title = title + f", xi_u = {field.xi_u[xi].item()}"
461
+ field = field.isel(xi_u=xi)
462
+ interface_depth = interface_depth.isel(xi_u=xi)
463
+ else:
464
+ raise ValueError(
465
+ f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
466
+ )
467
+
468
+ if eta is None and xi is None:
469
+ vmax = field.max().values
470
+ vmin = field.min().values
471
+ cmap = plt.colormaps.get_cmap("YlGnBu")
472
+ cmap.set_bad(color="gray")
473
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
474
+
475
+ _plot(
476
+ self.ds,
477
+ field=field,
478
+ straddle=self.straddle,
479
+ depth_contours=True,
480
+ title=title,
481
+ kwargs=kwargs,
482
+ c="g",
483
+ )
484
+ else:
485
+ if len(field.dims) == 2:
486
+ cmap = plt.colormaps.get_cmap("YlGnBu")
487
+ cmap.set_bad(color="gray")
488
+ kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
489
+
490
+ _section_plot(
491
+ xr.zeros_like(field),
492
+ interface_depth=interface_depth,
493
+ title=title,
494
+ kwargs=kwargs,
495
+ )
496
+ else:
497
+ if "s_rho" in field.dims or "s_w" in field.dims:
498
+ _profile_plot(field, title=title)
499
+ else:
500
+ _line_plot(field, title=title)
501
+
502
+ def save(self, filepath: str) -> None:
503
+ """
504
+ Save the grid information to a netCDF4 file.
505
+
506
+ Parameters
507
+ ----------
508
+ filepath
509
+ """
510
+ self.ds.to_netcdf(filepath)
241
511
 
242
512
  @classmethod
243
513
  def from_file(cls, filepath: str) -> "Grid":
@@ -259,8 +529,6 @@ class Grid:
259
529
 
260
530
  if not all(mask in ds for mask in ["mask_u", "mask_v"]):
261
531
  ds = _add_velocity_masks(ds)
262
- if not all(coord in ds for coord in ["lat_u", "lon_u", "lat_v", "lon_v"]):
263
- ds = _add_lat_lon_at_velocity_points(ds)
264
532
 
265
533
  # Create a new Grid instance without calling __init__ and __post_init__
266
534
  grid = cls.__new__(cls)
@@ -271,26 +539,82 @@ class Grid:
271
539
  # Check if the Greenwich meridian goes through the domain.
272
540
  grid._straddle()
273
541
 
542
+ if not all(coord in grid.ds for coord in ["lat_u", "lon_u", "lat_v", "lon_v"]):
543
+ ds = _add_lat_lon_at_velocity_points(grid.ds, grid.straddle)
544
+ object.__setattr__(grid, "ds", ds)
545
+
546
+ # Coarsen the grid if necessary
547
+ if not all(
548
+ var in grid.ds
549
+ for var in [
550
+ "lon_coarse",
551
+ "lat_coarse",
552
+ "angle_coarse",
553
+ "mask_coarse",
554
+ ]
555
+ ):
556
+ grid._coarsen()
557
+
558
+ # Update vertical coordinate if necessary
559
+ if not all(var in grid.ds for var in ["sc_r", "Cs_r"]):
560
+ N = 100
561
+ theta_s = 5.0
562
+ theta_b = 2.0
563
+ hc = 300.0
564
+
565
+ grid.update_vertical_coordinate(
566
+ N=N, theta_s=theta_s, theta_b=theta_b, hc=hc
567
+ )
568
+
274
569
  # Manually set the remaining attributes by extracting parameters from dataset
275
570
  object.__setattr__(grid, "nx", ds.sizes["xi_rho"] - 2)
276
571
  object.__setattr__(grid, "ny", ds.sizes["eta_rho"] - 2)
277
- object.__setattr__(grid, "center_lon", ds["tra_lon"].values.item())
278
- object.__setattr__(grid, "center_lat", ds["tra_lat"].values.item())
279
- object.__setattr__(grid, "rot", ds["rotate"].values.item())
572
+ object.__setattr__(grid, "center_lon", ds.attrs["center_lon"])
573
+ object.__setattr__(grid, "center_lat", ds.attrs["center_lat"])
574
+ object.__setattr__(grid, "rot", ds.attrs["rot"])
280
575
 
281
576
  for attr in [
282
577
  "size_x",
283
578
  "size_y",
284
579
  "topography_source",
285
- "smooth_factor",
286
580
  "hmin",
287
- "rmax",
288
581
  ]:
289
582
  if attr in ds.attrs:
290
583
  object.__setattr__(grid, attr, ds.attrs[attr])
291
584
 
292
585
  return grid
293
586
 
587
+ def to_yaml(self, filepath: str) -> None:
588
+ """
589
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
590
+
591
+ Parameters
592
+ ----------
593
+ filepath : str
594
+ The path to the YAML file where the parameters will be saved.
595
+ """
596
+ data = asdict(self)
597
+ data.pop("ds", None)
598
+ data.pop("straddle", None)
599
+
600
+ # Include the version of roms-tools
601
+ try:
602
+ roms_tools_version = importlib.metadata.version("roms-tools")
603
+ except importlib.metadata.PackageNotFoundError:
604
+ roms_tools_version = "unknown"
605
+
606
+ # Create header
607
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
608
+
609
+ # Use the class name as the top-level key
610
+ yaml_data = {self.__class__.__name__: data}
611
+
612
+ with open(filepath, "w") as file:
613
+ # Write header
614
+ file.write(header)
615
+ # Write YAML data
616
+ yaml.dump(yaml_data, file, default_flow_style=False)
617
+
294
618
  @classmethod
295
619
  def from_yaml(cls, filepath: str) -> "Grid":
296
620
  """
@@ -358,102 +682,6 @@ class Grid:
358
682
  attr_str = ", ".join(f"{k}={v!r}" for k, v in attr_dict.items())
359
683
  return f"{cls_name}({attr_str})"
360
684
 
361
- # def to_xgcm() -> Any:
362
- # # TODO we could convert the dataset to an xgcm.Grid object and return here?
363
- # raise NotImplementedError()
364
-
365
- def _straddle(self) -> None:
366
- """
367
- Check if the Greenwich meridian goes through the domain.
368
-
369
- This method sets the `straddle` attribute to `True` if the Greenwich meridian
370
- (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
371
- the `straddle` attribute to `False`.
372
-
373
- The check is based on whether the longitudinal differences between adjacent
374
- points exceed 300 degrees, indicating a potential wraparound of longitude.
375
- """
376
-
377
- if (
378
- np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
379
- or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
380
- ):
381
- object.__setattr__(self, "straddle", True)
382
- else:
383
- object.__setattr__(self, "straddle", False)
384
-
385
- def plot(self, bathymetry: bool = False) -> None:
386
- """
387
- Plot the grid.
388
-
389
- Parameters
390
- ----------
391
- bathymetry : bool
392
- Whether or not to plot the bathymetry. Default is False.
393
-
394
- Returns
395
- -------
396
- None
397
- This method does not return any value. It generates and displays a plot.
398
-
399
- """
400
-
401
- if bathymetry:
402
- kwargs = {"cmap": "YlGnBu"}
403
-
404
- _plot(
405
- self.ds,
406
- field=self.ds.h.where(self.ds.mask_rho),
407
- straddle=self.straddle,
408
- kwargs=kwargs,
409
- )
410
- else:
411
- _plot(self.ds, straddle=self.straddle)
412
-
413
- def coarsen(self):
414
- """
415
- Update the grid by adding grid variables that are coarsened versions of the original
416
- fine-resoluion grid variables. The coarsening is by a factor of two.
417
-
418
- The specific variables being coarsened are:
419
- - `lon_rho` -> `lon_coarse`: Longitude at rho points.
420
- - `lat_rho` -> `lat_coarse`: Latitude at rho points.
421
- - `h` -> `h_coarse`: Bathymetry (depth).
422
- - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
423
- - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
424
-
425
- Returns
426
- -------
427
- None
428
-
429
- Modifies
430
- --------
431
- self.ds : xr.Dataset
432
- The dataset attribute of the Grid instance is updated with the new coarser variables.
433
- """
434
- d = {
435
- "lon_rho": "lon_coarse",
436
- "lat_rho": "lat_coarse",
437
- "h": "h_coarse",
438
- "angle": "angle_coarse",
439
- "mask_rho": "mask_coarse",
440
- }
441
-
442
- for fine_var, coarse_var in d.items():
443
- fine_field = self.ds[fine_var]
444
- if self.straddle and fine_var == "lon_rho":
445
- fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
446
-
447
- coarse_field = _f2c(fine_field)
448
- if fine_var == "lon_rho":
449
- coarse_field = xr.where(
450
- coarse_field < 0, coarse_field + 360, coarse_field
451
- )
452
-
453
- self.ds[coarse_var] = coarse_field
454
-
455
- self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0)
456
-
457
685
 
458
686
  def _make_grid_ds(
459
687
  nx: int,
@@ -484,7 +712,7 @@ def _make_grid_ds(
484
712
 
485
713
  ds = _create_grid_ds(lon, lat, pm, pn, ang, rot, center_lon, center_lat)
486
714
 
487
- ds = _add_global_metadata(ds, size_x, size_y)
715
+ ds = _add_global_metadata(ds, size_x, size_y, center_lon, center_lat, rot)
488
716
 
489
717
  return ds
490
718
 
@@ -777,14 +1005,21 @@ def _create_grid_ds(
777
1005
  center_lon,
778
1006
  center_lat,
779
1007
  ):
780
- # Create xarray.Dataset object with lat_rho and lon_rho as coordinates
781
- ds = xr.Dataset(
782
- coords={
783
- "lat_rho": (("eta_rho", "xi_rho"), lat * 180 / np.pi),
784
- "lon_rho": (("eta_rho", "xi_rho"), lon * 180 / np.pi),
785
- }
1008
+ ds = xr.Dataset()
1009
+
1010
+ lon_rho = xr.Variable(
1011
+ data=lon * 180 / np.pi,
1012
+ dims=["eta_rho", "xi_rho"],
1013
+ attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
1014
+ )
1015
+ lat_rho = xr.Variable(
1016
+ data=lat * 180 / np.pi,
1017
+ dims=["eta_rho", "xi_rho"],
1018
+ attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
786
1019
  )
787
1020
 
1021
+ ds = ds.assign_coords({"lat_rho": lat_rho, "lon_rho": lon_rho})
1022
+
788
1023
  ds["angle"] = xr.Variable(
789
1024
  data=angle,
790
1025
  dims=["eta_rho", "xi_rho"],
@@ -799,6 +1034,7 @@ def _create_grid_ds(
799
1034
  dims=["eta_rho", "xi_rho"],
800
1035
  attrs={"long_name": "Coriolis parameter at rho-points", "units": "second-1"},
801
1036
  )
1037
+
802
1038
  ds["pm"] = xr.Variable(
803
1039
  data=pm,
804
1040
  dims=["eta_rho", "xi_rho"],
@@ -816,36 +1052,15 @@ def _create_grid_ds(
816
1052
  },
817
1053
  )
818
1054
 
819
- ds["tra_lon"] = center_lon
820
- ds["tra_lon"].attrs["long_name"] = "Longitudinal translation of base grid"
821
- ds["tra_lon"].attrs["units"] = "degrees East"
822
-
823
- ds["tra_lat"] = center_lat
824
- ds["tra_lat"].attrs["long_name"] = "Latitudinal translation of base grid"
825
- ds["tra_lat"].attrs["units"] = "degrees North"
826
-
827
- ds["rotate"] = rot
828
- ds["rotate"].attrs["long_name"] = "Rotation of base grid"
829
- ds["rotate"].attrs["units"] = "degrees"
830
-
831
- ds["lon_rho"] = xr.Variable(
832
- data=lon * 180 / np.pi,
833
- dims=["eta_rho", "xi_rho"],
834
- attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
835
- )
836
-
837
- ds["lat_rho"] = xr.Variable(
838
- data=lat * 180 / np.pi,
839
- dims=["eta_rho", "xi_rho"],
840
- attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
841
- )
1055
+ return ds
842
1056
 
843
- ds = _add_lat_lon_at_velocity_points(ds)
844
1057
 
845
- return ds
1058
+ def _add_global_metadata(ds, size_x, size_y, center_lon, center_lat, rot):
846
1059
 
1060
+ ds["spherical"] = xr.DataArray(np.array("T", dtype="S1"))
1061
+ ds["spherical"].attrs["Long_name"] = "Grid type logical switch"
1062
+ ds["spherical"].attrs["option_T"] = "spherical"
847
1063
 
848
- def _add_global_metadata(ds, size_x, size_y):
849
1064
  ds.attrs["title"] = "ROMS grid created by ROMS-Tools"
850
1065
 
851
1066
  # Include the version of roms-tools
@@ -857,6 +1072,9 @@ def _add_global_metadata(ds, size_x, size_y):
857
1072
  ds.attrs["roms_tools_version"] = roms_tools_version
858
1073
  ds.attrs["size_x"] = size_x
859
1074
  ds.attrs["size_y"] = size_y
1075
+ ds.attrs["center_lon"] = center_lon
1076
+ ds.attrs["center_lat"] = center_lat
1077
+ ds.attrs["rot"] = rot
860
1078
 
861
1079
  return ds
862
1080
 
@@ -916,12 +1134,24 @@ def _f2c_xdir(f):
916
1134
  return fc
917
1135
 
918
1136
 
919
- def _add_lat_lon_at_velocity_points(ds):
1137
+ def _add_lat_lon_at_velocity_points(ds, straddle):
920
1138
 
921
- lat_u = interpolate_from_rho_to_u(ds["lat_rho"])
922
- lon_u = interpolate_from_rho_to_u(ds["lon_rho"])
923
- lat_v = interpolate_from_rho_to_v(ds["lat_rho"])
924
- lon_v = interpolate_from_rho_to_v(ds["lon_rho"])
1139
+ if straddle:
1140
+ # avoid jump from 360 to 0 in interpolation
1141
+ lon_rho = xr.where(ds["lon_rho"] > 180, ds["lon_rho"] - 360, ds["lon_rho"])
1142
+ else:
1143
+ lon_rho = ds["lon_rho"]
1144
+ lat_rho = ds["lat_rho"]
1145
+
1146
+ lat_u = interpolate_from_rho_to_u(lat_rho)
1147
+ lon_u = interpolate_from_rho_to_u(lon_rho)
1148
+ lat_v = interpolate_from_rho_to_v(lat_rho)
1149
+ lon_v = interpolate_from_rho_to_v(lon_rho)
1150
+
1151
+ if straddle:
1152
+ # convert back to range [0, 360]
1153
+ lon_u = xr.where(lon_u < 0, lon_u + 360, lon_u)
1154
+ lon_v = xr.where(lon_v < 0, lon_v + 360, lon_v)
925
1155
 
926
1156
  lat_u.attrs = {"long_name": "latitude of u-points", "units": "degrees North"}
927
1157
  lon_u.attrs = {"long_name": "longitude of u-points", "units": "degrees East"}
@@ -929,7 +1159,12 @@ def _add_lat_lon_at_velocity_points(ds):
929
1159
  lon_v.attrs = {"long_name": "longitude of v-points", "units": "degrees East"}
930
1160
 
931
1161
  ds = ds.assign_coords(
932
- {"lat_u": lat_u, "lon_u": lon_u, "lat_v": lat_v, "lon_v": lon_v}
1162
+ {
1163
+ "lat_u": lat_u,
1164
+ "lon_u": lon_u,
1165
+ "lat_v": lat_v,
1166
+ "lon_v": lon_v,
1167
+ }
933
1168
  )
934
1169
 
935
1170
  return ds