roms-tools 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. roms_tools/__init__.py +2 -1
  2. roms_tools/setup/boundary_forcing.py +21 -30
  3. roms_tools/setup/datasets.py +13 -21
  4. roms_tools/setup/grid.py +253 -139
  5. roms_tools/setup/initial_conditions.py +21 -3
  6. roms_tools/setup/mask.py +50 -4
  7. roms_tools/setup/nesting.py +575 -0
  8. roms_tools/setup/plot.py +214 -55
  9. roms_tools/setup/river_forcing.py +125 -29
  10. roms_tools/setup/surface_forcing.py +21 -8
  11. roms_tools/setup/tides.py +21 -3
  12. roms_tools/setup/topography.py +168 -35
  13. roms_tools/setup/utils.py +127 -21
  14. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -3
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -2
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/.zarray +1 -1
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_name/0 +0 -0
  18. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zmetadata +5 -6
  19. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zarray +2 -2
  20. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_tracer/.zattrs +1 -2
  21. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/0.0.0 +0 -0
  22. roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zarray +2 -2
  23. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_name/0 +0 -0
  24. roms_tools/tests/test_setup/test_datasets.py +2 -2
  25. roms_tools/tests/test_setup/test_nesting.py +489 -0
  26. roms_tools/tests/test_setup/test_river_forcing.py +50 -13
  27. roms_tools/tests/test_setup/test_surface_forcing.py +1 -0
  28. roms_tools/tests/test_setup/test_validation.py +2 -2
  29. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/METADATA +8 -4
  30. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/RECORD +51 -50
  31. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/WHEEL +1 -1
  32. roms_tools/_version.py +0 -2
  33. roms_tools/tests/test_setup/test_data/river_forcing.zarr/river_tracer/0.0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/river_forcing.zarr/tracer_name/0 +0 -0
  35. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zattrs +0 -0
  36. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/.zgroup +0 -0
  37. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zarray +0 -0
  38. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/.zattrs +0 -0
  39. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/abs_time/0 +0 -0
  40. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zarray +0 -0
  41. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/.zattrs +0 -0
  42. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/month/0 +0 -0
  43. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zarray +0 -0
  44. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/.zattrs +0 -0
  45. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_name/0 +0 -0
  46. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zarray +0 -0
  47. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/.zattrs +0 -0
  48. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_time/0 +0 -0
  49. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zarray +0 -0
  50. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/.zattrs +0 -0
  51. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/river_volume/0.0 +0 -0
  52. /roms_tools/tests/test_setup/test_data/{river_forcing.zarr → river_forcing_with_bgc.zarr}/tracer_name/.zattrs +0 -0
  53. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/LICENSE +0 -0
  54. {roms_tools-2.0.0.dist-info → roms_tools-2.1.0.dist-info}/top_level.txt +0 -0
roms_tools/setup/plot.py CHANGED
@@ -10,6 +10,7 @@ def _plot(
10
10
  straddle=False,
11
11
  c="red",
12
12
  title="",
13
+ with_dim_names=False,
13
14
  kwargs={},
14
15
  ):
15
16
  """Plots a grid or field on a map with optional depth contours.
@@ -65,22 +66,136 @@ def _plot(
65
66
 
66
67
  fig, ax = plt.subplots(1, 1, figsize=(13, 7), subplot_kw={"projection": trans})
67
68
 
68
- _add_plot_to_ax(
69
- ax, lon_deg, lat_deg, trans, field, depth_contours, c, title, kwargs=kwargs
69
+ if c is not None:
70
+ _add_boundary_to_ax(
71
+ ax, lon_deg, lat_deg, trans, c, with_dim_names=with_dim_names
72
+ )
73
+
74
+ if field is not None:
75
+ _add_field_to_ax(ax, lon_deg, lat_deg, field, depth_contours, kwargs=kwargs)
76
+
77
+ ax.coastlines(
78
+ resolution="50m", linewidth=0.5, color="black"
79
+ ) # add map of coastlines
80
+
81
+ # Add gridlines with labels for latitude and longitude
82
+ gridlines = ax.gridlines(
83
+ draw_labels=True, linewidth=0.5, color="gray", alpha=0.7, linestyle="--"
70
84
  )
85
+ gridlines.top_labels = False # Hide top labels
86
+ gridlines.right_labels = False # Hide right labels
87
+ gridlines.xlabel_style = {
88
+ "size": 10,
89
+ "color": "black",
90
+ } # Customize longitude label style
91
+ gridlines.ylabel_style = {
92
+ "size": 10,
93
+ "color": "black",
94
+ } # Customize latitude label style
95
+
96
+ ax.set_title(title)
97
+
98
+
99
+ def _add_boundary_to_ax(
100
+ ax, lon_deg, lat_deg, trans, c="red", label="", with_dim_names=False
101
+ ):
102
+ """Plots a grid or field on a map with optional depth contours.
103
+
104
+ Parameters
105
+ ----------
106
+ ax : matplotlib.axes._axes.Axes
107
+ The axes on which to plot the data (Cartopy axis with projection).
108
+
109
+ lon_deg : np.ndarray
110
+ Longitude values in degrees.
71
111
 
112
+ lat_deg : np.ndarray
113
+ Latitude values in degrees.
114
+
115
+ trans : cartopy.crs.Projection
116
+ The projection for transforming coordinates.
72
117
 
73
- def _add_plot_to_ax(
118
+ c : str, optional
119
+ Color of the grid boundary (default is 'red').
120
+ """
121
+ proj = ccrs.PlateCarree()
122
+
123
+ # find corners
124
+ corners = [
125
+ (lon_deg[0, 0], lat_deg[0, 0]),
126
+ (lon_deg[0, -1], lat_deg[0, -1]),
127
+ (lon_deg[-1, -1], lat_deg[-1, -1]),
128
+ (lon_deg[-1, 0], lat_deg[-1, 0]),
129
+ ]
130
+
131
+ # transform coordinates to projected space
132
+ transformed_corners = [trans.transform_point(lo, la, proj) for lo, la in corners]
133
+ transformed_lons, transformed_lats = zip(*transformed_corners)
134
+
135
+ ax.plot(
136
+ list(transformed_lons) + [transformed_lons[0]],
137
+ list(transformed_lats) + [transformed_lats[0]],
138
+ "o-",
139
+ c=c,
140
+ label=label,
141
+ )
142
+
143
+ if with_dim_names:
144
+ for i in range(len(corners)):
145
+ if i in [0, 2]:
146
+ dim_name = r"$\xi$"
147
+ else:
148
+ dim_name = r"$\eta$"
149
+ # Define start and end points for each edge
150
+ start_lon, start_lat = transformed_corners[i]
151
+ end_lon, end_lat = transformed_corners[(i + 1) % len(corners)]
152
+
153
+ # Compute midpoint
154
+ mid_lon = (start_lon + end_lon) / 2
155
+ mid_lat = (start_lat + end_lat) / 2
156
+
157
+ # Compute vector direction for arrow
158
+ arrow_dx = (end_lon - start_lon) * 0.4 # Scale arrow size
159
+ arrow_dy = (end_lat - start_lat) * 0.4
160
+
161
+ # Reverse arrow direction for edges 2 and 3
162
+ if i in [2, 3]:
163
+ arrow_dx *= -1
164
+ arrow_dy *= -1
165
+
166
+ # Add arrow
167
+ ax.annotate(
168
+ "",
169
+ xy=(mid_lon + arrow_dx, mid_lat + arrow_dy),
170
+ xytext=(mid_lon - arrow_dx, mid_lat - arrow_dy),
171
+ arrowprops=dict(arrowstyle="->", color=c, lw=1.5),
172
+ )
173
+
174
+ ax.text(
175
+ mid_lon,
176
+ mid_lat,
177
+ dim_name,
178
+ color=c,
179
+ fontsize=10,
180
+ ha="center",
181
+ va="center",
182
+ bbox=dict(
183
+ facecolor="white",
184
+ edgecolor="none",
185
+ alpha=0.7,
186
+ boxstyle="round,pad=0.2",
187
+ ),
188
+ )
189
+
190
+
191
+ def _add_field_to_ax(
74
192
  ax,
75
193
  lon_deg,
76
194
  lat_deg,
77
- trans,
78
- field=None,
195
+ field,
79
196
  depth_contours=False,
80
- c="red",
81
- title="",
82
197
  add_colorbar=True,
83
- kwargs=None,
198
+ kwargs={},
84
199
  ):
85
200
  """Plots a grid or field on a map with optional depth contours.
86
201
 
@@ -95,21 +210,12 @@ def _add_plot_to_ax(
95
210
  lat_deg : np.ndarray
96
211
  Latitude values in degrees.
97
212
 
98
- trans : cartopy.crs.Projection
99
- The projection for transforming coordinates.
100
-
101
213
  field : xarray.DataArray, optional
102
214
  Field data to plot (e.g., temperature, salinity). If None, only the grid is plotted.
103
215
 
104
216
  depth_contours : bool, optional
105
217
  If True, adds depth contours to the plot.
106
218
 
107
- c : str, optional
108
- Color of the grid boundary (default is 'red').
109
-
110
- title : str, optional
111
- Title of the plot.
112
-
113
219
  add_colorbar : bool, optional
114
220
  If True, add colobar.
115
221
 
@@ -118,48 +224,20 @@ def _add_plot_to_ax(
118
224
 
119
225
  Notes
120
226
  -----
121
- - If `field` is provided, a colorbar is added.
122
227
  - If `depth_contours` is True, the field’s `layer_depth` is used to add contours.
123
228
  """
124
229
  proj = ccrs.PlateCarree()
125
230
 
126
- # find corners
127
- corners = [
128
- (lon_deg[0, 0], lat_deg[0, 0]),
129
- (lon_deg[0, -1], lat_deg[0, -1]),
130
- (lon_deg[-1, -1], lat_deg[-1, -1]),
131
- (lon_deg[-1, 0], lat_deg[-1, 0]),
132
- ]
133
-
134
- # transform coordinates to projected space
135
- transformed_corners = [trans.transform_point(lo, la, proj) for lo, la in corners]
136
- transformed_lons, transformed_lats = zip(*transformed_corners)
137
-
138
- if c is not None:
139
- ax.plot(
140
- list(transformed_lons) + [transformed_lons[0]],
141
- list(transformed_lats) + [transformed_lats[0]],
142
- "o-",
143
- c=c,
144
- )
145
-
146
- ax.coastlines(
147
- resolution="50m", linewidth=0.5, color="black"
148
- ) # add map of coastlines
149
- ax.gridlines()
150
- ax.set_title(title)
151
-
152
- if field is not None:
153
- p = ax.pcolormesh(lon_deg, lat_deg, field, transform=proj, **kwargs)
154
- if hasattr(field, "long_name"):
155
- label = f"{field.long_name} [{field.units}]"
156
- elif hasattr(field, "Long_name"):
157
- # this is the case for matlab generated grids
158
- label = f"{field.Long_name} [{field.units}]"
159
- else:
160
- label = ""
161
- if add_colorbar:
162
- plt.colorbar(p, label=label)
231
+ p = ax.pcolormesh(lon_deg, lat_deg, field, transform=proj, **kwargs)
232
+ if hasattr(field, "long_name"):
233
+ label = f"{field.long_name} [{field.units}]"
234
+ elif hasattr(field, "Long_name"):
235
+ # this is the case for matlab generated grids
236
+ label = f"{field.Long_name} [{field.units}]"
237
+ else:
238
+ label = ""
239
+ if add_colorbar:
240
+ plt.colorbar(p, label=label)
163
241
 
164
242
  if depth_contours:
165
243
  cs = ax.contour(lon_deg, lat_deg, field.layer_depth, transform=proj, colors="k")
@@ -282,3 +360,84 @@ def _line_plot(field, title="", ax=None):
282
360
  field.plot(ax=ax)
283
361
  ax.set_title(title)
284
362
  ax.grid()
363
+
364
+
365
+ def _plot_nesting(parent_grid_ds, child_grid_ds, parent_straddle, with_dim_names=False):
366
+
367
+ parent_lon_deg = parent_grid_ds["lon_rho"]
368
+ parent_lat_deg = parent_grid_ds["lat_rho"]
369
+
370
+ child_lon_deg = child_grid_ds["lon_rho"]
371
+ child_lat_deg = child_grid_ds["lat_rho"]
372
+
373
+ if parent_straddle:
374
+ parent_lon_deg = xr.where(
375
+ parent_lon_deg > 180, parent_lon_deg - 360, parent_lon_deg
376
+ )
377
+ child_lon_deg = xr.where(
378
+ child_lon_deg > 180, child_lon_deg - 360, child_lon_deg
379
+ )
380
+
381
+ trans = _get_projection(parent_lon_deg, parent_lat_deg)
382
+
383
+ parent_lon_deg = parent_lon_deg.values
384
+ parent_lat_deg = parent_lat_deg.values
385
+ child_lon_deg = child_lon_deg.values
386
+ child_lat_deg = child_lat_deg.values
387
+
388
+ fig, ax = plt.subplots(1, 1, figsize=(13, 7), subplot_kw={"projection": trans})
389
+
390
+ _add_boundary_to_ax(
391
+ ax,
392
+ parent_lon_deg,
393
+ parent_lat_deg,
394
+ trans,
395
+ c="r",
396
+ label="parent grid",
397
+ with_dim_names=with_dim_names,
398
+ )
399
+
400
+ _add_boundary_to_ax(
401
+ ax,
402
+ child_lon_deg,
403
+ child_lat_deg,
404
+ trans,
405
+ c="g",
406
+ label="child grid",
407
+ with_dim_names=with_dim_names,
408
+ )
409
+
410
+ vmax = 3
411
+ vmin = 0
412
+ cmap = plt.colormaps.get_cmap("Blues")
413
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
414
+
415
+ _add_field_to_ax(
416
+ ax,
417
+ parent_lon_deg,
418
+ parent_lat_deg,
419
+ parent_grid_ds.mask_rho,
420
+ add_colorbar=False,
421
+ kwargs=kwargs,
422
+ )
423
+
424
+ ax.coastlines(
425
+ resolution="50m", linewidth=0.5, color="black"
426
+ ) # add map of coastlines
427
+
428
+ # Add gridlines with labels for latitude and longitude
429
+ gridlines = ax.gridlines(
430
+ draw_labels=True, linewidth=0.5, color="gray", alpha=0.7, linestyle="--"
431
+ )
432
+ gridlines.top_labels = False # Hide top labels
433
+ gridlines.right_labels = False # Hide right labels
434
+ gridlines.xlabel_style = {
435
+ "size": 10,
436
+ "color": "black",
437
+ } # Customize longitude label style
438
+ gridlines.ylabel_style = {
439
+ "size": 10,
440
+ "color": "black",
441
+ } # Customize latitude label style
442
+
443
+ ax.legend(loc="best")
@@ -16,8 +16,9 @@ from roms_tools.setup.utils import (
16
16
  save_datasets,
17
17
  _to_yaml,
18
18
  _from_yaml,
19
+ get_variable_metadata,
19
20
  )
20
- from roms_tools.setup.plot import _get_projection, _add_plot_to_ax
21
+ from roms_tools.setup.plot import _get_projection, _add_field_to_ax
21
22
  import cartopy.crs as ccrs
22
23
 
23
24
 
@@ -52,6 +53,8 @@ class RiverForcing:
52
53
  - "never": Do not compute climatology.
53
54
  - "always": Compute climatology for all rivers, regardless of missing data.
54
55
 
56
+ include_bgc : bool, optional
57
+ Whether to include BGC tracers. Defaults to `False`.
55
58
  model_reference_date : datetime, optional
56
59
  Reference date for the model. Default is January 1, 2000.
57
60
 
@@ -68,6 +71,7 @@ class RiverForcing:
68
71
  end_time: datetime
69
72
  source: Dict[str, Union[str, Path, List[Union[str, Path]]]] = None
70
73
  convert_to_climatology: str = "if_any_missing"
74
+ include_bgc: bool = False
71
75
  model_reference_date: datetime = datetime(2000, 1, 1)
72
76
 
73
77
  ds: xr.Dataset = field(init=False, repr=False)
@@ -87,7 +91,7 @@ class RiverForcing:
87
91
  object.__setattr__(self, "original_indices", original_indices)
88
92
 
89
93
  if len(original_indices["station"]) > 0:
90
- self.move_rivers_to_closest_coast(target_coords, data)
94
+ self._move_rivers_to_closest_coast(target_coords, data)
91
95
  ds = self._create_river_forcing(data)
92
96
  self._validate(ds)
93
97
 
@@ -138,13 +142,13 @@ class RiverForcing:
138
142
  return data
139
143
 
140
144
  def _create_river_forcing(self, data):
141
- """Create river forcing data for volume flux and tracers (temperature and
142
- salinity).
145
+ """Create river forcing data for volume flux and tracers (temperature, salinity,
146
+ BGC tracers).
143
147
 
144
- This method computes the river volume flux and associated tracers (temperature and salinity)
148
+ This method computes the river volume flux and associated tracers (temperature, salinity, BGC tracers)
145
149
  based on the provided input data. It generates a new `xarray.Dataset` that contains:
146
150
  - `river_volume`: The river volume flux, calculated as the product of river flux and a specified ratio, with units of m³/s.
147
- - `river_tracer`: A tracer array containing temperature and salinity values at each river station over time.
151
+ - `river_tracer`: A tracer array containing temperature, salinity, and BGC tracer values for each river over time.
148
152
 
149
153
  The method also handles climatological adjustments for missing or incomplete data, depending on the `convert_to_climatology` setting.
150
154
 
@@ -161,7 +165,7 @@ class RiverForcing:
161
165
  xr.Dataset
162
166
  A new `xarray.Dataset` containing the computed river forcing data. The dataset includes:
163
167
  - `river_volume`: A `DataArray` representing the river volume flux (m³/s).
164
- - `river_tracer`: A `DataArray` representing tracer data for temperature and salinity at each river station over time.
168
+ - `river_tracer`: A `DataArray` representing tracer data for temperature, salinity and BGC tracers (if specified) for each river over time.
165
169
  """
166
170
  if self.source["climatology"]:
167
171
  object.__setattr__(self, "climatology", True)
@@ -199,18 +203,64 @@ class RiverForcing:
199
203
  river_volume.coords["river_name"] = name
200
204
  ds["river_volume"] = river_volume
201
205
 
206
+ if self.include_bgc:
207
+ ntracers = 2 + 32
208
+ else:
209
+ ntracers = 2
202
210
  tracer_data = np.zeros(
203
- (len(ds.river_time), 2, len(ds.nriver)), dtype=np.float32
211
+ (len(ds.river_time), ntracers, len(ds.nriver)), dtype=np.float32
204
212
  )
205
213
  tracer_data[:, 0, :] = 17.0
206
214
  tracer_data[:, 1, :] = 1.0
215
+ tracer_data[:, 2:, :] = 0.0
207
216
 
208
217
  river_tracer = xr.DataArray(
209
218
  tracer_data, dims=("river_time", "ntracers", "nriver")
210
219
  )
211
220
  river_tracer.attrs["long_name"] = "River tracer data"
212
- river_tracer.attrs["units"] = "degrees C [temperature]; psu [salinity]"
213
- tracer_names = xr.DataArray(["temperature", "salinity"], dims="ntracers")
221
+
222
+ if self.include_bgc:
223
+ tracer_names = xr.DataArray(
224
+ [
225
+ "temp",
226
+ "salt",
227
+ "PO4",
228
+ "NO3",
229
+ "SiO3",
230
+ "NH4",
231
+ "Fe",
232
+ "Lig",
233
+ "O2",
234
+ "DIC",
235
+ "DIC_ALT_CO2",
236
+ "ALK",
237
+ "ALK_ALT_CO2",
238
+ "DOC",
239
+ "DON",
240
+ "DOP",
241
+ "DOPr",
242
+ "DONr",
243
+ "DOCr",
244
+ "zooC",
245
+ "spChl",
246
+ "spC",
247
+ "spP",
248
+ "spFe",
249
+ "spCaCO3",
250
+ "diatChl",
251
+ "diatC",
252
+ "diatP",
253
+ "diatFe",
254
+ "diatSi",
255
+ "diazChl",
256
+ "diazC",
257
+ "diazP",
258
+ "diazFe",
259
+ ],
260
+ dims="ntracers",
261
+ )
262
+ else:
263
+ tracer_names = xr.DataArray(["temp", "salt"], dims="ntracers")
214
264
  tracer_names.attrs["long_name"] = "Tracer name"
215
265
  river_tracer.coords["tracer_name"] = tracer_names
216
266
  ds["river_tracer"] = river_tracer
@@ -225,7 +275,7 @@ class RiverForcing:
225
275
 
226
276
  return ds
227
277
 
228
- def move_rivers_to_closest_coast(self, target_coords, data):
278
+ def _move_rivers_to_closest_coast(self, target_coords, data):
229
279
  """Move river mouths to the closest coastal grid cell.
230
280
 
231
281
  This method computes the closest coastal grid point to each river mouth
@@ -295,10 +345,10 @@ class RiverForcing:
295
345
  "xi_rho": indices[2],
296
346
  "name": names,
297
347
  }
298
- self.write_indices_into_grid_file(indices)
348
+ self._write_indices_into_grid_file(indices)
299
349
  object.__setattr__(self, "updated_indices", indices)
300
350
 
301
- def write_indices_into_grid_file(self, indices):
351
+ def _write_indices_into_grid_file(self, indices):
302
352
  """Writes river location indices into the grid dataset as the "river_flux"
303
353
  variable.
304
354
 
@@ -396,16 +446,28 @@ class RiverForcing:
396
446
  )
397
447
 
398
448
  for ax in axs:
399
- _add_plot_to_ax(
449
+ _add_field_to_ax(
400
450
  ax,
401
451
  lon_deg,
402
452
  lat_deg,
403
- trans,
404
453
  field,
405
- c=None,
406
454
  add_colorbar=False,
407
455
  kwargs=kwargs,
408
456
  )
457
+ # Add gridlines with labels for latitude and longitude
458
+ gridlines = ax.gridlines(
459
+ draw_labels=True, linewidth=0.5, color="gray", alpha=0.7, linestyle="--"
460
+ )
461
+ gridlines.top_labels = False # Hide top labels
462
+ gridlines.right_labels = False # Hide right labels
463
+ gridlines.xlabel_style = {
464
+ "size": 10,
465
+ "color": "black",
466
+ } # Customize longitude label style
467
+ gridlines.ylabel_style = {
468
+ "size": 10,
469
+ "color": "black",
470
+ } # Customize latitude label style
409
471
 
410
472
  for ax, indices in zip(axs, [self.original_indices, self.updated_indices]):
411
473
  for i in range(len(indices["name"])):
@@ -445,9 +507,43 @@ class RiverForcing:
445
507
  ----------
446
508
  var_name : str, optional
447
509
  The variable to plot. It can be one of the following:
448
- - 'river_volume' : Plot the river volume flux.
449
- - 'river_temperature' : Plot the river temperature (from the river_tracer).
450
- - 'river_salinity' : Plot the river salinity (from the river_tracer).
510
+
511
+ - 'river_volume' : river volume flux.
512
+ - 'river_temp' : river temperature (from river_tracer).
513
+ - 'river_salt' : river salinity (from river_tracer).
514
+ - 'river_PO4' : river PO4 (from river_tracer).
515
+ - 'river_NO3' : river NO3 (from river_tracer).
516
+ - 'river_SiO3' : river SiO3 (from river_tracer).
517
+ - 'river_NH4' : river NH4 (from river_tracer).
518
+ - 'river_Fe' : river Fe (from river_tracer).
519
+ - 'river_Lig' : river Lig (from river_tracer).
520
+ - 'river_O2' : river O2 (from river_tracer).
521
+ - 'river_DIC' : river DIC (from river_tracer).
522
+ - 'river_DIC_ALT_CO2' : river DIC_ALT_CO2 (from river_tracer).
523
+ - 'river_ALK' : river ALK (from river_tracer).
524
+ - 'river_ALK_ALT_CO2' : river ALK_ALT_CO2 (from river_tracer).
525
+ - 'river_DOC' : river DOC (from river_tracer).
526
+ - 'river_DON' : river DON (from river_tracer).
527
+ - 'river_DOP' : river DOP (from river_tracer).
528
+ - 'river_DOPr' : river DOPr (from river_tracer).
529
+ - 'river_DONr' : river DONr (from river_tracer).
530
+ - 'river_DOCr' : river DOCr (from river_tracer).
531
+ - 'river_zooC' : river zooC (from river_tracer).
532
+ - 'river_spChl' : river sphChl (from river_tracer).
533
+ - 'river_spC' : river spC (from river_tracer).
534
+ - 'river_spP' : river spP (from river_tracer).
535
+ - 'river_spFe' : river spFe (from river_tracer).
536
+ - 'river_spCaCO3' : river spCaCO3 (from river_tracer).
537
+ - 'river_diatChl' : river diatChl (from river_tracer).
538
+ - 'river_diatC' : river diatC (from river_tracer).
539
+ - 'river_diatP' : river diatP (from river_tracer).
540
+ - 'river_diatFe' : river diatFe (from river_tracer).
541
+ - 'river_diatSi' : river diatSi (from river_tracer).
542
+ - 'river_diazChl' : river diazChl (from river_tracer).
543
+ - 'river_diazC' : river diazC (from river_tracer).
544
+ - 'river_diazP' : river diazP (from river_tracer).
545
+ - 'river_diazFe' : river diazFe (from river_tracer).
546
+
451
547
  The default is 'river_volume'.
452
548
  """
453
549
  fig, ax = plt.subplots(1, 1, figsize=(9, 5))
@@ -463,14 +559,14 @@ class RiverForcing:
463
559
  field = self.ds[var_name]
464
560
  units = f"${self.ds.river_volume.units}$"
465
561
  long_name = self.ds[var_name].long_name
466
- elif var_name == "river_temperature":
467
- field = self.ds["river_tracer"].isel(ntracers=0)
468
- units = "degrees C"
469
- long_name = "River temperature"
470
- elif var_name == "river_salinity":
471
- field = self.ds["river_tracer"].isel(ntracers=1)
472
- units = "psu"
473
- long_name = "River salinity"
562
+ else:
563
+ d = get_variable_metadata()
564
+ var_name_wo_river = var_name.split("_")[1]
565
+ field = self.ds["river_tracer"].isel(
566
+ ntracers=self.ds.tracer_name == var_name_wo_river
567
+ )
568
+ units = d[var_name_wo_river]["units"]
569
+ long_name = f"River {d[var_name_wo_river]['long_name']}"
474
570
 
475
571
  for i in range(len(self.ds.nriver)):
476
572
 
@@ -519,10 +615,10 @@ class RiverForcing:
519
615
  ----------
520
616
  filepath : Union[str, Path]
521
617
  The base path and filename for the output files. The filenames will include the specified path and the `.nc` extension.
522
- If partitioning is used, additional indices will be appended to the filenames, e.g., `"filepath_YYYYMM.0.nc"`, `"filepath_YYYYMM.1.nc"`, etc.
618
+ If partitioning is used, additional indices will be appended to the filenames, e.g., `"filepath.0.nc"`, `"filepath.1.nc"`, etc.
523
619
 
524
620
  filepath_grid : Union[str, Path]
525
- The base path and filename for saving the grid file. This file is essential for including the `river_flux` field.
621
+ The base path and filename for saving the grid file.
526
622
 
527
623
  np_eta : int, optional
528
624
  The number of partitions along the `eta` direction. If `None`, no spatial partitioning is performed along the `eta` axis.
@@ -66,6 +66,10 @@ class SurfaceForcing:
66
66
  Reference date for the model. Default is January 1, 2000.
67
67
  use_dask: bool, optional
68
68
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
69
+ bypass_validation: bool, optional
70
+ Indicates whether to skip validation checks in the processed data. When set to True,
71
+ the validation process that ensures no NaN values exist at wet points
72
+ in the processed dataset is bypassed. Defaults to False.
69
73
 
70
74
  Examples
71
75
  --------
@@ -88,6 +92,7 @@ class SurfaceForcing:
88
92
  use_coarse_grid: bool = False
89
93
  model_reference_date: datetime = datetime(2000, 1, 1)
90
94
  use_dask: bool = False
95
+ bypass_validation: bool = False
91
96
 
92
97
  ds: xr.Dataset = field(init=False, repr=False)
93
98
 
@@ -117,7 +122,7 @@ class SurfaceForcing:
117
122
  data.ds[data.var_names[var_name]]
118
123
  )
119
124
 
120
- # rotation of velocities and interpolation to u/v points
125
+ # rotation of velocities
121
126
  if "uwnd" in self.variable_info and "vwnd" in self.variable_info:
122
127
  processed_fields["uwnd"], processed_fields["vwnd"] = rotate_velocities(
123
128
  processed_fields["uwnd"],
@@ -134,7 +139,8 @@ class SurfaceForcing:
134
139
 
135
140
  ds = self._write_into_dataset(processed_fields, data, d_meta)
136
141
 
137
- self._validate(ds)
142
+ if not self.bypass_validation:
143
+ self._validate(ds)
138
144
 
139
145
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
140
146
  for var_name in ds.data_vars:
@@ -262,10 +268,8 @@ class SurfaceForcing:
262
268
  correction_data = self._get_correction_data()
263
269
  # choose same subdomain as forcing data so that we can use same mask
264
270
  coords_correction = {
265
- correction_data.dim_names["latitude"]: data.ds[data.dim_names["latitude"]],
266
- correction_data.dim_names["longitude"]: data.ds[
267
- data.dim_names["longitude"]
268
- ],
271
+ "lat": data.ds[data.dim_names["latitude"]],
272
+ "lon": data.ds[data.dim_names["longitude"]],
269
273
  }
270
274
  correction_data.choose_subdomain(
271
275
  coords_correction, straddle=self.target_coords["straddle"]
@@ -541,7 +545,10 @@ class SurfaceForcing:
541
545
 
542
546
  @classmethod
543
547
  def from_yaml(
544
- cls, filepath: Union[str, Path], use_dask: bool = False
548
+ cls,
549
+ filepath: Union[str, Path],
550
+ use_dask: bool = False,
551
+ bypass_validation: bool = False,
545
552
  ) -> "SurfaceForcing":
546
553
  """Create an instance of the SurfaceForcing class from a YAML file.
547
554
 
@@ -551,6 +558,10 @@ class SurfaceForcing:
551
558
  The path to the YAML file from which the parameters will be read.
552
559
  use_dask: bool, optional
553
560
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
561
+ bypass_validation: bool, optional
562
+ Indicates whether to skip validation checks in the processed data. When set to True,
563
+ the validation process that ensures no NaN values exist at wet points
564
+ in the processed dataset is bypassed. Defaults to False.
554
565
 
555
566
  Returns
556
567
  -------
@@ -562,4 +573,6 @@ class SurfaceForcing:
562
573
  grid = Grid.from_yaml(filepath)
563
574
  params = _from_yaml(cls, filepath)
564
575
 
565
- return cls(grid=grid, **params, use_dask=use_dask)
576
+ return cls(
577
+ grid=grid, **params, use_dask=use_dask, bypass_validation=bypass_validation
578
+ )