roms-tools 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/setup/grid.py CHANGED
@@ -3,21 +3,20 @@ from dataclasses import dataclass, field, asdict
3
3
 
4
4
  import numpy as np
5
5
  import xarray as xr
6
+ import matplotlib.pyplot as plt
6
7
  import yaml
7
8
  import importlib.metadata
8
9
 
9
10
  from roms_tools.setup.topography import _add_topography_and_mask, _add_velocity_masks
10
- from roms_tools.setup.plot import _plot
11
+ from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
11
12
  from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
13
+ from roms_tools.setup.vertical_coordinate import sigma_stretch, compute_depth
12
14
 
13
15
  import warnings
14
16
 
15
17
  RADIUS_OF_EARTH = 6371315.0 # in m
16
18
 
17
19
 
18
- # TODO should we store an xgcm.Grid object instead of an xarray Dataset? Or even subclass xgcm.Grid?
19
-
20
-
21
20
  @dataclass(frozen=True, kw_only=True)
22
21
  class Grid:
23
22
  """
@@ -43,20 +42,19 @@ class Grid:
43
42
  Rotation of grid x-direction from lines of constant latitude, measured in degrees.
44
43
  Positive values represent a counterclockwise rotation.
45
44
  The default is 0, which means that the x-direction of the grid is aligned with lines of constant latitude.
45
+ N : int, optional
46
+ The number of vertical levels. The default is 100.
47
+ theta_s : float, optional
48
+ The surface control parameter. Must satisfy 0 < theta_s <= 10. The default is 5.0.
49
+ theta_b : float, optional
50
+ The bottom control parameter. Must satisfy 0 < theta_b <= 4. The default is 2.0.
51
+ hc : float, optional
52
+ The critical depth (in meters). The default is 300.0.
46
53
  topography_source : str, optional
47
54
  Specifies the data source to use for the topography. Options are
48
55
  "ETOPO5". The default is "ETOPO5".
49
- smooth_factor : float, optional
50
- The smoothing factor used in the domain-wide Gaussian smoothing of the
51
- topography. Smaller values result in less smoothing, while larger
52
- values produce more smoothing. The default is 8.
53
56
  hmin : float, optional
54
- The minimum ocean depth (in meters). The default is 5.
55
- rmax : float, optional
56
- The maximum slope parameter (in meters). This parameter controls
57
- the local smoothing of the topography. Smaller values result in
58
- smoother topography, while larger values preserve more detail.
59
- The default is 0.2.
57
+ The minimum ocean depth (in meters). The default is 5.0.
60
58
 
61
59
  Attributes
62
60
  ----------
@@ -74,14 +72,18 @@ class Grid:
74
72
  Latitude of grid center.
75
73
  rot : float
76
74
  Rotation of grid x-direction from lines of constant latitude.
75
+ N : int
76
+ The number of vertical levels.
77
+ theta_s : float
78
+ The surface control parameter.
79
+ theta_b : float
80
+ The bottom control parameter.
81
+ hc : float
82
+ The critical depth (in meters).
77
83
  topography_source : str
78
84
  Data source used for the topography.
79
- smooth_factor : int
80
- Smoothing factor used in the domain-wide Gaussian smoothing of the topography.
81
85
  hmin : float
82
86
  Minimum ocean depth (in meters).
83
- rmax : float
84
- Maximum slope parameter (in meters).
85
87
  ds : xr.Dataset
86
88
  The xarray Dataset containing the grid data.
87
89
  straddle : bool
@@ -101,10 +103,12 @@ class Grid:
101
103
  center_lon: float
102
104
  center_lat: float
103
105
  rot: float = 0
106
+ N: int = 100
107
+ theta_s: float = 5.0
108
+ theta_b: float = 2.0
109
+ hc: float = 300.0
104
110
  topography_source: str = "ETOPO5"
105
- smooth_factor: int = 8
106
111
  hmin: float = 5.0
107
- rmax: float = 0.2
108
112
  ds: xr.Dataset = field(init=False, repr=False)
109
113
  straddle: bool = field(init=False, repr=False)
110
114
 
@@ -123,21 +127,27 @@ class Grid:
123
127
  object.__setattr__(self, "ds", ds)
124
128
 
125
129
  # Update self.ds with topography and mask information
126
- self.add_topography_and_mask(
130
+ self.update_topography_and_mask(
127
131
  topography_source=self.topography_source,
128
- smooth_factor=self.smooth_factor,
129
132
  hmin=self.hmin,
130
- rmax=self.rmax,
131
133
  )
132
134
 
133
135
  # Check if the Greenwich meridian goes through the domain.
134
136
  self._straddle()
135
137
 
136
- def add_topography_and_mask(
137
- self, topography_source="ETOPO5", smooth_factor=8, hmin=5.0, rmax=0.2
138
- ) -> None:
138
+ ds = _add_lat_lon_at_velocity_points(self.ds, self.straddle)
139
+ object.__setattr__(self, "ds", ds)
140
+
141
+ # Update the grid by adding grid variables that are coarsened versions of the original grid variables
142
+ self._coarsen()
143
+
144
+ self.update_vertical_coordinate(
145
+ N=self.N, theta_s=self.theta_s, theta_b=self.theta_b, hc=self.hc
146
+ )
147
+
148
+ def update_topography_and_mask(self, hmin, topography_source="ETOPO5") -> None:
139
149
  """
140
- Add topography and mask to the grid dataset.
150
+ Update the grid dataset by adding or overwriting the topography and land/sea mask.
141
151
 
142
152
  This method processes the topography data and generates a land/sea mask.
143
153
  It applies several steps, including interpolating topography, smoothing
@@ -146,20 +156,11 @@ class Grid:
146
156
 
147
157
  Parameters
148
158
  ----------
149
- topography_source : str, optional
159
+ hmin : float
160
+ The minimum ocean depth (in meters).
161
+ topography_source : str
150
162
  Specifies the data source to use for the topography. Options are
151
- "ETOPO5". The default is "ETOPO5".
152
- smooth_factor : float, optional
153
- The smoothing factor used in the domain-wide Gaussian smoothing of the
154
- topography. Smaller values result in less smoothing, while larger
155
- values produce more smoothing. The default is 8.
156
- hmin : float, optional
157
- The minimum ocean depth (in meters). The default is 5.
158
- rmax : float, optional
159
- The maximum slope parameter (in meters). This parameter controls
160
- the local smoothing of the topography. Smaller values result in
161
- smoother topography, while larger values preserve more detail.
162
- The default is 0.2.
163
+ "ETOPO5". Default is "ETOPO5".
163
164
 
164
165
  Returns
165
166
  -------
@@ -167,77 +168,362 @@ class Grid:
167
168
  This method modifies the dataset in place and does not return a value.
168
169
  """
169
170
 
170
- ds = _add_topography_and_mask(
171
- self.ds, topography_source, smooth_factor, hmin, rmax
172
- )
171
+ ds = _add_topography_and_mask(self.ds, topography_source, hmin)
173
172
  # Assign the updated dataset back to the frozen dataclass
174
173
  object.__setattr__(self, "ds", ds)
174
+ object.__setattr__(self, "topography_source", topography_source)
175
+ object.__setattr__(self, "hmin", hmin)
175
176
 
176
- def compute_bathymetry_laplacian(self):
177
+ def _straddle(self) -> None:
177
178
  """
178
- Compute the Laplacian of the 'h' field in the provided grid dataset.
179
+ Check if the Greenwich meridian goes through the domain.
180
+
181
+ This method sets the `straddle` attribute to `True` if the Greenwich meridian
182
+ (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
183
+ the `straddle` attribute to `False`.
179
184
 
180
- Adds:
181
- xarray.DataArray: The Laplacian of the 'h' field as a new variable in the dataset self.ds.
185
+ The check is based on whether the longitudinal differences between adjacent
186
+ points exceed 300 degrees, indicating a potential wraparound of longitude.
182
187
  """
183
188
 
184
- # Extract the 'h' field and grid spacing variables
185
- h = self.ds.h
186
- pm = self.ds.pm # Reciprocal of grid spacing in x-direction
187
- pn = self.ds.pn # Reciprocal of grid spacing in y-direction
189
+ if (
190
+ np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
191
+ or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
192
+ ):
193
+ object.__setattr__(self, "straddle", True)
194
+ else:
195
+ object.__setattr__(self, "straddle", False)
188
196
 
189
- # Compute second derivatives using finite differences
190
- d2h_dx2 = (h.shift(xi_rho=-1) - 2 * h + h.shift(xi_rho=1)) * pm**2
191
- d2h_dy2 = (h.shift(eta_rho=-1) - 2 * h + h.shift(eta_rho=1)) * pn**2
197
+ def _coarsen(self):
198
+ """
199
+ Update the grid by adding grid variables that are coarsened versions of the original
200
+ fine-resoluion grid variables. The coarsening is by a factor of two.
192
201
 
193
- # Compute the Laplacian by summing second derivatives
194
- laplacian_h = d2h_dx2 + d2h_dy2
202
+ The specific variables being coarsened are:
203
+ - `lon_rho` -> `lon_coarse`: Longitude at rho points.
204
+ - `lat_rho` -> `lat_coarse`: Latitude at rho points.
205
+ - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
206
+ - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
195
207
 
196
- # Add the Laplacian as a new variable in the dataset
197
- self.ds["h_laplacian"] = laplacian_h
198
- self.ds["h_laplacian"].attrs["long_name"] = "Laplacian of final bathymetry"
199
- self.ds["h_laplacian"].attrs["units"] = "1/m"
208
+ Returns
209
+ -------
210
+ None
200
211
 
201
- def save(self, filepath: str) -> None:
212
+ Modifies
213
+ --------
214
+ self.ds : xr.Dataset
215
+ The dataset attribute of the Grid instance is updated with the new coarser variables.
202
216
  """
203
- Save the grid information to a netCDF4 file.
217
+ d = {
218
+ "angle": "angle_coarse",
219
+ "mask_rho": "mask_coarse",
220
+ "lat_rho": "lat_coarse",
221
+ "lon_rho": "lon_coarse",
222
+ }
223
+
224
+ for fine_var, coarse_var in d.items():
225
+ fine_field = self.ds[fine_var]
226
+ if self.straddle and fine_var == "lon_rho":
227
+ fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
228
+
229
+ coarse_field = _f2c(fine_field)
230
+ if fine_var == "lon_rho":
231
+ coarse_field = xr.where(
232
+ coarse_field < 0, coarse_field + 360, coarse_field
233
+ )
234
+ if coarse_var in ["lon_coarse", "lat_coarse"]:
235
+ ds = self.ds.assign_coords({coarse_var: coarse_field})
236
+ object.__setattr__(self, "ds", ds)
237
+ else:
238
+ self.ds[coarse_var] = coarse_field
239
+ self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0).astype(
240
+ np.int32
241
+ )
242
+
243
+ def update_vertical_coordinate(self, N, theta_s, theta_b, hc) -> None:
244
+ """
245
+ Create vertical coordinate variables for the ROMS grid.
246
+
247
+ This method computes the S-coordinate stretching curves and depths
248
+ at various grid points (rho, u, v) using the specified parameters.
249
+ The computed depths and stretching curves are added to the dataset
250
+ as new coordinates, along with their corresponding attributes.
204
251
 
205
252
  Parameters
206
253
  ----------
207
- filepath
254
+ N : int
255
+ Number of vertical levels.
256
+ theta_s : float
257
+ S-coordinate surface control parameter.
258
+ theta_b : float
259
+ S-coordinate bottom control parameter.
260
+ hc : float
261
+ Critical depth (m) used in ROMS vertical coordinate stretching.
262
+
263
+ Returns
264
+ -------
265
+ None
266
+ This method modifies the dataset in place by adding vertical coordinate variables.
208
267
  """
209
- self.ds.to_netcdf(filepath)
210
268
 
211
- def to_yaml(self, filepath: str) -> None:
269
+ ds = self.ds
270
+ # need to drop vertical coordinates because they could cause conflict if N changed
271
+ vars_to_drop = [
272
+ "layer_depth_rho",
273
+ "layer_depth_u",
274
+ "layer_depth_v",
275
+ "interface_depth_rho",
276
+ "interface_depth_u",
277
+ "interface_depth_v",
278
+ "sc_r",
279
+ "Cs_r",
280
+ ]
281
+
282
+ for var in vars_to_drop:
283
+ if var in ds.variables:
284
+ ds = ds.drop_vars(var)
285
+
286
+ h = ds.h
287
+
288
+ cs_r, sigma_r = sigma_stretch(theta_s, theta_b, N, "r")
289
+ zr = compute_depth(h * 0, h, hc, cs_r, sigma_r)
290
+ cs_w, sigma_w = sigma_stretch(theta_s, theta_b, N, "w")
291
+ zw = compute_depth(h * 0, h, hc, cs_w, sigma_w)
292
+
293
+ ds["sc_r"] = sigma_r.astype(np.float32)
294
+ ds["sc_r"].attrs["long_name"] = "S-coordinate at rho-points"
295
+ ds["sc_r"].attrs["units"] = "nondimensional"
296
+
297
+ ds["Cs_r"] = cs_r.astype(np.float32)
298
+ ds["Cs_r"].attrs["long_name"] = "S-coordinate stretching curves at rho-points"
299
+ ds["Cs_r"].attrs["units"] = "nondimensional"
300
+
301
+ ds.attrs["theta_s"] = np.float32(theta_s)
302
+ ds.attrs["theta_b"] = np.float32(theta_b)
303
+ ds.attrs["hc"] = np.float32(hc)
304
+
305
+ depth = -zr
306
+ depth.attrs["long_name"] = "Layer depth at rho-points"
307
+ depth.attrs["units"] = "m"
308
+
309
+ depth_u = interpolate_from_rho_to_u(depth)
310
+ depth_u.attrs["long_name"] = "Layer depth at u-points"
311
+ depth_u.attrs["units"] = "m"
312
+
313
+ depth_v = interpolate_from_rho_to_v(depth)
314
+ depth_v.attrs["long_name"] = "Layer depth at v-points"
315
+ depth_v.attrs["units"] = "m"
316
+
317
+ interface_depth = -zw
318
+ interface_depth.attrs["long_name"] = "Interface depth at rho-points"
319
+ interface_depth.attrs["units"] = "m"
320
+
321
+ interface_depth_u = interpolate_from_rho_to_u(interface_depth)
322
+ interface_depth_u.attrs["long_name"] = "Interface depth at u-points"
323
+ interface_depth_u.attrs["units"] = "m"
324
+
325
+ interface_depth_v = interpolate_from_rho_to_v(interface_depth)
326
+ interface_depth_v.attrs["long_name"] = "Interface depth at v-points"
327
+ interface_depth_v.attrs["units"] = "m"
328
+
329
+ ds = ds.assign_coords(
330
+ {
331
+ "layer_depth_rho": depth.astype(np.float32),
332
+ "layer_depth_u": depth_u.astype(np.float32),
333
+ "layer_depth_v": depth_v.astype(np.float32),
334
+ "interface_depth_rho": interface_depth.astype(np.float32),
335
+ "interface_depth_u": interface_depth_u.astype(np.float32),
336
+ "interface_depth_v": interface_depth_v.astype(np.float32),
337
+ }
338
+ )
339
+ ds = ds.drop_vars(["eta_rho", "xi_rho"])
340
+
341
+ object.__setattr__(self, "ds", ds)
342
+ object.__setattr__(self, "theta_s", theta_s)
343
+ object.__setattr__(self, "theta_b", theta_b)
344
+ object.__setattr__(self, "hc", hc)
345
+ object.__setattr__(self, "N", N)
346
+
347
+ def plot(self, bathymetry: bool = False) -> None:
212
348
  """
213
- Export the parameters of the class to a YAML file, including the version of roms-tools.
349
+ Plot the grid.
214
350
 
215
351
  Parameters
216
352
  ----------
217
- filepath : str
218
- The path to the YAML file where the parameters will be saved.
353
+ bathymetry : bool
354
+ Whether or not to plot the bathymetry. Default is False.
355
+
356
+ Returns
357
+ -------
358
+ None
359
+ This method does not return any value. It generates and displays a plot.
360
+
219
361
  """
220
- data = asdict(self)
221
- data.pop("ds", None)
222
- data.pop("straddle", None)
223
362
 
224
- # Include the version of roms-tools
225
- try:
226
- roms_tools_version = importlib.metadata.version("roms-tools")
227
- except importlib.metadata.PackageNotFoundError:
228
- roms_tools_version = "unknown"
363
+ if bathymetry:
364
+ field = self.ds.h.where(self.ds.mask_rho)
365
+ field = field.assign_coords(
366
+ {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
367
+ )
229
368
 
230
- # Create header
231
- header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
369
+ vmax = field.max().values
370
+ vmin = field.min().values
371
+ cmap = plt.colormaps.get_cmap("YlGnBu")
372
+ cmap.set_bad(color="gray")
373
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
232
374
 
233
- # Use the class name as the top-level key
234
- yaml_data = {self.__class__.__name__: data}
375
+ _plot(
376
+ self.ds,
377
+ field=field,
378
+ straddle=self.straddle,
379
+ kwargs=kwargs,
380
+ )
381
+ else:
382
+ _plot(self.ds, straddle=self.straddle)
235
383
 
236
- with open(filepath, "w") as file:
237
- # Write header
238
- file.write(header)
239
- # Write YAML data
240
- yaml.dump(yaml_data, file, default_flow_style=False)
384
+ def plot_vertical_coordinate(
385
+ self,
386
+ varname="layer_depth_rho",
387
+ s=None,
388
+ eta=None,
389
+ xi=None,
390
+ ) -> None:
391
+ """
392
+ Plot the vertical coordinate system for a given eta-, xi-, or s-slice.
393
+
394
+ Parameters
395
+ ----------
396
+ varname : str, optional
397
+ The vertical coordinate field to plot. Options include:
398
+ - "layer_depth_rho": Layer depth at rho-points.
399
+ - "layer_depth_u": Layer depth at u-points.
400
+ - "layer_depth_v": Layer depth at v-points.
401
+ - "interface_depth_rho": Interface depth at rho-points.
402
+ - "interface_depth_u": Interface depth at u-points.
403
+ - "interface_depth_v": Interface depth at v-points.
404
+ s: int, optional
405
+ The s-index to plot. Default is None.
406
+ eta : int, optional
407
+ The eta-index to plot. Default is None.
408
+ xi : int, optional
409
+ The xi-index to plot. Default is None.
410
+
411
+ Returns
412
+ -------
413
+ None
414
+ This method does not return any value. It generates and displays a plot.
415
+
416
+ Raises
417
+ ------
418
+ ValueError
419
+ If the specified varname is not one of the valid options.
420
+ If none of s, eta, xi are specified.
421
+ """
422
+
423
+ if not any([s is not None, eta is not None, xi is not None]):
424
+ raise ValueError("At least one of s, eta, or xi must be specified.")
425
+
426
+ self.ds[varname].load()
427
+ field = self.ds[varname].squeeze()
428
+
429
+ if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
430
+ interface_depth = self.ds.interface_depth_rho
431
+ field = field.where(self.ds.mask_rho)
432
+ field = field.assign_coords(
433
+ {"lon": self.ds.lon_rho, "lat": self.ds.lat_rho}
434
+ )
435
+ elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
436
+ interface_depth = self.ds.interface_depth_u
437
+ field = field.where(self.ds.mask_u)
438
+ field = field.assign_coords({"lon": self.ds.lon_u, "lat": self.ds.lat_u})
439
+ elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
440
+ interface_depth = self.ds.interface_depth_v
441
+ field = field.where(self.ds.mask_v)
442
+ field = field.assign_coords({"lon": self.ds.lon_v, "lat": self.ds.lat_v})
443
+
444
+ # slice the field as desired
445
+ title = field.long_name
446
+ if s is not None:
447
+ if "s_rho" in field.dims:
448
+ title = title + f", s_rho = {field.s_rho[s].item()}"
449
+ field = field.isel(s_rho=s)
450
+ elif "s_w" in field.dims:
451
+ title = title + f", s_w = {field.s_w[s].item()}"
452
+ field = field.isel(s_w=s)
453
+ else:
454
+ raise ValueError(
455
+ f"None of the expected dimensions (s_rho, s_w) found in ds[{varname}]."
456
+ )
457
+
458
+ if eta is not None:
459
+ if "eta_rho" in field.dims:
460
+ title = title + f", eta_rho = {field.eta_rho[eta].item()}"
461
+ field = field.isel(eta_rho=eta)
462
+ interface_depth = interface_depth.isel(eta_rho=eta)
463
+ elif "eta_v" in field.dims:
464
+ title = title + f", eta_v = {field.eta_v[eta].item()}"
465
+ field = field.isel(eta_v=eta)
466
+ interface_depth = interface_depth.isel(eta_v=eta)
467
+ else:
468
+ raise ValueError(
469
+ f"None of the expected dimensions (eta_rho, eta_v) found in ds[{varname}]."
470
+ )
471
+ if xi is not None:
472
+ if "xi_rho" in field.dims:
473
+ title = title + f", xi_rho = {field.xi_rho[xi].item()}"
474
+ field = field.isel(xi_rho=xi)
475
+ interface_depth = interface_depth.isel(xi_rho=xi)
476
+ elif "xi_u" in field.dims:
477
+ title = title + f", xi_u = {field.xi_u[xi].item()}"
478
+ field = field.isel(xi_u=xi)
479
+ interface_depth = interface_depth.isel(xi_u=xi)
480
+ else:
481
+ raise ValueError(
482
+ f"None of the expected dimensions (xi_rho, xi_u) found in ds[{varname}]."
483
+ )
484
+
485
+ if eta is None and xi is None:
486
+ vmax = field.max().values
487
+ vmin = field.min().values
488
+ cmap = plt.colormaps.get_cmap("YlGnBu")
489
+ cmap.set_bad(color="gray")
490
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
491
+
492
+ _plot(
493
+ self.ds,
494
+ field=field,
495
+ straddle=self.straddle,
496
+ depth_contours=False,
497
+ title=title,
498
+ kwargs=kwargs,
499
+ )
500
+ else:
501
+ if len(field.dims) == 2:
502
+ cmap = plt.colormaps.get_cmap("YlGnBu")
503
+ cmap.set_bad(color="gray")
504
+ kwargs = {"vmax": 0.0, "vmin": 0.0, "cmap": cmap, "add_colorbar": False}
505
+
506
+ _section_plot(
507
+ xr.zeros_like(field),
508
+ interface_depth=interface_depth,
509
+ title=title,
510
+ kwargs=kwargs,
511
+ )
512
+ else:
513
+ if "s_rho" in field.dims or "s_w" in field.dims:
514
+ _profile_plot(field, title=title)
515
+ else:
516
+ _line_plot(field, title=title)
517
+
518
+ def save(self, filepath: str) -> None:
519
+ """
520
+ Save the grid information to a netCDF4 file.
521
+
522
+ Parameters
523
+ ----------
524
+ filepath
525
+ """
526
+ self.ds.to_netcdf(filepath)
241
527
 
242
528
  @classmethod
243
529
  def from_file(cls, filepath: str) -> "Grid":
@@ -259,8 +545,6 @@ class Grid:
259
545
 
260
546
  if not all(mask in ds for mask in ["mask_u", "mask_v"]):
261
547
  ds = _add_velocity_masks(ds)
262
- if not all(coord in ds for coord in ["lat_u", "lon_u", "lat_v", "lon_v"]):
263
- ds = _add_lat_lon_at_velocity_points(ds)
264
548
 
265
549
  # Create a new Grid instance without calling __init__ and __post_init__
266
550
  grid = cls.__new__(cls)
@@ -271,26 +555,82 @@ class Grid:
271
555
  # Check if the Greenwich meridian goes through the domain.
272
556
  grid._straddle()
273
557
 
558
+ if not all(coord in grid.ds for coord in ["lat_u", "lon_u", "lat_v", "lon_v"]):
559
+ ds = _add_lat_lon_at_velocity_points(grid.ds, grid.straddle)
560
+ object.__setattr__(grid, "ds", ds)
561
+
562
+ # Coarsen the grid if necessary
563
+ if not all(
564
+ var in grid.ds
565
+ for var in [
566
+ "lon_coarse",
567
+ "lat_coarse",
568
+ "angle_coarse",
569
+ "mask_coarse",
570
+ ]
571
+ ):
572
+ grid._coarsen()
573
+
574
+ # Update vertical coordinate if necessary
575
+ if not all(var in grid.ds for var in ["sc_r", "Cs_r"]):
576
+ N = 100
577
+ theta_s = 5.0
578
+ theta_b = 2.0
579
+ hc = 300.0
580
+
581
+ grid.update_vertical_coordinate(
582
+ N=N, theta_s=theta_s, theta_b=theta_b, hc=hc
583
+ )
584
+
274
585
  # Manually set the remaining attributes by extracting parameters from dataset
275
586
  object.__setattr__(grid, "nx", ds.sizes["xi_rho"] - 2)
276
587
  object.__setattr__(grid, "ny", ds.sizes["eta_rho"] - 2)
277
- object.__setattr__(grid, "center_lon", ds["tra_lon"].values.item())
278
- object.__setattr__(grid, "center_lat", ds["tra_lat"].values.item())
279
- object.__setattr__(grid, "rot", ds["rotate"].values.item())
588
+ object.__setattr__(grid, "center_lon", ds.attrs["center_lon"])
589
+ object.__setattr__(grid, "center_lat", ds.attrs["center_lat"])
590
+ object.__setattr__(grid, "rot", ds.attrs["rot"])
280
591
 
281
592
  for attr in [
282
593
  "size_x",
283
594
  "size_y",
284
595
  "topography_source",
285
- "smooth_factor",
286
596
  "hmin",
287
- "rmax",
288
597
  ]:
289
598
  if attr in ds.attrs:
290
599
  object.__setattr__(grid, attr, ds.attrs[attr])
291
600
 
292
601
  return grid
293
602
 
603
+ def to_yaml(self, filepath: str) -> None:
604
+ """
605
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
606
+
607
+ Parameters
608
+ ----------
609
+ filepath : str
610
+ The path to the YAML file where the parameters will be saved.
611
+ """
612
+ data = asdict(self)
613
+ data.pop("ds", None)
614
+ data.pop("straddle", None)
615
+
616
+ # Include the version of roms-tools
617
+ try:
618
+ roms_tools_version = importlib.metadata.version("roms-tools")
619
+ except importlib.metadata.PackageNotFoundError:
620
+ roms_tools_version = "unknown"
621
+
622
+ # Create header
623
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
624
+
625
+ # Use the class name as the top-level key
626
+ yaml_data = {self.__class__.__name__: data}
627
+
628
+ with open(filepath, "w") as file:
629
+ # Write header
630
+ file.write(header)
631
+ # Write YAML data
632
+ yaml.dump(yaml_data, file, default_flow_style=False)
633
+
294
634
  @classmethod
295
635
  def from_yaml(cls, filepath: str) -> "Grid":
296
636
  """
@@ -358,102 +698,6 @@ class Grid:
358
698
  attr_str = ", ".join(f"{k}={v!r}" for k, v in attr_dict.items())
359
699
  return f"{cls_name}({attr_str})"
360
700
 
361
- # def to_xgcm() -> Any:
362
- # # TODO we could convert the dataset to an xgcm.Grid object and return here?
363
- # raise NotImplementedError()
364
-
365
- def _straddle(self) -> None:
366
- """
367
- Check if the Greenwich meridian goes through the domain.
368
-
369
- This method sets the `straddle` attribute to `True` if the Greenwich meridian
370
- (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
371
- the `straddle` attribute to `False`.
372
-
373
- The check is based on whether the longitudinal differences between adjacent
374
- points exceed 300 degrees, indicating a potential wraparound of longitude.
375
- """
376
-
377
- if (
378
- np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
379
- or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
380
- ):
381
- object.__setattr__(self, "straddle", True)
382
- else:
383
- object.__setattr__(self, "straddle", False)
384
-
385
- def plot(self, bathymetry: bool = False) -> None:
386
- """
387
- Plot the grid.
388
-
389
- Parameters
390
- ----------
391
- bathymetry : bool
392
- Whether or not to plot the bathymetry. Default is False.
393
-
394
- Returns
395
- -------
396
- None
397
- This method does not return any value. It generates and displays a plot.
398
-
399
- """
400
-
401
- if bathymetry:
402
- kwargs = {"cmap": "YlGnBu"}
403
-
404
- _plot(
405
- self.ds,
406
- field=self.ds.h.where(self.ds.mask_rho),
407
- straddle=self.straddle,
408
- kwargs=kwargs,
409
- )
410
- else:
411
- _plot(self.ds, straddle=self.straddle)
412
-
413
- def coarsen(self):
414
- """
415
- Update the grid by adding grid variables that are coarsened versions of the original
416
- fine-resoluion grid variables. The coarsening is by a factor of two.
417
-
418
- The specific variables being coarsened are:
419
- - `lon_rho` -> `lon_coarse`: Longitude at rho points.
420
- - `lat_rho` -> `lat_coarse`: Latitude at rho points.
421
- - `h` -> `h_coarse`: Bathymetry (depth).
422
- - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
423
- - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
424
-
425
- Returns
426
- -------
427
- None
428
-
429
- Modifies
430
- --------
431
- self.ds : xr.Dataset
432
- The dataset attribute of the Grid instance is updated with the new coarser variables.
433
- """
434
- d = {
435
- "lon_rho": "lon_coarse",
436
- "lat_rho": "lat_coarse",
437
- "h": "h_coarse",
438
- "angle": "angle_coarse",
439
- "mask_rho": "mask_coarse",
440
- }
441
-
442
- for fine_var, coarse_var in d.items():
443
- fine_field = self.ds[fine_var]
444
- if self.straddle and fine_var == "lon_rho":
445
- fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
446
-
447
- coarse_field = _f2c(fine_field)
448
- if fine_var == "lon_rho":
449
- coarse_field = xr.where(
450
- coarse_field < 0, coarse_field + 360, coarse_field
451
- )
452
-
453
- self.ds[coarse_var] = coarse_field
454
-
455
- self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0)
456
-
457
701
 
458
702
  def _make_grid_ds(
459
703
  nx: int,
@@ -484,7 +728,7 @@ def _make_grid_ds(
484
728
 
485
729
  ds = _create_grid_ds(lon, lat, pm, pn, ang, rot, center_lon, center_lat)
486
730
 
487
- ds = _add_global_metadata(ds, size_x, size_y)
731
+ ds = _add_global_metadata(ds, size_x, size_y, center_lon, center_lat, rot)
488
732
 
489
733
  return ds
490
734
 
@@ -777,14 +1021,21 @@ def _create_grid_ds(
777
1021
  center_lon,
778
1022
  center_lat,
779
1023
  ):
780
- # Create xarray.Dataset object with lat_rho and lon_rho as coordinates
781
- ds = xr.Dataset(
782
- coords={
783
- "lat_rho": (("eta_rho", "xi_rho"), lat * 180 / np.pi),
784
- "lon_rho": (("eta_rho", "xi_rho"), lon * 180 / np.pi),
785
- }
1024
+ ds = xr.Dataset()
1025
+
1026
+ lon_rho = xr.Variable(
1027
+ data=lon * 180 / np.pi,
1028
+ dims=["eta_rho", "xi_rho"],
1029
+ attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
1030
+ )
1031
+ lat_rho = xr.Variable(
1032
+ data=lat * 180 / np.pi,
1033
+ dims=["eta_rho", "xi_rho"],
1034
+ attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
786
1035
  )
787
1036
 
1037
+ ds = ds.assign_coords({"lat_rho": lat_rho, "lon_rho": lon_rho})
1038
+
788
1039
  ds["angle"] = xr.Variable(
789
1040
  data=angle,
790
1041
  dims=["eta_rho", "xi_rho"],
@@ -799,6 +1050,7 @@ def _create_grid_ds(
799
1050
  dims=["eta_rho", "xi_rho"],
800
1051
  attrs={"long_name": "Coriolis parameter at rho-points", "units": "second-1"},
801
1052
  )
1053
+
802
1054
  ds["pm"] = xr.Variable(
803
1055
  data=pm,
804
1056
  dims=["eta_rho", "xi_rho"],
@@ -816,36 +1068,15 @@ def _create_grid_ds(
816
1068
  },
817
1069
  )
818
1070
 
819
- ds["tra_lon"] = center_lon
820
- ds["tra_lon"].attrs["long_name"] = "Longitudinal translation of base grid"
821
- ds["tra_lon"].attrs["units"] = "degrees East"
822
-
823
- ds["tra_lat"] = center_lat
824
- ds["tra_lat"].attrs["long_name"] = "Latitudinal translation of base grid"
825
- ds["tra_lat"].attrs["units"] = "degrees North"
826
-
827
- ds["rotate"] = rot
828
- ds["rotate"].attrs["long_name"] = "Rotation of base grid"
829
- ds["rotate"].attrs["units"] = "degrees"
830
-
831
- ds["lon_rho"] = xr.Variable(
832
- data=lon * 180 / np.pi,
833
- dims=["eta_rho", "xi_rho"],
834
- attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
835
- )
836
-
837
- ds["lat_rho"] = xr.Variable(
838
- data=lat * 180 / np.pi,
839
- dims=["eta_rho", "xi_rho"],
840
- attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
841
- )
1071
+ return ds
842
1072
 
843
- ds = _add_lat_lon_at_velocity_points(ds)
844
1073
 
845
- return ds
1074
+ def _add_global_metadata(ds, size_x, size_y, center_lon, center_lat, rot):
846
1075
 
1076
+ ds["spherical"] = xr.DataArray(np.array("T", dtype="S1"))
1077
+ ds["spherical"].attrs["Long_name"] = "Grid type logical switch"
1078
+ ds["spherical"].attrs["option_T"] = "spherical"
847
1079
 
848
- def _add_global_metadata(ds, size_x, size_y):
849
1080
  ds.attrs["title"] = "ROMS grid created by ROMS-Tools"
850
1081
 
851
1082
  # Include the version of roms-tools
@@ -857,6 +1088,9 @@ def _add_global_metadata(ds, size_x, size_y):
857
1088
  ds.attrs["roms_tools_version"] = roms_tools_version
858
1089
  ds.attrs["size_x"] = size_x
859
1090
  ds.attrs["size_y"] = size_y
1091
+ ds.attrs["center_lon"] = center_lon
1092
+ ds.attrs["center_lat"] = center_lat
1093
+ ds.attrs["rot"] = rot
860
1094
 
861
1095
  return ds
862
1096
 
@@ -916,12 +1150,24 @@ def _f2c_xdir(f):
916
1150
  return fc
917
1151
 
918
1152
 
919
- def _add_lat_lon_at_velocity_points(ds):
1153
+ def _add_lat_lon_at_velocity_points(ds, straddle):
920
1154
 
921
- lat_u = interpolate_from_rho_to_u(ds["lat_rho"])
922
- lon_u = interpolate_from_rho_to_u(ds["lon_rho"])
923
- lat_v = interpolate_from_rho_to_v(ds["lat_rho"])
924
- lon_v = interpolate_from_rho_to_v(ds["lon_rho"])
1155
+ if straddle:
1156
+ # avoid jump from 360 to 0 in interpolation
1157
+ lon_rho = xr.where(ds["lon_rho"] > 180, ds["lon_rho"] - 360, ds["lon_rho"])
1158
+ else:
1159
+ lon_rho = ds["lon_rho"]
1160
+ lat_rho = ds["lat_rho"]
1161
+
1162
+ lat_u = interpolate_from_rho_to_u(lat_rho)
1163
+ lon_u = interpolate_from_rho_to_u(lon_rho)
1164
+ lat_v = interpolate_from_rho_to_v(lat_rho)
1165
+ lon_v = interpolate_from_rho_to_v(lon_rho)
1166
+
1167
+ if straddle:
1168
+ # convert back to range [0, 360]
1169
+ lon_u = xr.where(lon_u < 0, lon_u + 360, lon_u)
1170
+ lon_v = xr.where(lon_v < 0, lon_v + 360, lon_v)
925
1171
 
926
1172
  lat_u.attrs = {"long_name": "latitude of u-points", "units": "degrees North"}
927
1173
  lon_u.attrs = {"long_name": "longitude of u-points", "units": "degrees East"}
@@ -929,7 +1175,12 @@ def _add_lat_lon_at_velocity_points(ds):
929
1175
  lon_v.attrs = {"long_name": "longitude of v-points", "units": "degrees East"}
930
1176
 
931
1177
  ds = ds.assign_coords(
932
- {"lat_u": lat_u, "lon_u": lon_u, "lat_v": lat_v, "lon_v": lon_v}
1178
+ {
1179
+ "lat_u": lat_u,
1180
+ "lon_u": lon_u,
1181
+ "lat_v": lat_v,
1182
+ "lon_v": lon_v,
1183
+ }
933
1184
  )
934
1185
 
935
1186
  return ds