roms-tools 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,6 @@ import importlib.metadata
5
5
  from dataclasses import dataclass, field, asdict
6
6
  from typing import Optional, Dict, Union
7
7
  from roms_tools.setup.grid import Grid
8
- from roms_tools.setup.vertical_coordinate import VerticalCoordinate
9
8
  from datetime import datetime
10
9
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
11
10
  from roms_tools.setup.utils import (
@@ -25,8 +24,6 @@ class InitialConditions(ROMSToolsMixins):
25
24
  ----------
26
25
  grid : Grid
27
26
  Object representing the grid information used for the model.
28
- vertical_coordinate : VerticalCoordinate
29
- Object representing the vertical coordinate system.
30
27
  ini_time : datetime
31
28
  The date and time at which the initial conditions are set.
32
29
  physics_source : Dict[str, Union[str, None]]
@@ -51,7 +48,6 @@ class InitialConditions(ROMSToolsMixins):
51
48
  --------
52
49
  >>> initial_conditions = InitialConditions(
53
50
  ... grid=grid,
54
- ... vertical_coordinate=vertical_coordinate,
55
51
  ... ini_time=datetime(2022, 1, 1),
56
52
  ... physics_source={"name": "GLORYS", "path": "physics_data.nc"},
57
53
  ... bgc_source={
@@ -63,7 +59,6 @@ class InitialConditions(ROMSToolsMixins):
63
59
  """
64
60
 
65
61
  grid: Grid
66
- vertical_coordinate: VerticalCoordinate
67
62
  ini_time: datetime
68
63
  physics_source: Dict[str, Union[str, None]]
69
64
  bgc_source: Optional[Dict[str, Union[str, None]]] = None
@@ -87,7 +82,7 @@ class InitialConditions(ROMSToolsMixins):
87
82
  vars_2d = ["zeta"]
88
83
  vars_3d = ["temp", "salt", "u", "v"]
89
84
  data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
90
- data_vars = super().process_velocities(data_vars, angle)
85
+ data_vars = super().process_velocities(data_vars, angle, "u", "v")
91
86
 
92
87
  if self.bgc_source is not None:
93
88
  bgc_data = self._get_bgc_data()
@@ -177,18 +172,18 @@ class InitialConditions(ROMSToolsMixins):
177
172
 
178
173
  if self.bgc_source["name"] == "CESM_REGRIDDED":
179
174
 
180
- bgc_data = CESMBGCDataset(
175
+ data = CESMBGCDataset(
181
176
  filename=self.bgc_source["path"],
182
177
  start_time=self.ini_time,
183
178
  climatology=self.bgc_source["climatology"],
184
179
  )
185
- bgc_data.post_process()
180
+ data.post_process()
186
181
  else:
187
182
  raise ValueError(
188
183
  'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
189
184
  )
190
185
 
191
- return bgc_data
186
+ return data
192
187
 
193
188
  def _write_into_dataset(self, data_vars, d_meta):
194
189
 
@@ -202,26 +197,48 @@ class InitialConditions(ROMSToolsMixins):
202
197
 
203
198
  # initialize vertical velocity to zero
204
199
  ds["w"] = xr.zeros_like(
205
- self.vertical_coordinate.ds["interface_depth_rho"].expand_dims(
206
- time=data_vars["u"].time
207
- )
200
+ self.grid.ds["interface_depth_rho"].expand_dims(time=data_vars["u"].time)
208
201
  ).astype(np.float32)
209
202
  ds["w"].attrs["long_name"] = d_meta["w"]["long_name"]
210
203
  ds["w"].attrs["units"] = d_meta["w"]["units"]
211
204
 
205
+ variables_to_drop = [
206
+ "s_rho",
207
+ "lat_rho",
208
+ "lon_rho",
209
+ "layer_depth_rho",
210
+ "interface_depth_rho",
211
+ "lat_u",
212
+ "lon_u",
213
+ "lat_v",
214
+ "lon_v",
215
+ ]
216
+ existing_vars = [var for var in variables_to_drop if var in ds]
217
+ ds = ds.drop_vars(existing_vars)
218
+
219
+ ds["sc_r"] = self.grid.ds["sc_r"]
220
+ ds["Cs_r"] = self.grid.ds["Cs_r"]
221
+
222
+ # Preserve absolute time coordinate for readability
223
+ ds = ds.assign_coords({"abs_time": ds["time"]})
224
+
225
+ # Translate the time coordinate to days since the model reference date
226
+ model_reference_date = np.datetime64(self.model_reference_date)
227
+
228
+ # Convert the time coordinate to the format expected by ROMS (days since model reference date)
229
+ ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
230
+ ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
231
+ ds["ocean_time"].attrs[
232
+ "long_name"
233
+ ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
234
+ ds["ocean_time"].attrs["units"] = "seconds"
235
+ ds = ds.swap_dims({"time": "ocean_time"})
236
+ ds = ds.drop_vars("time")
237
+
212
238
  return ds
213
239
 
214
240
  def _add_global_metadata(self, ds):
215
241
 
216
- ds = ds.assign_coords(
217
- {
218
- "layer_depth_u": self.vertical_coordinate.ds["layer_depth_u"],
219
- "layer_depth_v": self.vertical_coordinate.ds["layer_depth_v"],
220
- "interface_depth_u": self.vertical_coordinate.ds["interface_depth_u"],
221
- "interface_depth_v": self.vertical_coordinate.ds["interface_depth_v"],
222
- }
223
- )
224
-
225
242
  ds.attrs["title"] = "ROMS initial conditions file created by ROMS-Tools"
226
243
  # Include the version of roms-tools
227
244
  try:
@@ -235,25 +252,9 @@ class InitialConditions(ROMSToolsMixins):
235
252
  if self.bgc_source is not None:
236
253
  ds.attrs["bgc_source"] = self.bgc_source["name"]
237
254
 
238
- # Translate the time coordinate to days since the model reference date
239
- model_reference_date = np.datetime64(self.model_reference_date)
240
-
241
- # Convert the time coordinate to the format expected by ROMS (days since model reference date)
242
- ocean_time = (ds["time"] - model_reference_date).astype("float64") * 1e-9
243
- ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
244
- ds["ocean_time"].attrs[
245
- "long_name"
246
- ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
247
- ds["ocean_time"].attrs["units"] = "seconds"
248
-
249
- ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
250
- ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
251
- ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
252
- ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
253
- ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
254
- ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
255
-
256
- ds = ds.drop_vars(["s_rho"])
255
+ ds.attrs["theta_s"] = self.grid.ds.attrs["theta_s"]
256
+ ds.attrs["theta_b"] = self.grid.ds.attrs["theta_b"]
257
+ ds.attrs["hc"] = self.grid.ds.attrs["hc"]
257
258
 
258
259
  return ds
259
260
 
@@ -314,13 +315,23 @@ class InitialConditions(ROMSToolsMixins):
314
315
  - "diazP": Diazotroph Phosphorus (mmol/m³).
315
316
  - "diazFe": Diazotroph Iron (mmol/m³).
316
317
  s : int, optional
317
- The index of the vertical layer to plot. Default is None.
318
+ The index of the vertical layer (`s_rho`) to plot. If not specified, the plot
319
+ will represent a horizontal slice (eta- or xi- plane). Default is None.
318
320
  eta : int, optional
319
- The eta-index to plot. Default is None.
321
+ The eta-index to plot. Used for vertical sections or horizontal slices.
322
+ Default is None.
320
323
  xi : int, optional
321
- The xi-index to plot. Default is None.
324
+ The xi-index to plot. Used for vertical sections or horizontal slices.
325
+ Default is None.
322
326
  depth_contours : bool, optional
323
- Whether to include depth contours in the plot. Default is False.
327
+ If True, depth contours will be overlaid on the plot, showing lines of constant
328
+ depth. This is typically used for plots that show a single vertical layer.
329
+ Default is False.
330
+ layer_contours : bool, optional
331
+ If True, contour lines representing the boundaries between vertical layers will
332
+ be added to the plot. This is particularly useful in vertical sections to
333
+ visualize the layering of the water column. For clarity, the number of layer
334
+ contours displayed is limited to a maximum of 10. Default is False.
324
335
 
325
336
  Returns
326
337
  -------
@@ -333,6 +344,7 @@ class InitialConditions(ROMSToolsMixins):
333
344
  If the specified `varname` is not one of the valid options.
334
345
  If the field specified by `varname` is 3D and none of `s`, `eta`, or `xi` are specified.
335
346
  If the field specified by `varname` is 2D and both `eta` and `xi` are specified.
347
+
336
348
  """
337
349
 
338
350
  if len(self.ds[varname].squeeze().dims) == 3 and not any(
@@ -351,17 +363,38 @@ class InitialConditions(ROMSToolsMixins):
351
363
  field = self.ds[varname].squeeze()
352
364
 
353
365
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
354
- interface_depth = self.ds.interface_depth_rho
366
+ interface_depth = self.grid.ds.interface_depth_rho
367
+ layer_depth = self.grid.ds.layer_depth_rho
368
+ field = field.where(self.grid.ds.mask_rho)
369
+ field = field.assign_coords(
370
+ {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
371
+ )
372
+
355
373
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
356
- interface_depth = self.ds.interface_depth_u
374
+ interface_depth = self.grid.ds.interface_depth_u
375
+ layer_depth = self.grid.ds.layer_depth_u
376
+ field = field.where(self.grid.ds.mask_u)
377
+ field = field.assign_coords(
378
+ {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
379
+ )
380
+
357
381
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
358
- interface_depth = self.ds.interface_depth_v
382
+ interface_depth = self.grid.ds.interface_depth_v
383
+ layer_depth = self.grid.ds.layer_depth_v
384
+ field = field.where(self.grid.ds.mask_v)
385
+ field = field.assign_coords(
386
+ {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
387
+ )
388
+ else:
389
+ ValueError("provided field does not have two horizontal dimension")
359
390
 
360
391
  # slice the field as desired
361
392
  title = field.long_name
362
393
  if s is not None:
363
394
  title = title + f", s_rho = {field.s_rho[s].item()}"
364
395
  field = field.isel(s_rho=s)
396
+ layer_depth = layer_depth.isel(s_rho=s)
397
+ field = field.assign_coords({"layer_depth": layer_depth})
365
398
  else:
366
399
  depth_contours = False
367
400
 
@@ -369,10 +402,14 @@ class InitialConditions(ROMSToolsMixins):
369
402
  if "eta_rho" in field.dims:
370
403
  title = title + f", eta_rho = {field.eta_rho[eta].item()}"
371
404
  field = field.isel(eta_rho=eta)
405
+ layer_depth = layer_depth.isel(eta_rho=eta)
406
+ field = field.assign_coords({"layer_depth": layer_depth})
372
407
  interface_depth = interface_depth.isel(eta_rho=eta)
373
408
  elif "eta_v" in field.dims:
374
409
  title = title + f", eta_v = {field.eta_v[eta].item()}"
375
410
  field = field.isel(eta_v=eta)
411
+ layer_depth = layer_depth.isel(eta_v=eta)
412
+ field = field.assign_coords({"layer_depth": layer_depth})
376
413
  interface_depth = interface_depth.isel(eta_v=eta)
377
414
  else:
378
415
  raise ValueError(
@@ -382,10 +419,14 @@ class InitialConditions(ROMSToolsMixins):
382
419
  if "xi_rho" in field.dims:
383
420
  title = title + f", xi_rho = {field.xi_rho[xi].item()}"
384
421
  field = field.isel(xi_rho=xi)
422
+ layer_depth = layer_depth.isel(xi_rho=xi)
423
+ field = field.assign_coords({"layer_depth": layer_depth})
385
424
  interface_depth = interface_depth.isel(xi_rho=xi)
386
425
  elif "xi_u" in field.dims:
387
426
  title = title + f", xi_u = {field.xi_u[xi].item()}"
388
427
  field = field.isel(xi_u=xi)
428
+ layer_depth = layer_depth.isel(xi_u=xi)
429
+ field = field.assign_coords({"layer_depth": layer_depth})
389
430
  interface_depth = interface_depth.isel(xi_u=xi)
390
431
  else:
391
432
  raise ValueError(
@@ -462,11 +503,6 @@ class InitialConditions(ROMSToolsMixins):
462
503
  grid_data.pop("ds", None) # Exclude non-serializable fields
463
504
  grid_data.pop("straddle", None)
464
505
 
465
- # Serialize VerticalCoordinate data
466
- vertical_coordinate_data = asdict(self.vertical_coordinate)
467
- vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
468
- vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
469
-
470
506
  # Include the version of roms-tools
471
507
  try:
472
508
  roms_tools_version = importlib.metadata.version("roms-tools")
@@ -477,7 +513,6 @@ class InitialConditions(ROMSToolsMixins):
477
513
  header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
478
514
 
479
515
  grid_yaml_data = {"Grid": grid_data}
480
- vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
481
516
 
482
517
  initial_conditions_data = {
483
518
  "InitialConditions": {
@@ -492,7 +527,6 @@ class InitialConditions(ROMSToolsMixins):
492
527
 
493
528
  yaml_data = {
494
529
  **grid_yaml_data,
495
- **vertical_coordinate_yaml_data,
496
530
  **initial_conditions_data,
497
531
  }
498
532
 
@@ -545,13 +579,10 @@ class InitialConditions(ROMSToolsMixins):
545
579
  initial_conditions_data[date_string]
546
580
  )
547
581
 
548
- # Create VerticalCoordinate instance from the YAML file
549
- vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
550
- grid = vertical_coordinate.grid
582
+ grid = Grid.from_yaml(filepath)
551
583
 
552
584
  # Create and return an instance of InitialConditions
553
585
  return cls(
554
586
  grid=grid,
555
- vertical_coordinate=vertical_coordinate,
556
587
  **initial_conditions_data,
557
588
  )
@@ -1,6 +1,5 @@
1
1
  from dataclasses import dataclass
2
2
  from roms_tools.setup.grid import Grid
3
- from roms_tools.setup.vertical_coordinate import VerticalCoordinate
4
3
  from roms_tools.setup.fill import fill_and_interpolate
5
4
  from roms_tools.setup.utils import (
6
5
  extrapolate_deepest_to_bottom,
@@ -21,13 +20,10 @@ class ROMSToolsMixins:
21
20
  ----------
22
21
  grid : Grid
23
22
  Object representing the grid information used for the model.
24
- vertical_coordinate : VerticalCoordinate
25
- Object representing the vertical coordinate system. Defaults to None.
26
23
 
27
24
  """
28
25
 
29
26
  grid: Grid
30
- vertical_coordinate: VerticalCoordinate = None
31
27
 
32
28
  def get_target_lon_lat(self, use_coarse_grid=False):
33
29
  """
@@ -54,11 +50,6 @@ class ROMSToolsMixins:
54
50
  """
55
51
 
56
52
  if use_coarse_grid:
57
- if "lon_coarse" not in self.grid.ds:
58
- raise ValueError(
59
- "Grid has not been coarsened yet. Execute grid.coarsen() first."
60
- )
61
-
62
53
  lon = self.grid.ds.lon_coarse
63
54
  lat = self.grid.ds.lat_coarse
64
55
  angle = self.grid.ds.angle_coarse
@@ -132,9 +123,9 @@ class ROMSToolsMixins:
132
123
  if vars_3d:
133
124
  # 3d interpolation
134
125
  coords = {
126
+ data.dim_names["depth"]: self.grid.ds["layer_depth_rho"],
135
127
  data.dim_names["latitude"]: lat,
136
128
  data.dim_names["longitude"]: lon,
137
- data.dim_names["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
138
129
  }
139
130
  # extrapolate deepest value all the way to bottom ("flooding")
140
131
  for var in vars_3d:
@@ -158,37 +149,50 @@ class ROMSToolsMixins:
158
149
  if data.dim_names["time"] != "time":
159
150
  data_vars[var] = data_vars[var].rename({data.dim_names["time"]: "time"})
160
151
 
152
+ # transpose to correct order (time, s_rho, eta_rho, xi_rho)
153
+ data_vars[var] = data_vars[var].transpose(
154
+ "time", "s_rho", "eta_rho", "xi_rho"
155
+ )
156
+
161
157
  return data_vars
162
158
 
163
- def process_velocities(self, data_vars, angle, interpolate=True):
159
+ def process_velocities(self, data_vars, angle, uname, vname, interpolate=True):
164
160
  """
165
- Processes and rotates velocity components, and interpolates them to the appropriate grid points.
161
+ Process and rotate velocity components to align with the grid orientation and optionally interpolate
162
+ them to the appropriate grid points.
166
163
 
167
164
  This method performs the following steps:
168
- 1. Rotates the velocity components to align with the grid orientation using the provided angle.
169
- 2. Optionally interpolates the rotated velocities to the u- and v-points of the grid.
170
- 3. If a vertical coordinate is provided, computes the barotropic velocities by integrating
171
- over the vertical dimension.
165
+
166
+ 1. **Rotation**: Rotates the velocity components (e.g., `u`, `v`) to align with the grid orientation
167
+ using the provided angle data.
168
+ 2. **Interpolation**: Optionally interpolates the rotated velocities from rho-points to u- and v-points
169
+ of the grid.
170
+ 3. **Barotropic Velocity Calculation**: If the velocity components are 3D (with vertical coordinates),
171
+ computes the barotropic (depth-averaged) velocities.
172
172
 
173
173
  Parameters
174
174
  ----------
175
175
  data_vars : dict of str: xarray.DataArray
176
- Dictionary containing the velocity components to be processed. Must include keys "u" and "v"
177
- or "uwnd" and "vwnd".
176
+ Dictionary containing the velocity components to be processed. The dictionary should include keys
177
+ corresponding to the velocity component names (e.g., `uname`, `vname`).
178
178
  angle : xarray.DataArray
179
- DataArray containing the angle used for rotating the velocity components to the grid orientation.
179
+ DataArray containing the grid angle values used to rotate the velocity components to the correct
180
+ orientation on the grid.
181
+ uname : str
182
+ The key corresponding to the zonal (east-west) velocity component in `data_vars`.
183
+ vname : str
184
+ The key corresponding to the meridional (north-south) velocity component in `data_vars`.
180
185
  interpolate : bool, optional
181
- If True, interpolates the velocities to the u- and v-points. Defaults to True.
186
+ If True, interpolates the rotated velocity components to the u- and v-points of the grid.
187
+ Defaults to True.
182
188
 
183
189
  Returns
184
190
  -------
185
191
  dict of str: xarray.DataArray
186
- Dictionary of processed velocity components. Includes "ubar" and "vbar" if a vertical coordinate
187
- is provided.
192
+ A dictionary of the processed velocity components. The returned dictionary includes the rotated and,
193
+ if applicable, interpolated velocity components. If the input velocities are 3D (having a vertical
194
+ dimension), the dictionary also includes the barotropic (depth-averaged) velocities (`ubar` and `vbar`).
188
195
  """
189
- # Determine the correct variable names based on the keys in data_vars
190
- uname = "u" if "u" in data_vars else "uwnd"
191
- vname = "v" if "v" in data_vars else "vwnd"
192
196
 
193
197
  # Rotate velocities to grid orientation
194
198
  u_rot = data_vars[uname] * np.cos(angle) + data_vars[vname] * np.sin(angle)
@@ -202,7 +206,7 @@ class ROMSToolsMixins:
202
206
  data_vars[uname] = u_rot
203
207
  data_vars[vname] = v_rot
204
208
 
205
- if self.vertical_coordinate is not None:
209
+ if "s_rho" in data_vars[uname].dims and "s_rho" in data_vars[vname].dims:
206
210
  # 3D masks for ROMS domain
207
211
  umask = self.grid.ds.mask_u.expand_dims({"s_rho": data_vars[uname].s_rho})
208
212
  vmask = self.grid.ds.mask_v.expand_dims({"s_rho": data_vars[vname].s_rho})
@@ -211,7 +215,7 @@ class ROMSToolsMixins:
211
215
  data_vars[vname] = data_vars[vname] * vmask
212
216
 
213
217
  # Compute barotropic velocity
214
- dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
218
+ dz = -self.grid.ds["interface_depth_rho"].diff(dim="s_w")
215
219
  dz = dz.rename({"s_w": "s_rho"})
216
220
  dzu = interpolate_from_rho_to_u(dz)
217
221
  dzv = interpolate_from_rho_to_v(dz)
@@ -240,6 +244,26 @@ class ROMSToolsMixins:
240
244
  """
241
245
 
242
246
  d = {
247
+ "ssh_Re": {"long_name": "Tidal elevation, real part", "units": "m"},
248
+ "ssh_Im": {"long_name": "Tidal elevation, complex part", "units": "m"},
249
+ "pot_Re": {"long_name": "Tidal potential, real part", "units": "m"},
250
+ "pot_Im": {"long_name": "Tidal potential, complex part", "units": "m"},
251
+ "u_Re": {
252
+ "long_name": "Tidal velocity in x-direction, real part",
253
+ "units": "m/s",
254
+ },
255
+ "u_Im": {
256
+ "long_name": "Tidal velocity in x-direction, complex part",
257
+ "units": "m/s",
258
+ },
259
+ "v_Re": {
260
+ "long_name": "Tidal velocity in y-direction, real part",
261
+ "units": "m/s",
262
+ },
263
+ "v_Im": {
264
+ "long_name": "Tidal velocity in y-direction, complex part",
265
+ "units": "m/s",
266
+ },
243
267
  "uwnd": {"long_name": "10 meter wind in x-direction", "units": "m/s"},
244
268
  "vwnd": {"long_name": "10 meter wind in y-direction", "units": "m/s"},
245
269
  "swrad": {
@@ -333,18 +357,18 @@ class ROMSToolsMixins:
333
357
  return d
334
358
 
335
359
  def get_boundary_info(self):
336
- """
337
- Provides boundary coordinate information and renaming conventions for grid boundaries.
338
360
 
339
- This method returns two dictionaries: one specifying the boundary coordinates for different types of
340
- grid variables (e.g., "rho", "u", "v"), and another specifying how to rename dimensions for these boundaries.
361
+ """
362
+ This method provides information about the boundary points for the rho, u, and v
363
+ variables on the grid, specifying the indices for the south, east, north, and west
364
+ boundaries.
341
365
 
342
366
  Returns
343
367
  -------
344
- tuple of (dict, dict)
345
- - A dictionary mapping variable types and directions to boundary coordinates.
346
- - A dictionary mapping variable types and directions to new dimension names.
347
-
368
+ dict
369
+ A dictionary where keys are variable types ("rho", "u", "v"), and values
370
+ are nested dictionaries mapping directions ("south", "east", "north", "west")
371
+ to the corresponding boundary coordinates.
348
372
  """
349
373
 
350
374
  # Boundary coordinates
@@ -369,27 +393,4 @@ class ROMSToolsMixins:
369
393
  },
370
394
  }
371
395
 
372
- # How to rename the dimensions
373
-
374
- rename = {
375
- "rho": {
376
- "south": {"xi_rho": "xi_rho_south"},
377
- "east": {"eta_rho": "eta_rho_east"},
378
- "north": {"xi_rho": "xi_rho_north"},
379
- "west": {"eta_rho": "eta_rho_west"},
380
- },
381
- "u": {
382
- "south": {"xi_u": "xi_u_south"},
383
- "east": {"eta_rho": "eta_u_east"},
384
- "north": {"xi_u": "xi_u_north"},
385
- "west": {"eta_rho": "eta_u_west"},
386
- },
387
- "v": {
388
- "south": {"xi_rho": "xi_v_south"},
389
- "east": {"eta_v": "eta_v_east"},
390
- "north": {"xi_rho": "xi_v_north"},
391
- "west": {"eta_v": "eta_v_west"},
392
- },
393
- }
394
-
395
- return bdry_coords, rename
396
+ return bdry_coords
roms_tools/setup/plot.py CHANGED
@@ -8,7 +8,6 @@ def _plot(
8
8
  field=None,
9
9
  depth_contours=False,
10
10
  straddle=False,
11
- coarse_grid=False,
12
11
  c="red",
13
12
  title="",
14
13
  kwargs={},
@@ -21,29 +20,8 @@ def _plot(
21
20
  else:
22
21
 
23
22
  field = field.squeeze()
24
-
25
- if coarse_grid:
26
-
27
- field = field.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
28
- field = field.where(grid_ds.mask_coarse)
29
- lon_deg = field.lon
30
- lat_deg = field.lat
31
-
32
- else:
33
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
34
- field = field.where(grid_ds.mask_rho)
35
- lon_deg = grid_ds["lon_rho"]
36
- lat_deg = grid_ds["lat_rho"]
37
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
38
- field = field.where(grid_ds.mask_u)
39
- lon_deg = grid_ds["lon_u"]
40
- lat_deg = grid_ds["lat_u"]
41
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
42
- field = field.where(grid_ds.mask_v)
43
- lon_deg = grid_ds["lon_v"]
44
- lat_deg = grid_ds["lat_v"]
45
- else:
46
- ValueError("provided field does not have two horizontal dimension")
23
+ lon_deg = field.lon
24
+ lat_deg = field.lat
47
25
 
48
26
  # check if North or South pole are in domain
49
27
  if lat_deg.max().values > 89 or lat_deg.min().values < -89:
@@ -96,23 +74,7 @@ def _plot(
96
74
  plt.colorbar(p, label=f"{field.long_name} [{field.units}]")
97
75
 
98
76
  if depth_contours:
99
- if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
100
- if "layer_depth_rho" in field.coords:
101
- depth = field.layer_depth_rho
102
- else:
103
- depth = field.interface_depth_rho
104
- elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
105
- if "layer_depth_u" in field.coords:
106
- depth = field.layer_depth_u
107
- else:
108
- depth = field.interface_depth_u
109
- elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
110
- if "layer_depth_v" in field.coords:
111
- depth = field.layer_depth_v
112
- else:
113
- depth = field.interface_depth_v
114
-
115
- cs = ax.contour(lon_deg, lat_deg, depth, transform=proj, colors="k")
77
+ cs = ax.contour(lon_deg, lat_deg, field.layer_depth, transform=proj, colors="k")
116
78
  ax.clabel(cs, inline=True, fontsize=10)
117
79
 
118
80
  return fig
@@ -135,12 +97,8 @@ def _section_plot(field, interface_depth=None, title="", kwargs={}):
135
97
  )
136
98
 
137
99
  depths_to_check = [
138
- "layer_depth_rho",
139
- "layer_depth_u",
140
- "layer_depth_v",
141
- "interface_depth_rho",
142
- "interface_depth_u",
143
- "interface_depth_v",
100
+ "layer_depth",
101
+ "interface_depth",
144
102
  ]
145
103
  try:
146
104
  depth_label = next(
@@ -95,7 +95,9 @@ class SurfaceForcing(ROMSToolsMixins):
95
95
  vars_2d = ["uwnd", "vwnd", "swrad", "lwrad", "Tair", "qair", "rain"]
96
96
  vars_3d = []
97
97
  data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
98
- data_vars = super().process_velocities(data_vars, angle, interpolate=False)
98
+ data_vars = super().process_velocities(
99
+ data_vars, angle, "uwnd", "vwnd", interpolate=False
100
+ )
99
101
 
100
102
  if self.correct_radiation:
101
103
  correction_data = self._get_correction_data()
@@ -235,18 +237,19 @@ class SurfaceForcing(ROMSToolsMixins):
235
237
 
236
238
  if self.bgc_source["name"] == "CESM_REGRIDDED":
237
239
 
238
- bgc_data = CESMBGCSurfaceForcingDataset(
240
+ data = CESMBGCSurfaceForcingDataset(
239
241
  filename=self.bgc_source["path"],
240
242
  start_time=self.start_time,
241
243
  end_time=self.end_time,
242
244
  climatology=self.bgc_source["climatology"],
243
245
  )
246
+ data.post_process()
244
247
  else:
245
248
  raise ValueError(
246
249
  'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
247
250
  )
248
251
 
249
- return bgc_data
252
+ return data
250
253
 
251
254
  def _write_into_dataset(self, data, d_meta):
252
255
 
@@ -259,7 +262,6 @@ class SurfaceForcing(ROMSToolsMixins):
259
262
  ds[var].attrs["units"] = d_meta[var]["units"]
260
263
 
261
264
  if self.use_coarse_grid:
262
- ds = ds.assign_coords({"lon": self.target_lon, "lat": self.target_lat})
263
265
  ds = ds.rename({"eta_coarse": "eta_rho", "xi_coarse": "xi_rho"})
264
266
 
265
267
  # Preserve absolute time coordinate for readability
@@ -295,6 +297,10 @@ class SurfaceForcing(ROMSToolsMixins):
295
297
  if data.climatology:
296
298
  ds["time"].attrs["cycle_length"] = 365.25
297
299
 
300
+ variables_to_drop = ["lat_rho", "lon_rho", "lat_coarse", "lon_coarse"]
301
+ existing_vars = [var for var in variables_to_drop if var in ds]
302
+ ds = ds.drop_vars(existing_vars)
303
+
298
304
  return ds
299
305
 
300
306
  def _write_into_datatree(self, data, bgc_data, d_meta):
@@ -392,6 +398,15 @@ class SurfaceForcing(ROMSToolsMixins):
392
398
  field = ds[varname].isel(time=time).load()
393
399
  title = field.long_name
394
400
 
401
+ # assign lat / lon
402
+ if self.use_coarse_grid:
403
+ field = field.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
404
+ field = field.where(self.grid.ds.mask_coarse)
405
+ else:
406
+ field = field.where(self.grid.ds.mask_rho)
407
+
408
+ field = field.assign_coords({"lon": self.target_lon, "lat": self.target_lat})
409
+
395
410
  # choose colorbar
396
411
  if varname in ["uwnd", "vwnd"]:
397
412
  vmax = max(field.max().values, -field.min().values)
@@ -412,7 +427,6 @@ class SurfaceForcing(ROMSToolsMixins):
412
427
  self.grid.ds,
413
428
  field=field,
414
429
  straddle=self.grid.straddle,
415
- coarse_grid=self.use_coarse_grid,
416
430
  title=title,
417
431
  kwargs=kwargs,
418
432
  c="g",