roms-tools 0.20__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,18 @@
1
1
  import xarray as xr
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  import yaml
5
+ from datatree import DataTree
4
6
  import importlib.metadata
5
- from typing import Dict
7
+ from typing import Dict, Union, Optional
6
8
  from dataclasses import dataclass, field, asdict
7
9
  from roms_tools.setup.grid import Grid
8
10
  from roms_tools.setup.vertical_coordinate import VerticalCoordinate
11
+ from roms_tools.setup.mixins import ROMSToolsMixins
9
12
  from datetime import datetime
10
- from roms_tools.setup.datasets import Dataset
11
- from roms_tools.setup.fill import fill_and_interpolate
13
+ from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
12
14
  from roms_tools.setup.utils import (
13
15
  nan_check,
14
- interpolate_from_rho_to_u,
15
- interpolate_from_rho_to_v,
16
- extrapolate_deepest_to_bottom,
17
16
  )
18
17
  from roms_tools.setup.plot import _section_plot, _line_plot
19
18
  import calendar
@@ -22,7 +21,7 @@ import matplotlib.pyplot as plt
22
21
 
23
22
 
24
23
  @dataclass(frozen=True, kw_only=True)
25
- class BoundaryForcing:
24
+ class BoundaryForcing(ROMSToolsMixins):
26
25
  """
27
26
  Represents boundary forcing for ROMS.
28
27
 
@@ -38,22 +37,39 @@ class BoundaryForcing:
38
37
  End time of the desired boundary forcing data.
39
38
  boundaries : Dict[str, bool], optional
40
39
  Dictionary specifying which boundaries are forced (south, east, north, west). Default is all True.
40
+ physics_source : Dict[str, Union[str, None]]
41
+ Dictionary specifying the source of the physical boundary forcing data:
42
+ - "name" (str): Name of the data source (e.g., "GLORYS").
43
+ - "path" (str): Path to the physical data file. Can contain wildcards.
44
+ - "climatology" (bool): Indicates if the physical data is climatology data. Defaults to False.
45
+ bgc_source : Optional[Dict[str, Union[str, None]]]
46
+ Dictionary specifying the source of the biogeochemical (BGC) initial condition data:
47
+ - "name" (str): Name of the BGC data source (e.g., "CESM_REGRIDDED").
48
+ - "path" (str): Path to the BGC data file. Can contain wildcards.
49
+ - "climatology" (bool): Indicates if the BGC data is climatology data. Defaults to True.
41
50
  model_reference_date : datetime, optional
42
51
  Reference date for the model. Default is January 1, 2000.
43
- source : str, optional
44
- Source of the boundary forcing data. Default is "glorys".
45
- filename: str
46
- Path to the source data file. Can contain wildcards.
47
52
 
48
53
  Attributes
49
54
  ----------
50
55
  ds : xr.Dataset
51
56
  Xarray Dataset containing the atmospheric forcing data.
52
57
 
53
- Notes
54
- -----
55
- This class represents atmospheric forcing data used in ocean modeling. It provides a convenient
56
- interface to work with forcing data including shortwave radiation correction and river forcing.
58
+ Examples
59
+ --------
60
+ >>> boundary_forcing = BoundaryForcing(
61
+ ... grid=grid,
62
+ ... vertical_coordinate=vertical_coordinate,
63
+ ... boundaries={"south": True, "east": True, "north": False, "west": True},
64
+ ... start_time=datetime(2022, 1, 1),
65
+ ... end_time=datetime(2022, 1, 2),
66
+ ... physics_source={"name": "GLORYS", "path": "physics_data.nc"},
67
+ ... bgc_source={
68
+ ... "name": "CESM_REGRIDDED",
69
+ ... "path": "bgc_data.nc",
70
+ ... "climatology": True,
71
+ ... },
72
+ ... )
57
73
  """
58
74
 
59
75
  grid: Grid
@@ -68,50 +84,18 @@ class BoundaryForcing:
68
84
  "west": True,
69
85
  }
70
86
  )
87
+ physics_source: Dict[str, Union[str, None]]
88
+ bgc_source: Optional[Dict[str, Union[str, None]]] = None
71
89
  model_reference_date: datetime = datetime(2000, 1, 1)
72
- source: str = "glorys"
73
- filename: str
90
+
74
91
  ds: xr.Dataset = field(init=False, repr=False)
75
92
 
76
93
  def __post_init__(self):
77
94
 
78
- lon = self.grid.ds.lon_rho
79
- lat = self.grid.ds.lat_rho
80
- angle = self.grid.ds.angle
81
-
82
- if self.source == "glorys":
83
- dims = {
84
- "longitude": "longitude",
85
- "latitude": "latitude",
86
- "depth": "depth",
87
- "time": "time",
88
- }
89
-
90
- varnames = {
91
- "temp": "thetao",
92
- "salt": "so",
93
- "u": "uo",
94
- "v": "vo",
95
- "ssh": "zos",
96
- }
97
- data = Dataset(
98
- filename=self.filename,
99
- start_time=self.start_time,
100
- end_time=self.end_time,
101
- var_names=varnames.values(),
102
- dim_names=dims,
103
- )
104
-
105
- # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
106
- lon = xr.where(lon > 180, lon - 360, lon)
107
- straddle = True
108
- if not self.grid.straddle and abs(lon).min() > 5:
109
- lon = xr.where(lon < 0, lon + 360, lon)
110
- straddle = False
95
+ self._input_checks()
96
+ lon, lat, angle, straddle = super().get_target_lon_lat()
111
97
 
112
- # Restrict data to relevant subdomain to achieve better performance and to avoid discontinuous longitudes introduced by converting
113
- # to a different longitude range (+- 360 degrees). Discontinues longitudes can lead to artifacts in the interpolation process that
114
- # would not be detected by the nan_check function.
98
+ data = self._get_data()
115
99
  data.choose_subdomain(
116
100
  latitude_range=[lat.min().values, lat.max().values],
117
101
  longitude_range=[lon.min().values, lon.max().values],
@@ -119,252 +103,278 @@ class BoundaryForcing:
119
103
  straddle=straddle,
120
104
  )
121
105
 
122
- # extrapolate deepest value all the way to bottom ("flooding") to prepare for 3d interpolation
123
- for var in ["temp", "salt", "u", "v"]:
124
- data.ds[varnames[var]] = extrapolate_deepest_to_bottom(
125
- data.ds[varnames[var]], dims["depth"]
106
+ vars_2d = ["zeta"]
107
+ vars_3d = ["temp", "salt", "u", "v"]
108
+ data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
109
+ data_vars = super().process_velocities(data_vars, angle)
110
+ object.__setattr__(data, "data_vars", data_vars)
111
+
112
+ if self.bgc_source is not None:
113
+ bgc_data = self._get_bgc_data()
114
+ bgc_data.choose_subdomain(
115
+ latitude_range=[lat.min().values, lat.max().values],
116
+ longitude_range=[lon.min().values, lon.max().values],
117
+ margin=2,
118
+ straddle=straddle,
126
119
  )
127
120
 
128
- # interpolate onto desired grid
129
- fill_dims = [dims["latitude"], dims["longitude"]]
121
+ vars_2d = []
122
+ vars_3d = bgc_data.var_names.keys()
123
+ data_vars = super().regrid_data(bgc_data, vars_2d, vars_3d, lon, lat)
124
+ object.__setattr__(bgc_data, "data_vars", data_vars)
125
+ else:
126
+ bgc_data = None
130
127
 
131
- # 2d interpolation
132
- coords = {dims["latitude"]: lat, dims["longitude"]: lon}
133
- mask = xr.where(data.ds[varnames["ssh"]].isel(time=0).isnull(), 0, 1)
128
+ d_meta = super().get_variable_metadata()
129
+ bdry_coords, rename = super().get_boundary_info()
134
130
 
135
- ssh = fill_and_interpolate(
136
- data.ds[varnames["ssh"]].astype(np.float64),
137
- mask,
138
- fill_dims=fill_dims,
139
- coords=coords,
140
- method="linear",
141
- )
131
+ ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords, rename)
142
132
 
143
- # 3d interpolation
144
- coords = {
145
- dims["latitude"]: lat,
146
- dims["longitude"]: lon,
147
- dims["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
148
- }
149
- mask = xr.where(data.ds[varnames["temp"]].isel(time=0).isnull(), 0, 1)
150
-
151
- data_vars = {}
152
- # setting fillvalue_interp to None means that we allow extrapolation in the
153
- # interpolation step to avoid NaNs at the surface if the lowest depth in original
154
- # data is greater than zero
155
-
156
- for var in ["temp", "salt", "u", "v"]:
157
-
158
- data_vars[var] = fill_and_interpolate(
159
- data.ds[varnames[var]].astype(np.float64),
160
- mask,
161
- fill_dims=fill_dims,
162
- coords=coords,
163
- method="linear",
164
- fillvalue_interp=None,
165
- )
133
+ for direction in ["south", "east", "north", "west"]:
134
+ if self.boundaries[direction]:
135
+ nan_check(
136
+ ds["physics"][f"zeta_{direction}"].isel(bry_time=0),
137
+ self.grid.ds.mask_rho.isel(**bdry_coords["rho"][direction]),
138
+ )
166
139
 
167
- # rotate velocities to grid orientation
168
- u_rot = data_vars["u"] * np.cos(angle) + data_vars["v"] * np.sin(angle)
169
- v_rot = data_vars["v"] * np.cos(angle) - data_vars["u"] * np.sin(angle)
140
+ object.__setattr__(self, "ds", ds)
170
141
 
171
- # interpolate to u- and v-points
172
- u = interpolate_from_rho_to_u(u_rot)
173
- v = interpolate_from_rho_to_v(v_rot)
142
+ def _input_checks(self):
143
+
144
+ if "name" not in self.physics_source.keys():
145
+ raise ValueError("`physics_source` must include a 'name'.")
146
+ if "path" not in self.physics_source.keys():
147
+ raise ValueError("`physics_source` must include a 'path'.")
148
+ # set self.physics_source["climatology"] to False if not provided
149
+ object.__setattr__(
150
+ self,
151
+ "physics_source",
152
+ {
153
+ **self.physics_source,
154
+ "climatology": self.physics_source.get("climatology", False),
155
+ },
156
+ )
157
+ if self.bgc_source is not None:
158
+ if "name" not in self.bgc_source.keys():
159
+ raise ValueError(
160
+ "`bgc_source` must include a 'name' if it is provided."
161
+ )
162
+ if "path" not in self.bgc_source.keys():
163
+ raise ValueError(
164
+ "`bgc_source` must include a 'path' if it is provided."
165
+ )
166
+ # set self.bgc_source["climatology"] to True if not provided
167
+ object.__setattr__(
168
+ self,
169
+ "bgc_source",
170
+ {
171
+ **self.bgc_source,
172
+ "climatology": self.bgc_source.get("climatology", True),
173
+ },
174
+ )
174
175
 
175
- # 3d masks for ROMS domain
176
- umask = self.grid.ds.mask_u.expand_dims({"s_rho": u.s_rho})
177
- vmask = self.grid.ds.mask_v.expand_dims({"s_rho": v.s_rho})
176
+ def _get_data(self):
178
177
 
179
- u = u * umask
180
- v = v * vmask
178
+ if self.physics_source["name"] == "GLORYS":
179
+ data = GLORYSDataset(
180
+ filename=self.physics_source["path"],
181
+ start_time=self.start_time,
182
+ end_time=self.end_time,
183
+ climatology=self.physics_source["climatology"],
184
+ )
185
+ else:
186
+ raise ValueError(
187
+ 'Only "GLORYS" is a valid option for physics_source["name"].'
188
+ )
181
189
 
182
- # Compute barotropic velocity
190
+ return data
183
191
 
184
- # thicknesses
185
- dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
186
- dz = dz.rename({"s_w": "s_rho"})
187
- # thicknesses at u- and v-points
188
- dzu = interpolate_from_rho_to_u(dz)
189
- dzv = interpolate_from_rho_to_v(dz)
192
+ def _get_bgc_data(self):
190
193
 
191
- ubar = (dzu * u).sum(dim="s_rho") / dzu.sum(dim="s_rho")
192
- vbar = (dzv * v).sum(dim="s_rho") / dzv.sum(dim="s_rho")
194
+ if self.bgc_source["name"] == "CESM_REGRIDDED":
193
195
 
194
- # Boundary coordinates for rho-points
195
- bdry_coords_rho = {
196
- "south": {"eta_rho": 0},
197
- "east": {"xi_rho": -1},
198
- "north": {"eta_rho": -1},
199
- "west": {"xi_rho": 0},
200
- }
201
- # How to rename the dimensions at rho-points
202
- rename_rho = {
203
- "south": {"xi_rho": "xi_rho_south"},
204
- "east": {"eta_rho": "eta_rho_east"},
205
- "north": {"xi_rho": "xi_rho_north"},
206
- "west": {"eta_rho": "eta_rho_west"},
207
- }
196
+ data = CESMBGCDataset(
197
+ filename=self.bgc_source["path"],
198
+ start_time=self.start_time,
199
+ end_time=self.end_time,
200
+ climatology=self.bgc_source["climatology"],
201
+ )
202
+ data.post_process()
203
+ else:
204
+ raise ValueError(
205
+ 'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
206
+ )
208
207
 
209
- # Boundary coordinates for u-points
210
- bdry_coords_u = {
211
- "south": {"eta_rho": 0},
212
- "east": {"xi_u": -1},
213
- "north": {"eta_rho": -1},
214
- "west": {"xi_u": 0},
215
- }
216
- # How to rename the dimensions at u-points
217
- rename_u = {
218
- "south": {"xi_u": "xi_u_south"},
219
- "east": {"eta_rho": "eta_u_east"},
220
- "north": {"xi_u": "xi_u_north"},
221
- "west": {"eta_rho": "eta_u_west"},
222
- }
208
+ return data
223
209
 
224
- # Boundary coordinates for v-points
225
- bdry_coords_v = {
226
- "south": {"eta_v": 0},
227
- "east": {"xi_rho": -1},
228
- "north": {"eta_v": -1},
229
- "west": {"xi_rho": 0},
230
- }
231
- # How to rename the dimensions at v-points
232
- rename_v = {
233
- "south": {"xi_rho": "xi_v_south"},
234
- "east": {"eta_v": "eta_v_east"},
235
- "north": {"xi_rho": "xi_v_north"},
236
- "west": {"eta_v": "eta_v_west"},
237
- }
210
+ def _write_into_dataset(self, data, d_meta, bdry_coords, rename):
238
211
 
212
+ # save in new dataset
239
213
  ds = xr.Dataset()
240
214
 
241
215
  for direction in ["south", "east", "north", "west"]:
242
-
243
216
  if self.boundaries[direction]:
244
217
 
245
- ds[f"zeta_{direction}"] = (
246
- ssh.isel(**bdry_coords_rho[direction])
247
- .rename(**rename_rho[direction])
248
- .astype(np.float32)
249
- )
250
- ds[f"zeta_{direction}"].attrs[
251
- "long_name"
252
- ] = f"{direction}ern boundary sea surface height"
253
- ds[f"zeta_{direction}"].attrs["units"] = "m"
254
-
255
- ds[f"temp_{direction}"] = (
256
- data_vars["temp"]
257
- .isel(**bdry_coords_rho[direction])
258
- .rename(**rename_rho[direction])
259
- .astype(np.float32)
260
- )
261
- ds[f"temp_{direction}"].attrs[
262
- "long_name"
263
- ] = f"{direction}ern boundary potential temperature"
264
- ds[f"temp_{direction}"].attrs["units"] = "Celsius"
265
-
266
- ds[f"salt_{direction}"] = (
267
- data_vars["salt"]
268
- .isel(**bdry_coords_rho[direction])
269
- .rename(**rename_rho[direction])
270
- .astype(np.float32)
271
- )
272
- ds[f"salt_{direction}"].attrs[
273
- "long_name"
274
- ] = f"{direction}ern boundary salinity"
275
- ds[f"salt_{direction}"].attrs["units"] = "PSU"
276
-
277
- ds[f"u_{direction}"] = (
278
- u.isel(**bdry_coords_u[direction])
279
- .rename(**rename_u[direction])
280
- .astype(np.float32)
281
- )
282
- ds[f"u_{direction}"].attrs[
283
- "long_name"
284
- ] = f"{direction}ern boundary u-flux component"
285
- ds[f"u_{direction}"].attrs["units"] = "m/s"
286
-
287
- ds[f"v_{direction}"] = (
288
- v.isel(**bdry_coords_v[direction])
289
- .rename(**rename_v[direction])
290
- .astype(np.float32)
291
- )
292
- ds[f"v_{direction}"].attrs[
293
- "long_name"
294
- ] = f"{direction}ern boundary v-flux component"
295
- ds[f"v_{direction}"].attrs["units"] = "m/s"
296
-
297
- ds[f"ubar_{direction}"] = (
298
- ubar.isel(**bdry_coords_u[direction])
299
- .rename(**rename_u[direction])
300
- .astype(np.float32)
301
- )
302
- ds[f"ubar_{direction}"].attrs[
303
- "long_name"
304
- ] = f"{direction}ern boundary vertically integrated u-flux component"
305
- ds[f"ubar_{direction}"].attrs["units"] = "m/s"
306
-
307
- ds[f"vbar_{direction}"] = (
308
- vbar.isel(**bdry_coords_v[direction])
309
- .rename(**rename_v[direction])
310
- .astype(np.float32)
311
- )
312
- ds[f"vbar_{direction}"].attrs[
313
- "long_name"
314
- ] = f"{direction}ern boundary vertically integrated v-flux component"
315
- ds[f"vbar_{direction}"].attrs["units"] = "m/s"
218
+ for var in data.data_vars.keys():
219
+ if var in ["u", "ubar"]:
220
+ ds[f"{var}_{direction}"] = (
221
+ data.data_vars[var]
222
+ .isel(**bdry_coords["u"][direction])
223
+ .rename(**rename["u"][direction])
224
+ .astype(np.float32)
225
+ )
226
+ elif var in ["v", "vbar"]:
227
+ ds[f"{var}_{direction}"] = (
228
+ data.data_vars[var]
229
+ .isel(**bdry_coords["v"][direction])
230
+ .rename(**rename["v"][direction])
231
+ .astype(np.float32)
232
+ )
233
+ else:
234
+ ds[f"{var}_{direction}"] = (
235
+ data.data_vars[var]
236
+ .isel(**bdry_coords["rho"][direction])
237
+ .rename(**rename["rho"][direction])
238
+ .astype(np.float32)
239
+ )
240
+ ds[f"{var}_{direction}"].attrs[
241
+ "long_name"
242
+ ] = f"{direction}ern boundary {d_meta[var]['long_name']}"
243
+ ds[f"{var}_{direction}"].attrs["units"] = d_meta[var]["units"]
244
+
245
+ # Gracefully handle dropping variables that might not be present
246
+ variables_to_drop = [
247
+ "s_rho",
248
+ "layer_depth_rho",
249
+ "layer_depth_u",
250
+ "layer_depth_v",
251
+ "interface_depth_rho",
252
+ "interface_depth_u",
253
+ "interface_depth_v",
254
+ "lat_rho",
255
+ "lon_rho",
256
+ "lat_u",
257
+ "lon_u",
258
+ "lat_v",
259
+ "lon_v",
260
+ ]
261
+ existing_vars = [var for var in variables_to_drop if var in ds]
262
+ ds = ds.drop_vars(existing_vars)
263
+
264
+ # Preserve absolute time coordinate for readability
265
+ ds = ds.assign_coords({"abs_time": ds["time"]})
266
+
267
+ # Convert the time coordinate to the format expected by ROMS
268
+ if data.climatology:
269
+ # Convert to pandas TimedeltaIndex
270
+ timedelta_index = pd.to_timedelta(ds["time"].values)
271
+ # Determine the start of the year for the base_datetime
272
+ start_of_year = datetime(self.model_reference_date.year, 1, 1)
273
+ # Calculate the offset from midnight of the new year
274
+ offset = self.model_reference_date - start_of_year
275
+ bry_time = xr.DataArray(
276
+ timedelta_index - offset,
277
+ dims="time",
278
+ )
279
+ else:
280
+ # TODO: Check if we need to convert from 12:00:00 to 00:00:00 as in matlab scripts
281
+ bry_time = ds["time"] - np.datetime64(self.model_reference_date)
282
+
283
+ ds = ds.assign_coords({"bry_time": bry_time})
284
+ ds["bry_time"].attrs[
285
+ "long_name"
286
+ ] = f"nanoseconds since {np.datetime_as_string(np.datetime64(self.model_reference_date), unit='ns')}"
287
+ ds["bry_time"].encoding["units"] = "nanoseconds"
288
+ ds = ds.swap_dims({"time": "bry_time"})
289
+ ds = ds.drop_vars("time")
290
+
291
+ if data.climatology:
292
+ ds["bry_time"].attrs["cycle_length"] = 365.25
293
+
294
+ return ds
295
+
296
+ def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords, rename):
297
+
298
+ ds = self._add_global_metadata()
299
+ ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
300
+ ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
301
+
302
+ ds = DataTree(name="root", data=ds)
303
+
304
+ ds_physics = self._write_into_dataset(data, d_meta, bdry_coords, rename)
305
+ ds_physics = self._add_coordinates(bdry_coords, rename, ds_physics)
306
+ ds_physics = self._add_global_metadata(ds_physics)
307
+ ds_physics.attrs["physics_source"] = self.physics_source["name"]
308
+
309
+ ds_physics = DataTree(name="physics", parent=ds, data=ds_physics)
310
+
311
+ if bgc_data:
312
+ ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords, rename)
313
+ ds_bgc = self._add_coordinates(bdry_coords, rename, ds_bgc)
314
+ ds_bgc = self._add_global_metadata(ds_bgc)
315
+ ds_bgc.attrs["bgc_source"] = self.bgc_source["name"]
316
+ ds_bgc = DataTree(name="bgc", parent=ds, data=ds_bgc)
316
317
 
317
- # assign the correct depth coordinates
318
+ return ds
319
+
320
+ def _add_coordinates(self, bdry_coords, rename, ds=None):
321
+
322
+ if ds is None:
323
+ ds = xr.Dataset()
324
+
325
+ for direction in ["south", "east", "north", "west"]:
326
+
327
+ if self.boundaries[direction]:
318
328
 
319
329
  lat_rho = self.grid.ds.lat_rho.isel(
320
- **bdry_coords_rho[direction]
321
- ).rename(**rename_rho[direction])
330
+ **bdry_coords["rho"][direction]
331
+ ).rename(**rename["rho"][direction])
322
332
  lon_rho = self.grid.ds.lon_rho.isel(
323
- **bdry_coords_rho[direction]
324
- ).rename(**rename_rho[direction])
333
+ **bdry_coords["rho"][direction]
334
+ ).rename(**rename["rho"][direction])
325
335
  layer_depth_rho = (
326
336
  self.vertical_coordinate.ds["layer_depth_rho"]
327
- .isel(**bdry_coords_rho[direction])
328
- .rename(**rename_rho[direction])
337
+ .isel(**bdry_coords["rho"][direction])
338
+ .rename(**rename["rho"][direction])
329
339
  )
330
340
  interface_depth_rho = (
331
341
  self.vertical_coordinate.ds["interface_depth_rho"]
332
- .isel(**bdry_coords_rho[direction])
333
- .rename(**rename_rho[direction])
342
+ .isel(**bdry_coords["rho"][direction])
343
+ .rename(**rename["rho"][direction])
334
344
  )
335
345
 
336
- lat_u = self.grid.ds.lat_u.isel(**bdry_coords_u[direction]).rename(
337
- **rename_u[direction]
346
+ lat_u = self.grid.ds.lat_u.isel(**bdry_coords["u"][direction]).rename(
347
+ **rename["u"][direction]
338
348
  )
339
- lon_u = self.grid.ds.lon_u.isel(**bdry_coords_u[direction]).rename(
340
- **rename_u[direction]
349
+ lon_u = self.grid.ds.lon_u.isel(**bdry_coords["u"][direction]).rename(
350
+ **rename["u"][direction]
341
351
  )
342
352
  layer_depth_u = (
343
353
  self.vertical_coordinate.ds["layer_depth_u"]
344
- .isel(**bdry_coords_u[direction])
345
- .rename(**rename_u[direction])
354
+ .isel(**bdry_coords["u"][direction])
355
+ .rename(**rename["u"][direction])
346
356
  )
347
357
  interface_depth_u = (
348
358
  self.vertical_coordinate.ds["interface_depth_u"]
349
- .isel(**bdry_coords_u[direction])
350
- .rename(**rename_u[direction])
359
+ .isel(**bdry_coords["u"][direction])
360
+ .rename(**rename["u"][direction])
351
361
  )
352
362
 
353
- lat_v = self.grid.ds.lat_v.isel(**bdry_coords_v[direction]).rename(
354
- **rename_v[direction]
363
+ lat_v = self.grid.ds.lat_v.isel(**bdry_coords["v"][direction]).rename(
364
+ **rename["v"][direction]
355
365
  )
356
- lon_v = self.grid.ds.lon_v.isel(**bdry_coords_v[direction]).rename(
357
- **rename_v[direction]
366
+ lon_v = self.grid.ds.lon_v.isel(**bdry_coords["v"][direction]).rename(
367
+ **rename["v"][direction]
358
368
  )
359
369
  layer_depth_v = (
360
370
  self.vertical_coordinate.ds["layer_depth_v"]
361
- .isel(**bdry_coords_v[direction])
362
- .rename(**rename_v[direction])
371
+ .isel(**bdry_coords["v"][direction])
372
+ .rename(**rename["v"][direction])
363
373
  )
364
374
  interface_depth_v = (
365
375
  self.vertical_coordinate.ds["interface_depth_v"]
366
- .isel(**bdry_coords_v[direction])
367
- .rename(**rename_v[direction])
376
+ .isel(**bdry_coords["v"][direction])
377
+ .rename(**rename["v"][direction])
368
378
  )
369
379
 
370
380
  ds = ds.assign_coords(
@@ -384,46 +394,31 @@ class BoundaryForcing:
384
394
  }
385
395
  )
386
396
 
387
- ds = ds.drop_vars(
388
- [
389
- "layer_depth_rho",
390
- "layer_depth_u",
391
- "layer_depth_v",
392
- "interface_depth_rho",
393
- "interface_depth_u",
394
- "interface_depth_v",
395
- "lat_rho",
396
- "lon_rho",
397
- "lat_u",
398
- "lon_u",
399
- "lat_v",
400
- "lon_v",
401
- "s_rho",
402
- ]
403
- )
404
-
405
- # deal with time dimension
406
- if dims["time"] != "time":
407
- ds = ds.rename({dims["time"]: "time"})
408
-
409
- # Translate the time coordinate to days since the model reference date
410
- # TODO: Check if we need to convert from 12:00:00 to 00:00:00 as in matlab scripts
411
- model_reference_date = np.datetime64(self.model_reference_date)
412
-
413
- # Convert the time coordinate to the format expected by ROMS (days since model reference date)
414
- bry_time = ds["time"] - model_reference_date
415
- ds = ds.assign_coords(bry_time=("time", bry_time.data))
416
- ds["bry_time"].attrs[
417
- "long_name"
418
- ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
419
-
420
- ds["theta_s"] = self.vertical_coordinate.ds["theta_s"]
421
- ds["theta_b"] = self.vertical_coordinate.ds["theta_b"]
422
- ds["Tcline"] = self.vertical_coordinate.ds["Tcline"]
423
- ds["hc"] = self.vertical_coordinate.ds["hc"]
424
- ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
425
- ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
426
-
397
+ # Gracefully handle dropping variables that might not be present
398
+ variables_to_drop = [
399
+ "s_rho",
400
+ "layer_depth_rho",
401
+ "layer_depth_u",
402
+ "layer_depth_v",
403
+ "interface_depth_rho",
404
+ "interface_depth_u",
405
+ "interface_depth_v",
406
+ "lat_rho",
407
+ "lon_rho",
408
+ "lat_u",
409
+ "lon_u",
410
+ "lat_v",
411
+ "lon_v",
412
+ ]
413
+ existing_vars = [var for var in variables_to_drop if var in ds]
414
+ ds = ds.drop_vars(existing_vars)
415
+
416
+ return ds
417
+
418
+ def _add_global_metadata(self, ds=None):
419
+
420
+ if ds is None:
421
+ ds = xr.Dataset()
427
422
  ds.attrs["title"] = "ROMS boundary forcing file created by ROMS-Tools"
428
423
  # Include the version of roms-tools
429
424
  try:
@@ -434,15 +429,13 @@ class BoundaryForcing:
434
429
  ds.attrs["start_time"] = str(self.start_time)
435
430
  ds.attrs["end_time"] = str(self.end_time)
436
431
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
437
- ds.attrs["source"] = self.source
438
432
 
439
- object.__setattr__(self, "ds", ds)
433
+ ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
434
+ ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
435
+ ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
436
+ ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
440
437
 
441
- for direction in ["south", "east", "north", "west"]:
442
- nan_check(
443
- ds[f"zeta_{direction}"].isel(time=0),
444
- self.grid.ds.mask_rho.isel(**bdry_coords_rho[direction]),
445
- )
438
+ return ds
446
439
 
447
440
  def plot(
448
441
  self,
@@ -457,14 +450,45 @@ class BoundaryForcing:
457
450
  ----------
458
451
  varname : str
459
452
  The name of the initial conditions field to plot. Options include:
460
- - "temp_{direction}": Potential temperature.
461
- - "salt_{direction}": Salinity.
462
- - "zeta_{direction}": Sea surface height.
463
- - "u_{direction}": u-flux component.
464
- - "v_{direction}": v-flux component.
465
- - "ubar_{direction}": Vertically integrated u-flux component.
466
- - "vbar_{direction}": Vertically integrated v-flux component.
467
- where {direction} can be one of ["south", "east", "north", "west"].
453
+ - "temp_{direction}": Potential temperature, where {direction} can be one of ["south", "east", "north", "west"].
454
+ - "salt_{direction}": Salinity, where {direction} can be one of ["south", "east", "north", "west"].
455
+ - "zeta_{direction}": Sea surface height, where {direction} can be one of ["south", "east", "north", "west"].
456
+ - "u_{direction}": u-flux component, where {direction} can be one of ["south", "east", "north", "west"].
457
+ - "v_{direction}": v-flux component, where {direction} can be one of ["south", "east", "north", "west"].
458
+ - "ubar_{direction}": Vertically integrated u-flux component, where {direction} can be one of ["south", "east", "north", "west"].
459
+ - "vbar_{direction}": Vertically integrated v-flux component, where {direction} can be one of ["south", "east", "north", "west"].
460
+ - "PO4_{direction}": Dissolved Inorganic Phosphate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
461
+ - "NO3_{direction}": Dissolved Inorganic Nitrate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
462
+ - "SiO3_{direction}": Dissolved Inorganic Silicate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
463
+ - "NH4_{direction}": Dissolved Ammonia (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
464
+ - "Fe_{direction}": Dissolved Inorganic Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
465
+ - "Lig_{direction}": Iron Binding Ligand (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
466
+ - "O2_{direction}": Dissolved Oxygen (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
467
+ - "DIC_{direction}": Dissolved Inorganic Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
468
+ - "DIC_ALT_CO2_{direction}": Dissolved Inorganic Carbon, Alternative CO2 (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
469
+ - "ALK_{direction}": Alkalinity (meq/m³), where {direction} can be one of ["south", "east", "north", "west"].
470
+ - "ALK_ALT_CO2_{direction}": Alkalinity, Alternative CO2 (meq/m³), where {direction} can be one of ["south", "east", "north", "west"].
471
+ - "DOC_{direction}": Dissolved Organic Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
472
+ - "DON_{direction}": Dissolved Organic Nitrogen (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
473
+ - "DOP_{direction}": Dissolved Organic Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
474
+ - "DOPr_{direction}": Refractory Dissolved Organic Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
475
+ - "DONr_{direction}": Refractory Dissolved Organic Nitrogen (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
476
+ - "DOCr_{direction}": Refractory Dissolved Organic Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
477
+ - "zooC_{direction}": Zooplankton Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
478
+ - "spChl_{direction}": Small Phytoplankton Chlorophyll (mg/m³), where {direction} can be one of ["south", "east", "north", "west"].
479
+ - "spC_{direction}": Small Phytoplankton Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
480
+ - "spP_{direction}": Small Phytoplankton Phosphorous (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
481
+ - "spFe_{direction}": Small Phytoplankton Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
482
+ - "spCaCO3_{direction}": Small Phytoplankton CaCO3 (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
483
+ - "diatChl_{direction}": Diatom Chlorophyll (mg/m³), where {direction} can be one of ["south", "east", "north", "west"].
484
+ - "diatC_{direction}": Diatom Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
485
+ - "diatP_{direction}": Diatom Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
486
+ - "diatFe_{direction}": Diatom Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
487
+ - "diatSi_{direction}": Diatom Silicate (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
488
+ - "diazChl_{direction}": Diazotroph Chlorophyll (mg/m³), where {direction} can be one of ["south", "east", "north", "west"].
489
+ - "diazC_{direction}": Diazotroph Carbon (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
490
+ - "diazP_{direction}": Diazotroph Phosphorus (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
491
+ - "diazFe_{direction}": Diazotroph Iron (mmol/m³), where {direction} can be one of ["south", "east", "north", "west"].
468
492
  time : int, optional
469
493
  The time index to plot. Default is 0.
470
494
  layer_contours : bool, optional
@@ -482,8 +506,17 @@ class BoundaryForcing:
482
506
  If the specified varname is not one of the valid options.
483
507
  """
484
508
 
485
- field = self.ds[varname].isel(time=time).load()
509
+ if varname in self.ds["physics"]:
510
+ ds = self.ds["physics"]
511
+ else:
512
+ if "bgc" in self.ds and varname in self.ds["bgc"]:
513
+ ds = self.ds["bgc"]
514
+ else:
515
+ raise ValueError(
516
+ f"Variable '{varname}' is not found in 'physics' or 'bgc' datasets."
517
+ )
486
518
 
519
+ field = ds[varname].isel(bry_time=time).load()
487
520
  title = field.long_name
488
521
 
489
522
  # chose colorbar
@@ -494,7 +527,10 @@ class BoundaryForcing:
494
527
  else:
495
528
  vmax = field.max().values
496
529
  vmin = field.min().values
497
- cmap = plt.colormaps.get_cmap("YlOrRd")
530
+ if varname.startswith(("temp", "salt")):
531
+ cmap = plt.colormaps.get_cmap("YlOrRd")
532
+ else:
533
+ cmap = plt.colormaps.get_cmap("YlGn")
498
534
  cmap.set_bad(color="gray")
499
535
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
500
536
 
@@ -507,13 +543,13 @@ class BoundaryForcing:
507
543
  ]
508
544
  try:
509
545
  interface_depth = next(
510
- self.ds[depth_label]
511
- for depth_label in self.ds.coords
546
+ ds[depth_label]
547
+ for depth_label in ds.coords
512
548
  if any(
513
549
  depth_label.startswith(prefix) for prefix in depths_to_check
514
550
  )
515
551
  and (
516
- set(self.ds[depth_label].dims) - {"s_w"}
552
+ set(ds[depth_label].dims) - {"s_w"}
517
553
  == set(field.dims) - {"s_rho"}
518
554
  )
519
555
  )
@@ -563,40 +599,50 @@ class BoundaryForcing:
563
599
  filenames = []
564
600
  writes = []
565
601
 
566
- # Group dataset by year
567
- gb = self.ds.groupby("time.year")
568
-
569
- for year, group_ds in gb:
570
- # Further group each yearly group by month
571
- sub_gb = group_ds.groupby("time.month")
572
-
573
- for month, ds in sub_gb:
574
- # Chunk the dataset by the specified time chunk size
575
- ds = ds.chunk({"time": time_chunk_size})
576
- datasets.append(ds)
577
-
578
- # Determine the number of days in the month
579
- num_days_in_month = calendar.monthrange(year, month)[1]
580
- first_day = ds.time.dt.day.values[0]
581
- last_day = ds.time.dt.day.values[-1]
582
-
583
- # Create filename based on whether the dataset contains a full month
584
- if first_day == 1 and last_day == num_days_in_month:
585
- # Full month format: "filepath.YYYYMM.nc"
586
- year_month_str = f"{year}{month:02}"
587
- filename = f"{filepath}.{year_month_str}.nc"
602
+ for node in ["physics", "bgc"]:
603
+ if node in self.ds:
604
+ ds = self.ds[node].to_dataset()
605
+ # copy vertical coordinate variables from parent to children because I believe this is info that ROMS needs
606
+ for var in self.ds.data_vars:
607
+ ds[var] = self.ds[var]
608
+ if hasattr(ds["bry_time"], "cycle_length"):
609
+ filename = f"{filepath}_{node}_clim.nc"
610
+ filenames.append(filename)
611
+ datasets.append(ds)
588
612
  else:
589
- # Partial month format: "filepath.YYYYMMDD-DD.nc"
590
- year_month_day_str = f"{year}{month:02}{first_day:02}-{last_day:02}"
591
- filename = f"{filepath}.{year_month_day_str}.nc"
592
- filenames.append(filename)
613
+ # Group dataset by year
614
+ gb = ds.groupby("abs_time.year")
615
+
616
+ for year, group_ds in gb:
617
+ # Further group each yearly group by month
618
+ sub_gb = group_ds.groupby("abs_time.month")
619
+
620
+ for month, ds in sub_gb:
621
+ # Chunk the dataset by the specified time chunk size
622
+ ds = ds.chunk({"bry_time": time_chunk_size})
623
+ datasets.append(ds)
624
+
625
+ # Determine the number of days in the month
626
+ num_days_in_month = calendar.monthrange(year, month)[1]
627
+ first_day = ds.abs_time.dt.day.values[0]
628
+ last_day = ds.abs_time.dt.day.values[-1]
629
+
630
+ # Create filename based on whether the dataset contains a full month
631
+ if first_day == 1 and last_day == num_days_in_month:
632
+ # Full month format: "filepath_physics_YYYYMM.nc"
633
+ year_month_str = f"{year}{month:02}"
634
+ filename = f"{filepath}_{node}_{year_month_str}.nc"
635
+ else:
636
+ # Partial month format: "filepath_physics_YYYYMMDD-DD.nc"
637
+ year_month_day_str = (
638
+ f"{year}{month:02}{first_day:02}-{last_day:02}"
639
+ )
640
+ filename = f"{filepath}_{node}_{year_month_day_str}.nc"
641
+ filenames.append(filename)
593
642
 
594
643
  print("Saving the following files:")
595
- for filename in filenames:
596
- print(filename)
597
-
598
644
  for ds, filename in zip(datasets, filenames):
599
-
645
+ print(filename)
600
646
  # Prepare the dataset for writing to a netCDF file without immediately computing
601
647
  write = ds.to_netcdf(filename, compute=False)
602
648
  writes.append(write)
@@ -637,12 +683,12 @@ class BoundaryForcing:
637
683
 
638
684
  boundary_forcing_data = {
639
685
  "BoundaryForcing": {
640
- "filename": self.filename,
641
686
  "start_time": self.start_time.isoformat(),
642
687
  "end_time": self.end_time.isoformat(),
643
- "model_reference_date": self.model_reference_date.isoformat(),
644
- "source": self.source,
645
688
  "boundaries": self.boundaries,
689
+ "physics_source": self.physics_source,
690
+ "bgc_source": self.bgc_source,
691
+ "model_reference_date": self.model_reference_date.isoformat(),
646
692
  }
647
693
  }
648
694