roms-tools 1.7.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. roms_tools/_version.py +1 -1
  2. roms_tools/setup/boundary_forcing.py +253 -144
  3. roms_tools/setup/datasets.py +216 -48
  4. roms_tools/setup/download.py +13 -17
  5. roms_tools/setup/grid.py +561 -512
  6. roms_tools/setup/initial_conditions.py +148 -30
  7. roms_tools/setup/mask.py +69 -0
  8. roms_tools/setup/plot.py +4 -8
  9. roms_tools/setup/regrid.py +4 -2
  10. roms_tools/setup/surface_forcing.py +11 -18
  11. roms_tools/setup/tides.py +9 -12
  12. roms_tools/setup/topography.py +92 -128
  13. roms_tools/setup/utils.py +49 -25
  14. roms_tools/setup/vertical_coordinate.py +5 -16
  15. roms_tools/tests/test_setup/test_boundary_forcing.py +10 -5
  16. roms_tools/tests/test_setup/test_data/grid.zarr/.zattrs +0 -1
  17. roms_tools/tests/test_setup/test_data/grid.zarr/.zmetadata +56 -201
  18. roms_tools/tests/test_setup/test_data/grid.zarr/Cs_r/.zattrs +1 -1
  19. roms_tools/tests/test_setup/test_data/grid.zarr/Cs_w/.zattrs +1 -1
  20. roms_tools/tests/test_setup/test_data/grid.zarr/{interface_depth_rho → sigma_r}/.zarray +2 -6
  21. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_r/.zattrs +7 -0
  22. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_r/0 +0 -0
  23. roms_tools/tests/test_setup/test_data/grid.zarr/{interface_depth_u → sigma_w}/.zarray +2 -6
  24. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/.zattrs +7 -0
  25. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/0 +0 -0
  26. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zattrs +1 -2
  27. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zmetadata +58 -203
  28. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/Cs_r/.zattrs +1 -1
  29. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/Cs_w/.zattrs +1 -1
  30. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/.zattrs +1 -1
  31. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/0.0 +0 -0
  32. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_coarse/0.0 +0 -0
  33. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_rho/0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_u/0.0 +0 -0
  35. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_v/0.0 +0 -0
  36. roms_tools/tests/test_setup/test_data/{grid.zarr/interface_depth_v → grid_that_straddles_dateline.zarr/sigma_r}/.zarray +2 -6
  37. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/.zattrs +7 -0
  38. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/0 +0 -0
  39. roms_tools/tests/test_setup/test_data/{grid.zarr/layer_depth_rho → grid_that_straddles_dateline.zarr/sigma_w}/.zarray +2 -6
  40. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/.zattrs +7 -0
  41. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/0 +0 -0
  42. roms_tools/tests/test_setup/test_grid.py +110 -12
  43. roms_tools/tests/test_setup/test_initial_conditions.py +2 -1
  44. roms_tools/tests/test_setup/test_river_forcing.py +3 -2
  45. roms_tools/tests/test_setup/test_surface_forcing.py +2 -22
  46. roms_tools/tests/test_setup/test_tides.py +2 -1
  47. roms_tools/tests/test_setup/test_topography.py +106 -1
  48. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/LICENSE +1 -1
  49. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/METADATA +2 -1
  50. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/RECORD +52 -76
  51. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/.zattrs +0 -9
  52. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/0.0.0 +0 -0
  53. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/.zattrs +0 -9
  54. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/0.0.0 +0 -0
  55. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/.zattrs +0 -9
  56. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/0.0.0 +0 -0
  57. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_rho/.zattrs +0 -9
  58. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_rho/0.0.0 +0 -0
  59. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/.zarray +0 -24
  60. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/.zattrs +0 -9
  61. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/0.0.0 +0 -0
  62. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/.zarray +0 -24
  63. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/.zattrs +0 -9
  64. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/0.0.0 +0 -0
  65. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/.zarray +0 -24
  66. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/.zattrs +0 -9
  67. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/0.0.0 +0 -0
  68. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/.zarray +0 -24
  69. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/.zattrs +0 -9
  70. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/0.0.0 +0 -0
  71. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/.zarray +0 -24
  72. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/.zattrs +0 -9
  73. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/0.0.0 +0 -0
  74. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/.zarray +0 -24
  75. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/.zattrs +0 -9
  76. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/0.0.0 +0 -0
  77. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/.zarray +0 -24
  78. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/.zattrs +0 -9
  79. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/0.0.0 +0 -0
  80. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/.zarray +0 -24
  81. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/.zattrs +0 -9
  82. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/0.0.0 +0 -0
  83. roms_tools/tests/test_setup/test_vertical_coordinate.py +0 -91
  84. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/WHEEL +0 -0
  85. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/top_level.txt +0 -0
roms_tools/_version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # Do not change! Do not track in version control!
2
- __version__ = "1.7.0"
2
+ __version__ = "2.0.0"
@@ -9,6 +9,7 @@ from roms_tools.setup.grid import Grid
9
9
  from roms_tools.setup.regrid import LateralRegrid, VerticalRegrid
10
10
  from datetime import datetime
11
11
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
12
+ from roms_tools.setup.vertical_coordinate import compute_depth
12
13
  from roms_tools.setup.utils import (
13
14
  get_variable_metadata,
14
15
  group_dataset,
@@ -20,6 +21,8 @@ from roms_tools.setup.utils import (
20
21
  one_dim_fill,
21
22
  nan_check,
22
23
  substitute_nans_by_fillvalue,
24
+ interpolate_from_rho_to_u,
25
+ interpolate_from_rho_to_v,
23
26
  convert_to_roms_time,
24
27
  _to_yaml,
25
28
  _from_yaml,
@@ -115,8 +118,8 @@ class BoundaryForcing:
115
118
  data.extrapolate_deepest_to_bottom()
116
119
  data.apply_lateral_fill()
117
120
 
118
- variable_info = self._set_variable_info(data)
119
- bdry_coords = get_boundary_info()
121
+ self._set_variable_info(data)
122
+ self._set_boundary_info()
120
123
  ds = xr.Dataset()
121
124
 
122
125
  for direction in ["south", "east", "north", "west"]:
@@ -124,10 +127,10 @@ class BoundaryForcing:
124
127
 
125
128
  bdry_target_coords = {
126
129
  "lat": target_coords["lat"].isel(
127
- **bdry_coords["vector"][direction]
130
+ **self.bdry_coords["vector"][direction]
128
131
  ),
129
132
  "lon": target_coords["lon"].isel(
130
- **bdry_coords["vector"][direction]
133
+ **self.bdry_coords["vector"][direction]
131
134
  ),
132
135
  "straddle": target_coords["straddle"],
133
136
  }
@@ -145,11 +148,17 @@ class BoundaryForcing:
145
148
 
146
149
  # lateral regridding of vector fields
147
150
  vector_var_names = [
148
- name for name, info in variable_info.items() if info["is_vector"]
151
+ name
152
+ for name, info in self.variable_info.items()
153
+ if info["is_vector"]
149
154
  ]
150
155
  if len(vector_var_names) > 0:
151
- lon = target_coords["lon"].isel(**bdry_coords["vector"][direction])
152
- lat = target_coords["lat"].isel(**bdry_coords["vector"][direction])
156
+ lon = target_coords["lon"].isel(
157
+ **self.bdry_coords["vector"][direction]
158
+ )
159
+ lat = target_coords["lat"].isel(
160
+ **self.bdry_coords["vector"][direction]
161
+ )
153
162
  lateral_regrid = LateralRegrid(
154
163
  {"lat": lat, "lon": lon}, bdry_data.dim_names
155
164
  )
@@ -162,12 +171,16 @@ class BoundaryForcing:
162
171
  # lateral regridding of tracer fields
163
172
  tracer_var_names = [
164
173
  name
165
- for name, info in variable_info.items()
174
+ for name, info in self.variable_info.items()
166
175
  if not info["is_vector"]
167
176
  ]
168
177
  if len(tracer_var_names) > 0:
169
- lon = target_coords["lon"].isel(**bdry_coords["rho"][direction])
170
- lat = target_coords["lat"].isel(**bdry_coords["rho"][direction])
178
+ lon = target_coords["lon"].isel(
179
+ **self.bdry_coords["rho"][direction]
180
+ )
181
+ lat = target_coords["lat"].isel(
182
+ **self.bdry_coords["rho"][direction]
183
+ )
171
184
  lateral_regrid = LateralRegrid(
172
185
  {"lat": lat, "lon": lon}, bdry_data.dim_names
173
186
  )
@@ -178,9 +191,9 @@ class BoundaryForcing:
178
191
  )
179
192
 
180
193
  # rotation of velocities and interpolation to u/v points
181
- if "u" in variable_info and "v" in variable_info:
194
+ if "u" in self.variable_info and "v" in self.variable_info:
182
195
  angle = target_coords["angle"].isel(
183
- **bdry_coords["vector"][direction]
196
+ **self.bdry_coords["vector"][direction]
184
197
  )
185
198
  (processed_fields["u"], processed_fields["v"],) = rotate_velocities(
186
199
  processed_fields["u"],
@@ -190,54 +203,68 @@ class BoundaryForcing:
190
203
  )
191
204
 
192
205
  # selection of outermost margin for u/v variables
193
- for var_name in variable_info.keys():
206
+ for var_name in self.variable_info.keys():
194
207
  if var_name in processed_fields:
195
- location = variable_info[var_name]["location"]
208
+ location = self.variable_info[var_name]["location"]
196
209
  if location in ["u", "v"]:
197
210
  processed_fields[var_name] = processed_fields[
198
211
  var_name
199
- ].isel(**bdry_coords[location][direction])
212
+ ].isel(**self.bdry_coords[location][direction])
200
213
 
201
214
  if not self.apply_2d_horizontal_fill:
202
215
  self._validate_1d_fill(
203
216
  processed_fields,
204
- variable_info,
205
- bdry_coords,
206
217
  direction,
207
218
  bdry_data.dim_names["depth"],
208
219
  )
209
220
  processed_fields = apply_1d_horizontal_fill(processed_fields)
210
221
 
211
- # vertical regridding
222
+ var_names_dict = {}
212
223
  for location in ["rho", "u", "v"]:
213
- var_names = [
224
+ var_names_dict[location] = [
214
225
  name
215
- for name, info in variable_info.items()
226
+ for name, info in self.variable_info.items()
216
227
  if info["location"] == location and info["is_3d"]
217
228
  ]
218
- if len(var_names) > 0:
229
+ # compute layer depth coordinates
230
+ if len(var_names_dict["u"]) > 0 or len(var_names_dict["v"]) > 0:
231
+ self._get_vertical_coordinates(
232
+ type="layer",
233
+ direction=direction,
234
+ additional_locations=["u", "v"],
235
+ )
236
+ else:
237
+ if len(var_names_dict["rho"]) > 0:
238
+ self._get_vertical_coordinates(
239
+ type="layer", direction=direction, additional_locations=[]
240
+ )
241
+
242
+ # vertical regridding
243
+ for location in ["rho", "u", "v"]:
244
+ if len(var_names_dict[location]) > 0:
219
245
  vertical_regrid = VerticalRegrid(
220
- self.grid.ds[f"layer_depth_{location}"].isel(
221
- **bdry_coords[location][direction]
222
- ),
246
+ self.grid.ds[f"layer_depth_{location}_{direction}"],
223
247
  bdry_data.ds[bdry_data.dim_names["depth"]],
224
248
  )
225
- for var_name in var_names:
249
+ for var_name in var_names_dict[location]:
226
250
  if var_name in processed_fields:
227
251
  processed_fields[var_name] = vertical_regrid.apply(
228
252
  processed_fields[var_name]
229
253
  )
230
254
 
231
255
  # compute barotropic velocities
232
- if "u" in variable_info and "v" in variable_info:
233
- for var_name in ["u", "v"]:
256
+ if "u" in self.variable_info and "v" in self.variable_info:
257
+ self._get_vertical_coordinates(
258
+ type="interface",
259
+ direction=direction,
260
+ additional_locations=["u", "v"],
261
+ )
262
+ for location in ["u", "v"]:
234
263
  processed_fields[
235
- f"{var_name}bar"
264
+ f"{location}bar"
236
265
  ] = compute_barotropic_velocity(
237
- processed_fields[var_name],
238
- self.grid.ds[f"interface_depth_{var_name}"].isel(
239
- **bdry_coords[var_name][direction]
240
- ),
266
+ processed_fields[location],
267
+ self.grid.ds[f"interface_depth_{location}_{direction}"],
241
268
  )
242
269
 
243
270
  # Reorder dimensions
@@ -252,7 +279,7 @@ class BoundaryForcing:
252
279
  # Add global information
253
280
  ds = self._add_global_metadata(data, ds)
254
281
 
255
- self._validate(ds, variable_info, bdry_coords)
282
+ self._validate(ds)
256
283
 
257
284
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
258
285
  for var_name in ds.data_vars:
@@ -317,11 +344,15 @@ class BoundaryForcing:
317
344
  - `vector_pair`: For vector variables, this indicates the associated variable that forms the vector (e.g., 'u' and 'v').
318
345
  - `is_3d`: Indicates whether the variable is 3D (True for variables like 'temp' and 'salt') or 2D (False for 'zeta').
319
346
 
347
+ Parameters
348
+ ----------
349
+ data : object
350
+ An object that contains variable names for the data being processed. This is used to set variable information for biogeochemical data.
351
+
320
352
  Returns
321
353
  -------
322
- dict
323
- A dictionary where the keys are variable names and the values are dictionaries of metadata
324
- about each variable, including 'location', 'is_vector', 'vector_pair', and 'is_3d'.
354
+ None
355
+ This method updates the instance attribute `variable_info` with the metadata dictionary for the variables.
325
356
  """
326
357
  default_info = {
327
358
  "location": "rho",
@@ -379,7 +410,7 @@ class BoundaryForcing:
379
410
  else:
380
411
  variable_info[var_name] = {**default_info, "validate": False}
381
412
 
382
- return variable_info
413
+ object.__setattr__(self, "variable_info", variable_info)
383
414
 
384
415
  def _write_into_dataset(self, direction, processed_fields, ds=None):
385
416
  if ds is None:
@@ -414,45 +445,178 @@ class BoundaryForcing:
414
445
  "lat_v",
415
446
  "lon_v",
416
447
  ]
417
- existing_vars = [var_name for var_name in variables_to_drop if var_name in ds]
448
+ suffixes = ["", "_south", "_east", "_north", "_west"]
449
+ # Existing variables with suffixes
450
+ existing_vars = []
451
+ for var_name in variables_to_drop:
452
+ for suffix in suffixes:
453
+ full_var_name = f"{var_name}{suffix}"
454
+ if full_var_name in ds:
455
+ existing_vars.append(full_var_name)
456
+
418
457
  ds = ds.drop_vars(existing_vars)
419
458
 
420
459
  return ds
421
460
 
422
- def _get_coordinates(self, direction, point):
461
+ def _set_boundary_info(self):
462
+ """Updates boundary coordinates for rho, u, and v variables on the grid.
463
+
464
+ This method determines the boundary points for the grid variables by specifying the
465
+ indices for the south, east, north, and west boundaries. The resulting boundary
466
+ information is stored in the instance attribute `bdry_coords`.
467
+
468
+ Returns
469
+ -------
470
+ None
471
+ The method does not return a value. Instead, it updates the instance attribute
472
+ `bdry_coords`, which is a dictionary structured as follows:
473
+ - Keys: Variable types ("rho", "u", "v", "vector").
474
+ - Values: Nested dictionaries mapping each direction ("south", "east", "north", "west")
475
+ to their corresponding boundary coordinates. The coordinates are specified in terms of
476
+ grid indices for the respective variable types.
477
+ """
478
+
479
+ bdry_coords = {
480
+ "rho": {
481
+ "south": {"eta_rho": 0},
482
+ "east": {"xi_rho": -1},
483
+ "north": {"eta_rho": -1},
484
+ "west": {"xi_rho": 0},
485
+ },
486
+ "u": {
487
+ "south": {"eta_rho": 0},
488
+ "east": {"xi_u": -1},
489
+ "north": {"eta_rho": -1},
490
+ "west": {"xi_u": 0},
491
+ },
492
+ "v": {
493
+ "south": {"eta_v": 0},
494
+ "east": {"xi_rho": -1},
495
+ "north": {"eta_v": -1},
496
+ "west": {"xi_rho": 0},
497
+ },
498
+ "vector": {
499
+ "south": {"eta_rho": [0, 1]},
500
+ "east": {"xi_rho": [-2, -1]},
501
+ "north": {"eta_rho": [-2, -1]},
502
+ "west": {"xi_rho": [0, 1]},
503
+ },
504
+ }
505
+
506
+ object.__setattr__(self, "bdry_coords", bdry_coords)
507
+
508
+ def _get_vertical_coordinates(
509
+ self, type, direction, additional_locations=["u", "v"]
510
+ ):
423
511
  """Retrieve layer and interface depth coordinates for a specified grid boundary.
424
512
 
425
- This method extracts the layer depth and interface depth coordinates along
426
- a specified boundary (north, south, east, or west) and for a specified point
427
- type (rho, u, or v) from the grid dataset.
513
+ This method computes and updates the layer and interface depth coordinates along a specified
514
+ boundary (north, south, east, or west). It handles depth calculations for rho points and
515
+ additional specified locations (u and v).
428
516
 
429
517
  Parameters
430
518
  ----------
431
- direction : str
432
- The direction of the boundary to retrieve coordinates for. Valid options
433
- are "north", "south", "east", and "west".
434
- point : str
435
- The type of grid point to retrieve coordinates for. Valid options are
436
- "rho" for the grid's central points, "u" for the u-flux points, and "v"
437
- for the v-flux points.
519
+ type : str
520
+ The type of depth coordinate to retrieve. Valid options are:
521
+ - "layer": Retrieves layer depth coordinates.
522
+ - "interface": Retrieves interface depth coordinates.
438
523
 
439
- Returns
524
+ direction : str
525
+ The direction of the boundary to retrieve coordinates for. Valid options are:
526
+ - "north"
527
+ - "south"
528
+ - "east"
529
+ - "west"
530
+
531
+ additional_locations : list of str, optional
532
+ Specifies additional locations to compute depth coordinates for. Default is ["u", "v"].
533
+ Valid options include:
534
+ - "u": Computes depth coordinates for u points.
535
+ - "v": Computes depth coordinates for v points.
536
+
537
+ Updates
440
538
  -------
441
- xarray.DataArray, xarray.DataArray
442
- The layer depth and interface depth coordinates for the specified grid
443
- boundary and point type.
539
+ self.grid.ds : xarray.Dataset
540
+ The dataset is updated with the following vertical depth coordinates:
541
+ - f"{type}_depth_rho_{direction}": Depth coordinates at rho points.
542
+ - f"{type}_depth_u_{direction}": Depth coordinates at u points (if applicable).
543
+ - f"{type}_depth_v_{direction}": Depth coordinates at v points (if applicable).
444
544
  """
445
545
 
446
- bdry_coords = get_boundary_info()
546
+ layer_vars = []
547
+ for location in ["rho"] + additional_locations:
548
+ layer_vars.append(f"{type}_depth_{location}_{direction}")
447
549
 
448
- layer_depth = self.grid.ds[f"layer_depth_{point}"].isel(
449
- **bdry_coords[point][direction]
450
- )
451
- interface_depth = self.grid.ds[f"interface_depth_{point}"].isel(
452
- **bdry_coords[point][direction]
453
- )
550
+ if all(layer_var in self.grid.ds for layer_var in layer_vars):
551
+ # Vertical coordinate data already exists
552
+ pass
454
553
 
455
- return layer_depth, interface_depth
554
+ elif f"{type}_depth_rho" in self.grid.ds:
555
+ depth = self.grid.ds[f"{type}_depth_rho"]
556
+ depth.attrs["long_name"] = f"{type} depth at rho-points"
557
+ depth.attrs["units"] = "m"
558
+ self.grid.ds[f"{type}_depth_rho_{direction}"] = depth.isel(
559
+ **self.bdry_coords["rho"][direction]
560
+ )
561
+
562
+ if "u" in additional_locations or "v" in additional_locations:
563
+ # selection of margin consisting of 2 grid cells
564
+ depth = depth.isel(**self.bdry_coords["vector"][direction])
565
+ # interpolation
566
+ if "u" in additional_locations:
567
+ depth_u = interpolate_from_rho_to_u(depth)
568
+ depth_u.attrs["long_name"] = f"{type} depth at u-points"
569
+ depth_u.attrs["units"] = "m"
570
+ self.grid.ds[f"{type}_depth_u_{direction}"] = depth_u.isel(
571
+ **self.bdry_coords["u"][direction]
572
+ )
573
+ if "v" in additional_locations:
574
+ depth_v = interpolate_from_rho_to_v(depth)
575
+ depth_v.attrs["long_name"] = f"{type} depth at v-points"
576
+ depth_v.attrs["units"] = "m"
577
+ self.grid.ds[f"{type}_depth_v_{direction}"] = depth_v.isel(
578
+ **self.bdry_coords["v"][direction]
579
+ )
580
+ else:
581
+ if "u" in additional_locations or "v" in additional_locations:
582
+ h = self.grid.ds["h"].isel(**self.bdry_coords["vector"][direction])
583
+ else:
584
+ h = self.grid.ds["h"].isel(**self.bdry_coords["rho"][direction])
585
+ if type == "layer":
586
+ depth = compute_depth(
587
+ 0, h, self.grid.hc, self.grid.ds.Cs_r, self.grid.ds.sigma_r
588
+ )
589
+ else:
590
+ depth = compute_depth(
591
+ 0, h, self.grid.hc, self.grid.ds.Cs_w, self.grid.ds.sigma_w
592
+ )
593
+
594
+ if "u" in additional_locations or "v" in additional_locations:
595
+ depth.attrs["long_name"] = f"{type} depth at rho-points"
596
+ depth.attrs["units"] = "m"
597
+ self.grid.ds[f"{type}_depth_rho_{direction}"] = depth.isel(
598
+ **self.bdry_coords["rho"][direction]
599
+ )
600
+ # selection of margin consisting of 2 grid cells
601
+ depth = depth.isel(**self.bdry_coords["vector"][direction])
602
+ # interpolation
603
+ depth_u = interpolate_from_rho_to_u(depth)
604
+ depth_v = interpolate_from_rho_to_v(depth)
605
+ # selection of outermost margin
606
+ depth_u.attrs["long_name"] = f"{type} depth at u-points"
607
+ depth_u.attrs["units"] = "m"
608
+ self.grid.ds[f"{type}_depth_u_{direction}"] = depth_u.isel(
609
+ **self.bdry_coords["u"][direction]
610
+ )
611
+ depth_v.attrs["long_name"] = f"{type} depth at v-points"
612
+ depth_v.attrs["units"] = "m"
613
+ self.grid.ds[f"{type}_depth_v_{direction}"] = depth_v.isel(
614
+ **self.bdry_coords["v"][direction]
615
+ )
616
+ else:
617
+ depth.attrs["long_name"] = f"{type} depth at rho-points"
618
+ depth.attrs["units"] = "m"
619
+ self.grid.ds[f"{type}_depth_rho_{direction}"] = depth
456
620
 
457
621
  def _add_global_metadata(self, data, ds=None):
458
622
 
@@ -485,9 +649,7 @@ class BoundaryForcing:
485
649
 
486
650
  return ds
487
651
 
488
- def _validate_1d_fill(
489
- self, processed_fields, variable_info, bdry_coords, direction, depth_dim
490
- ):
652
+ def _validate_1d_fill(self, processed_fields, direction, depth_dim):
491
653
  """Check if any boundary is divided by land and issue a warning if so,
492
654
  suggesting the use of 2D horizontal fill for safer regridding.
493
655
 
@@ -497,15 +659,6 @@ class BoundaryForcing:
497
659
  A dictionary where keys are variable names and values are `xarray.DataArray`
498
660
  objects representing the processed data for each variable.
499
661
 
500
- variable_info : dict
501
- A dictionary containing metadata about each variable (e.g., location,
502
- whether it's a 3D variable, etc.). Used to retrieve information for
503
- validating each variable.
504
-
505
- bdry_coords : dict
506
- A dictionary containing boundary coordinates for different directions (north, south,
507
- east, west), used to slice the boundary-specific data for each variable.
508
-
509
662
  direction : str
510
663
  The boundary direction being processed (e.g., "north", "south", "east", or "west").
511
664
 
@@ -521,8 +674,8 @@ class BoundaryForcing:
521
674
 
522
675
  for var_name in processed_fields.keys():
523
676
  # Only validate variables based on "validate" flag if use_dask is False
524
- if not self.use_dask or variable_info[var_name]["validate"]:
525
- location = variable_info[var_name]["location"]
677
+ if not self.use_dask or self.variable_info[var_name]["validate"]:
678
+ location = self.variable_info[var_name]["location"]
526
679
 
527
680
  # Select the appropriate mask based on variable location
528
681
  if location == "rho":
@@ -532,9 +685,9 @@ class BoundaryForcing:
532
685
  elif location == "v":
533
686
  mask = self.grid.ds.mask_v
534
687
 
535
- mask = mask.isel(**bdry_coords[location][direction])
688
+ mask = mask.isel(**self.bdry_coords[location][direction])
536
689
 
537
- if variable_info[var_name]["is_3d"]:
690
+ if self.variable_info[var_name]["is_3d"]:
538
691
  da = processed_fields[var_name].isel({depth_dim: 0, "time": 0})
539
692
  else:
540
693
  da = processed_fields[var_name].isel({"time": 0})
@@ -553,7 +706,7 @@ class BoundaryForcing:
553
706
  f"For {var_name}, the {direction}ern boundary is divided by land. It would be safer (but slower) to use `apply_2d_horizontal_fill = True`."
554
707
  )
555
708
 
556
- def _validate(self, ds, variable_info, bdry_coords):
709
+ def _validate(self, ds):
557
710
  """Validate the dataset for NaN values at the first time step (bry_time=0) for
558
711
  specified variables. If NaN values are found at wet points, this function raises
559
712
  an error.
@@ -563,12 +716,6 @@ class BoundaryForcing:
563
716
  ds : xarray.Dataset
564
717
  The dataset to validate.
565
718
 
566
- variable_info : dict
567
- A dictionary containing metadata about the variables, including their locations (e.g., 'rho', 'u', 'v').
568
-
569
- bdry_coords : dict
570
- A dictionary containing the boundary coordinates for each variable location.
571
-
572
719
  Raises
573
720
  ------
574
721
  ValueError
@@ -580,10 +727,10 @@ class BoundaryForcing:
580
727
  Validation is performed on the initial boundary time step (`bry_time=0`) for each
581
728
  variable in the dataset.
582
729
  """
583
- for var_name in variable_info:
730
+ for var_name in self.variable_info:
584
731
  # only validate variables based on "validate" flag if use_dask is false
585
- if not self.use_dask or variable_info[var_name]["validate"]:
586
- location = variable_info[var_name]["location"]
732
+ if not self.use_dask or self.variable_info[var_name]["validate"]:
733
+ location = self.variable_info[var_name]["location"]
587
734
 
588
735
  # Select the appropriate mask based on variable location
589
736
  if location == "rho":
@@ -610,7 +757,7 @@ class BoundaryForcing:
610
757
 
611
758
  nan_check(
612
759
  ds[bdry_var_name].isel(bry_time=0),
613
- mask.isel(**bdry_coords[location][direction]),
760
+ mask.isel(**self.bdry_coords[location][direction]),
614
761
  error_message=error_message,
615
762
  )
616
763
 
@@ -696,20 +843,13 @@ class BoundaryForcing:
696
843
  field = field.load()
697
844
 
698
845
  title = field.long_name
846
+ var_name_wo_direction, direction = var_name.split("_")
847
+ location = self.variable_info[var_name_wo_direction]["location"]
699
848
 
700
849
  if "s_rho" in field.dims:
701
- if var_name.startswith(("u_", "ubar_")):
702
- point = "u"
703
- elif var_name.startswith(("v_", "vbar_")):
704
- point = "v"
705
- else:
706
- point = "rho"
707
- direction = var_name.split("_")[-1]
708
-
709
- layer_depth, interface_depth = self._get_coordinates(direction, point)
710
-
711
- field = field.assign_coords({"layer_depth": layer_depth})
712
-
850
+ field = field.assign_coords(
851
+ {"layer_depth": self.grid.ds[f"layer_depth_{location}_{direction}"]}
852
+ )
713
853
  # chose colorbar
714
854
  if var_name.startswith(("u", "v", "ubar", "vbar", "zeta")):
715
855
  vmax = max(field.max().values, -field.min().values)
@@ -727,6 +867,19 @@ class BoundaryForcing:
727
867
 
728
868
  if len(field.dims) == 2:
729
869
  if layer_contours:
870
+ if location in ["u", "v"]:
871
+ additional_locations = ["u", "v"]
872
+ else:
873
+ additional_locations = []
874
+ self._get_vertical_coordinates(
875
+ type="interface",
876
+ direction=direction,
877
+ additional_locations=additional_locations,
878
+ )
879
+
880
+ interface_depth = self.grid.ds[
881
+ f"interface_depth_{location}_{direction}"
882
+ ]
730
883
  # restrict number of layer_contours to 10 for the sake of plot clearity
731
884
  nr_layers = len(interface_depth["s_w"])
732
885
  selected_layers = np.linspace(
@@ -852,50 +1005,6 @@ class BoundaryForcing:
852
1005
  return cls(grid=grid, **params, use_dask=use_dask)
853
1006
 
854
1007
 
855
- def get_boundary_info():
856
- """This function provides information about the boundary points for the rho, u, and
857
- v variables on the grid, specifying the indices for the south, east, north, and west
858
- boundaries.
859
-
860
- Returns
861
- -------
862
- dict
863
- A dictionary where keys are variable types ("rho", "u", "v"), and values
864
- are nested dictionaries mapping directions ("south", "east", "north", "west")
865
- to the corresponding boundary coordinates.
866
- """
867
-
868
- # Boundary coordinates
869
- bdry_coords = {
870
- "rho": {
871
- "south": {"eta_rho": 0},
872
- "east": {"xi_rho": -1},
873
- "north": {"eta_rho": -1},
874
- "west": {"xi_rho": 0},
875
- },
876
- "u": {
877
- "south": {"eta_rho": 0},
878
- "east": {"xi_u": -1},
879
- "north": {"eta_rho": -1},
880
- "west": {"xi_u": 0},
881
- },
882
- "v": {
883
- "south": {"eta_v": 0},
884
- "east": {"xi_rho": -1},
885
- "north": {"eta_v": -1},
886
- "west": {"xi_rho": 0},
887
- },
888
- "vector": {
889
- "south": {"eta_rho": [0, 1]},
890
- "east": {"xi_rho": [-2, -1]},
891
- "north": {"eta_rho": [-2, -1]},
892
- "west": {"xi_rho": [0, 1]},
893
- },
894
- }
895
-
896
- return bdry_coords
897
-
898
-
899
1008
  def apply_1d_horizontal_fill(processed_fields: dict) -> dict:
900
1009
  """Forward and backward fill NaN values in horizontal direction for open boundaries.
901
1010