roms-tools 1.7.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. roms_tools/_version.py +1 -1
  2. roms_tools/setup/boundary_forcing.py +253 -144
  3. roms_tools/setup/datasets.py +216 -48
  4. roms_tools/setup/download.py +13 -17
  5. roms_tools/setup/grid.py +561 -512
  6. roms_tools/setup/initial_conditions.py +148 -30
  7. roms_tools/setup/mask.py +69 -0
  8. roms_tools/setup/plot.py +4 -8
  9. roms_tools/setup/regrid.py +4 -2
  10. roms_tools/setup/surface_forcing.py +11 -18
  11. roms_tools/setup/tides.py +9 -12
  12. roms_tools/setup/topography.py +92 -128
  13. roms_tools/setup/utils.py +49 -25
  14. roms_tools/setup/vertical_coordinate.py +5 -16
  15. roms_tools/tests/test_setup/test_boundary_forcing.py +10 -5
  16. roms_tools/tests/test_setup/test_data/grid.zarr/.zattrs +0 -1
  17. roms_tools/tests/test_setup/test_data/grid.zarr/.zmetadata +56 -201
  18. roms_tools/tests/test_setup/test_data/grid.zarr/Cs_r/.zattrs +1 -1
  19. roms_tools/tests/test_setup/test_data/grid.zarr/Cs_w/.zattrs +1 -1
  20. roms_tools/tests/test_setup/test_data/grid.zarr/{interface_depth_rho → sigma_r}/.zarray +2 -6
  21. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_r/.zattrs +7 -0
  22. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_r/0 +0 -0
  23. roms_tools/tests/test_setup/test_data/grid.zarr/{interface_depth_u → sigma_w}/.zarray +2 -6
  24. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/.zattrs +7 -0
  25. roms_tools/tests/test_setup/test_data/grid.zarr/sigma_w/0 +0 -0
  26. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zattrs +1 -2
  27. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zmetadata +58 -203
  28. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/Cs_r/.zattrs +1 -1
  29. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/Cs_w/.zattrs +1 -1
  30. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/.zattrs +1 -1
  31. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/0.0 +0 -0
  32. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_coarse/0.0 +0 -0
  33. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_rho/0.0 +0 -0
  34. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_u/0.0 +0 -0
  35. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_v/0.0 +0 -0
  36. roms_tools/tests/test_setup/test_data/{grid.zarr/interface_depth_v → grid_that_straddles_dateline.zarr/sigma_r}/.zarray +2 -6
  37. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/.zattrs +7 -0
  38. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_r/0 +0 -0
  39. roms_tools/tests/test_setup/test_data/{grid.zarr/layer_depth_rho → grid_that_straddles_dateline.zarr/sigma_w}/.zarray +2 -6
  40. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/.zattrs +7 -0
  41. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/sigma_w/0 +0 -0
  42. roms_tools/tests/test_setup/test_grid.py +110 -12
  43. roms_tools/tests/test_setup/test_initial_conditions.py +2 -1
  44. roms_tools/tests/test_setup/test_river_forcing.py +3 -2
  45. roms_tools/tests/test_setup/test_surface_forcing.py +2 -22
  46. roms_tools/tests/test_setup/test_tides.py +2 -1
  47. roms_tools/tests/test_setup/test_topography.py +106 -1
  48. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/LICENSE +1 -1
  49. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/METADATA +2 -1
  50. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/RECORD +52 -76
  51. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/.zattrs +0 -9
  52. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_rho/0.0.0 +0 -0
  53. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/.zattrs +0 -9
  54. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_u/0.0.0 +0 -0
  55. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/.zattrs +0 -9
  56. roms_tools/tests/test_setup/test_data/grid.zarr/interface_depth_v/0.0.0 +0 -0
  57. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_rho/.zattrs +0 -9
  58. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_rho/0.0.0 +0 -0
  59. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/.zarray +0 -24
  60. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/.zattrs +0 -9
  61. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_u/0.0.0 +0 -0
  62. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/.zarray +0 -24
  63. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/.zattrs +0 -9
  64. roms_tools/tests/test_setup/test_data/grid.zarr/layer_depth_v/0.0.0 +0 -0
  65. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/.zarray +0 -24
  66. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/.zattrs +0 -9
  67. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_rho/0.0.0 +0 -0
  68. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/.zarray +0 -24
  69. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/.zattrs +0 -9
  70. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_u/0.0.0 +0 -0
  71. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/.zarray +0 -24
  72. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/.zattrs +0 -9
  73. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/interface_depth_v/0.0.0 +0 -0
  74. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/.zarray +0 -24
  75. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/.zattrs +0 -9
  76. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_rho/0.0.0 +0 -0
  77. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/.zarray +0 -24
  78. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/.zattrs +0 -9
  79. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_u/0.0.0 +0 -0
  80. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/.zarray +0 -24
  81. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/.zattrs +0 -9
  82. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/layer_depth_v/0.0.0 +0 -0
  83. roms_tools/tests/test_setup/test_vertical_coordinate.py +0 -91
  84. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/WHEEL +0 -0
  85. {roms_tools-1.7.0.dist-info → roms_tools-2.0.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from typing import Dict, Union, List, Optional
6
6
  from roms_tools.setup.grid import Grid
7
7
  from datetime import datetime
8
8
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
9
+ from roms_tools.setup.vertical_coordinate import compute_depth
9
10
  from roms_tools.setup.utils import (
10
11
  nan_check,
11
12
  substitute_nans_by_fillvalue,
@@ -15,6 +16,8 @@ from roms_tools.setup.utils import (
15
16
  rotate_velocities,
16
17
  compute_barotropic_velocity,
17
18
  transpose_dimensions,
19
+ interpolate_from_rho_to_u,
20
+ interpolate_from_rho_to_v,
18
21
  _to_yaml,
19
22
  _from_yaml,
20
23
  )
@@ -92,23 +95,17 @@ class InitialConditions:
92
95
  self._input_checks()
93
96
 
94
97
  processed_fields = {}
95
- processed_fields, variable_info = self._process_data(
96
- processed_fields, type="physics"
97
- )
98
+ processed_fields = self._process_data(processed_fields, type="physics")
98
99
 
99
100
  if self.bgc_source is not None:
100
- processed_fields, bgc_variable_info = self._process_data(
101
- processed_fields, type="bgc"
102
- )
101
+ processed_fields = self._process_data(processed_fields, type="bgc")
103
102
 
104
103
  d_meta = get_variable_metadata()
105
104
  ds = self._write_into_dataset(processed_fields, d_meta)
106
105
 
107
106
  ds = self._add_global_metadata(ds)
108
107
 
109
- if self.bgc_source is not None:
110
- variable_info = {**variable_info, **bgc_variable_info}
111
- self._validate(ds, variable_info)
108
+ self._validate(ds)
112
109
 
113
110
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
114
111
  for var_name in ds.data_vars:
@@ -133,7 +130,9 @@ class InitialConditions:
133
130
  data.extrapolate_deepest_to_bottom()
134
131
  data.apply_lateral_fill()
135
132
 
136
- variable_info = self._set_variable_info(data, type=type)
133
+ self._set_variable_info(data, type=type)
134
+ attr_name = f"variable_info_{type}"
135
+ variable_info = getattr(self, attr_name)
137
136
  var_names = variable_info.keys()
138
137
 
139
138
  # lateral regridding
@@ -153,19 +152,31 @@ class InitialConditions:
153
152
  interpolate=True,
154
153
  )
155
154
 
156
- # vertical regridding
155
+ var_names_dict = {}
157
156
  for location in ["rho", "u", "v"]:
158
- var_names = [
157
+ var_names_dict[location] = [
159
158
  name
160
159
  for name, info in variable_info.items()
161
160
  if info["location"] == location and info["is_3d"]
162
161
  ]
163
- if len(var_names) > 0:
162
+
163
+ # compute layer depth coordinates
164
+ if len(var_names_dict["u"]) > 0 or len(var_names_dict["v"]) > 0:
165
+ self._get_vertical_coordinates(
166
+ type="layer",
167
+ additional_locations=["u", "v"],
168
+ )
169
+ else:
170
+ if len(var_names_dict["rho"]) > 0:
171
+ self._get_vertical_coordinates(type="layer", additional_locations=[])
172
+ # vertical regridding
173
+ for location in ["rho", "u", "v"]:
174
+ if len(var_names_dict[location]) > 0:
164
175
  vertical_regrid = VerticalRegrid(
165
176
  self.grid.ds[f"layer_depth_{location}"],
166
177
  data.ds[data.dim_names["depth"]],
167
178
  )
168
- for var_name in var_names:
179
+ for var_name in var_names_dict[location]:
169
180
  if var_name in processed_fields:
170
181
  processed_fields[var_name] = vertical_regrid.apply(
171
182
  processed_fields[var_name]
@@ -173,10 +184,14 @@ class InitialConditions:
173
184
 
174
185
  # compute barotropic velocities
175
186
  if "u" in variable_info and "v" in variable_info:
176
- for var_name in ["u", "v"]:
177
- processed_fields[f"{var_name}bar"] = compute_barotropic_velocity(
178
- processed_fields[var_name],
179
- self.grid.ds[f"interface_depth_{var_name}"],
187
+ self._get_vertical_coordinates(
188
+ type="interface",
189
+ additional_locations=["u", "v"],
190
+ )
191
+ for location in ["u", "v"]:
192
+ processed_fields[f"{location}bar"] = compute_barotropic_velocity(
193
+ processed_fields[location],
194
+ self.grid.ds[f"interface_depth_{location}"],
180
195
  )
181
196
 
182
197
  if type == "bgc":
@@ -191,7 +206,7 @@ class InitialConditions:
191
206
  processed_fields[var_name]
192
207
  )
193
208
 
194
- return processed_fields, variable_info
209
+ return processed_fields
195
210
 
196
211
  def _input_checks(self):
197
212
 
@@ -345,7 +360,84 @@ class InitialConditions:
345
360
  else:
346
361
  variable_info[var_name] = {**default_info, "validate": False}
347
362
 
348
- return variable_info
363
+ object.__setattr__(self, f"variable_info_{type}", variable_info)
364
+
365
+ def _get_vertical_coordinates(self, type, additional_locations=["u", "v"]):
366
+ """Retrieve layer and interface depth coordinates.
367
+
368
+ This method computes and updates the layer and interface depth coordinates. It handles depth calculations for rho points and
369
+ additional specified locations (u and v).
370
+
371
+ Parameters
372
+ ----------
373
+ type : str
374
+ The type of depth coordinate to retrieve. Valid options are:
375
+ - "layer": Retrieves layer depth coordinates.
376
+ - "interface": Retrieves interface depth coordinates.
377
+
378
+ additional_locations : list of str, optional
379
+ Specifies additional locations to compute depth coordinates for. Default is ["u", "v"].
380
+ Valid options include:
381
+ - "u": Computes depth coordinates for u points.
382
+ - "v": Computes depth coordinates for v points.
383
+
384
+ Updates
385
+ -------
386
+ self.grid.ds : xarray.Dataset
387
+ The dataset is updated with the following vertical depth coordinates:
388
+ - f"{type}_depth_rho": Depth coordinates at rho points.
389
+ - f"{type}_depth_u": Depth coordinates at u points (if applicable).
390
+ - f"{type}_depth_v": Depth coordinates at v points (if applicable).
391
+ """
392
+
393
+ layer_vars = []
394
+ for location in ["rho"] + additional_locations:
395
+ layer_vars.append(f"{type}_depth_{location}")
396
+
397
+ if all(layer_var in self.grid.ds for layer_var in layer_vars):
398
+ # Vertical coordinate data already exists
399
+ pass
400
+
401
+ elif f"{type}_depth_rho" in self.grid.ds:
402
+ depth = self.grid.ds[f"{type}_depth_rho"]
403
+
404
+ if "u" in additional_locations or "v" in additional_locations:
405
+ # interpolation
406
+ if "u" in additional_locations:
407
+ depth_u = interpolate_from_rho_to_u(depth)
408
+ depth_u.attrs["long_name"] = f"{type} depth at u-points"
409
+ depth_u.attrs["units"] = "m"
410
+ self.grid.ds[f"{type}_depth_u"] = depth_u
411
+ if "v" in additional_locations:
412
+ depth_v = interpolate_from_rho_to_v(depth)
413
+ depth_v.attrs["long_name"] = f"{type} depth at v-points"
414
+ depth_v.attrs["units"] = "m"
415
+ self.grid.ds[f"{type}_depth_v"] = depth_v
416
+ else:
417
+ h = self.grid.ds["h"]
418
+ if type == "layer":
419
+ depth = compute_depth(
420
+ 0, h, self.grid.hc, self.grid.ds.Cs_r, self.grid.ds.sigma_r
421
+ )
422
+ else:
423
+ depth = compute_depth(
424
+ 0, h, self.grid.hc, self.grid.ds.Cs_w, self.grid.ds.sigma_w
425
+ )
426
+
427
+ depth.attrs["long_name"] = f"{type} depth at rho-points"
428
+ depth.attrs["units"] = "m"
429
+ self.grid.ds[f"{type}_depth_rho"] = depth
430
+
431
+ if "u" in additional_locations or "v" in additional_locations:
432
+ # interpolation
433
+ depth_u = interpolate_from_rho_to_u(depth)
434
+ depth_u.attrs["long_name"] = f"{type} depth at u-points"
435
+ depth_u.attrs["units"] = "m"
436
+ depth_v = interpolate_from_rho_to_v(depth)
437
+ depth_v.attrs["long_name"] = f"{type} depth at v-points"
438
+ depth_v.attrs["units"] = "m"
439
+ self.grid.ds[f"{type}_depth_u"] = depth_u
440
+ self.grid.ds[f"{type}_depth_v"] = depth_v
349
441
 
350
442
  def _write_into_dataset(self, processed_fields, d_meta):
351
443
 
@@ -410,7 +502,7 @@ class InitialConditions:
410
502
 
411
503
  return ds
412
504
 
413
- def _validate(self, ds, variable_info):
505
+ def _validate(self, ds):
414
506
  """Validates the dataset by checking for NaN values in SSH at wet points, which
415
507
  would indicate missing raw data coverage over the target domain.
416
508
 
@@ -418,8 +510,6 @@ class InitialConditions:
418
510
  ----------
419
511
  ds : xarray.Dataset
420
512
  The dataset to validate.
421
- variable_info : dict
422
- A dictionary containing metadata about the variables, including whether to validate them.
423
513
 
424
514
  Raises
425
515
  ------
@@ -431,6 +521,10 @@ class InitialConditions:
431
521
  -----
432
522
  This check is only applied to the 2D variable SSH to improve performance.
433
523
  """
524
+ if self.bgc_source is not None:
525
+ variable_info = {**self.variable_info_physics, **self.variable_info_bgc}
526
+ else:
527
+ variable_info = self.variable_info_physics
434
528
 
435
529
  for var_name in variable_info:
436
530
  # Only validate variables based on "validate" flag if use_dask is False
@@ -570,9 +664,17 @@ class InitialConditions:
570
664
  self.ds[var_name].load()
571
665
 
572
666
  field = self.ds[var_name].squeeze()
667
+ if s is not None:
668
+ layer_contours = False
573
669
 
574
670
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
575
- interface_depth = self.grid.ds.interface_depth_rho
671
+ if layer_contours:
672
+ if "interface_depth_rho" in self.grid.ds:
673
+ interface_depth = self.grid.ds.interface_depth_rho
674
+ else:
675
+ self.get_vertical_coordinates(
676
+ type="interface", additional_locations=[]
677
+ )
576
678
  layer_depth = self.grid.ds.layer_depth_rho
577
679
  mask = self.grid.ds.mask_rho
578
680
  field = field.assign_coords(
@@ -580,7 +682,13 @@ class InitialConditions:
580
682
  )
581
683
 
582
684
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
583
- interface_depth = self.grid.ds.interface_depth_u
685
+ if layer_contours:
686
+ if "interface_depth_u" in self.grid.ds:
687
+ interface_depth = self.grid.ds.interface_depth_u
688
+ else:
689
+ self.get_vertical_coordinates(
690
+ type="interface", additional_locations=["u", "v"]
691
+ )
584
692
  layer_depth = self.grid.ds.layer_depth_u
585
693
  mask = self.grid.ds.mask_u
586
694
  field = field.assign_coords(
@@ -588,7 +696,13 @@ class InitialConditions:
588
696
  )
589
697
 
590
698
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
591
- interface_depth = self.grid.ds.interface_depth_v
699
+ if layer_contours:
700
+ if "interface_depth_v" in self.grid.ds:
701
+ interface_depth = self.grid.ds.interface_depth_v
702
+ else:
703
+ self.get_vertical_coordinates(
704
+ type="interface", additional_locations=["u", "v"]
705
+ )
592
706
  layer_depth = self.grid.ds.layer_depth_v
593
707
  mask = self.grid.ds.mask_v
594
708
  field = field.assign_coords(
@@ -612,14 +726,16 @@ class InitialConditions:
612
726
  title = title + f", eta_rho = {field.eta_rho[eta].item()}"
613
727
  field = field.isel(eta_rho=eta)
614
728
  layer_depth = layer_depth.isel(eta_rho=eta)
615
- interface_depth = interface_depth.isel(eta_rho=eta)
729
+ if layer_contours:
730
+ interface_depth = interface_depth.isel(eta_rho=eta)
616
731
  if "s_rho" in field.dims:
617
732
  field = field.assign_coords({"layer_depth": layer_depth})
618
733
  elif "eta_v" in field.dims:
619
734
  title = title + f", eta_v = {field.eta_v[eta].item()}"
620
735
  field = field.isel(eta_v=eta)
621
736
  layer_depth = layer_depth.isel(eta_v=eta)
622
- interface_depth = interface_depth.isel(eta_v=eta)
737
+ if layer_contours:
738
+ interface_depth = interface_depth.isel(eta_v=eta)
623
739
  if "s_rho" in field.dims:
624
740
  field = field.assign_coords({"layer_depth": layer_depth})
625
741
  else:
@@ -631,14 +747,16 @@ class InitialConditions:
631
747
  title = title + f", xi_rho = {field.xi_rho[xi].item()}"
632
748
  field = field.isel(xi_rho=xi)
633
749
  layer_depth = layer_depth.isel(xi_rho=xi)
634
- interface_depth = interface_depth.isel(xi_rho=xi)
750
+ if layer_contours:
751
+ interface_depth = interface_depth.isel(xi_rho=xi)
635
752
  if "s_rho" in field.dims:
636
753
  field = field.assign_coords({"layer_depth": layer_depth})
637
754
  elif "xi_u" in field.dims:
638
755
  title = title + f", xi_u = {field.xi_u[xi].item()}"
639
756
  field = field.isel(xi_u=xi)
640
757
  layer_depth = layer_depth.isel(xi_u=xi)
641
- interface_depth = interface_depth.isel(xi_u=xi)
758
+ if layer_contours:
759
+ interface_depth = interface_depth.isel(xi_u=xi)
642
760
  if "s_rho" in field.dims:
643
761
  field = field.assign_coords({"layer_depth": layer_depth})
644
762
  else:
@@ -0,0 +1,69 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+ import regionmask
4
+ from scipy.ndimage import label
5
+ from roms_tools.setup.utils import (
6
+ interpolate_from_rho_to_u,
7
+ interpolate_from_rho_to_v,
8
+ handle_boundaries,
9
+ )
10
+
11
+
12
+ def _add_mask(ds):
13
+
14
+ land = regionmask.defined_regions.natural_earth_v5_0_0.land_10
15
+
16
+ land_mask = land.mask(ds["lon_rho"], ds["lat_rho"])
17
+ mask = land_mask.isnull()
18
+
19
+ # fill enclosed basins with land
20
+ mask = _fill_enclosed_basins(mask.values)
21
+ # adjust mask boundaries by copying values from adjacent cells
22
+ mask = handle_boundaries(mask)
23
+
24
+ ds["mask_rho"] = xr.DataArray(mask.astype(np.int32), dims=("eta_rho", "xi_rho"))
25
+ ds["mask_rho"].attrs = {
26
+ "long_name": "Mask at rho-points",
27
+ "units": "land/water (0/1)",
28
+ }
29
+ ds = _add_velocity_masks(ds)
30
+
31
+ return ds
32
+
33
+
34
+ def _fill_enclosed_basins(mask) -> np.ndarray:
35
+ """Fills in enclosed basins with land."""
36
+
37
+ # Label connected regions in the mask
38
+ reg, nreg = label(mask)
39
+ # Find the largest region
40
+ lint = 0
41
+ lreg = 0
42
+ for ireg in range(nreg):
43
+ int_ = np.sum(reg == ireg)
44
+ if int_ > lint and mask[reg == ireg].sum() > 0:
45
+ lreg = ireg
46
+ lint = int_
47
+
48
+ # Remove regions other than the largest one
49
+ for ireg in range(nreg):
50
+ if ireg != lreg:
51
+ mask[reg == ireg] = 0
52
+
53
+ return mask
54
+
55
+
56
+ def _add_velocity_masks(ds):
57
+
58
+ # add u- and v-masks
59
+ ds["mask_u"] = interpolate_from_rho_to_u(
60
+ ds["mask_rho"], method="multiplicative"
61
+ ).astype(np.int32)
62
+ ds["mask_v"] = interpolate_from_rho_to_v(
63
+ ds["mask_rho"], method="multiplicative"
64
+ ).astype(np.int32)
65
+
66
+ ds["mask_u"].attrs = {"long_name": "Mask at u-points", "units": "land/water (0/1)"}
67
+ ds["mask_v"].attrs = {"long_name": "Mask at v-points", "units": "land/water (0/1)"}
68
+
69
+ return ds
roms_tools/setup/plot.py CHANGED
@@ -238,18 +238,14 @@ def _profile_plot(field, title="", ax=None):
238
238
  """
239
239
 
240
240
  depths_to_check = [
241
- "layer_depth_rho",
242
- "layer_depth_u",
243
- "layer_depth_v",
244
- "interface_depth_rho",
245
- "interface_depth_u",
246
- "interface_depth_v",
241
+ "layer_depth",
242
+ "interface_depth",
247
243
  ]
248
244
  try:
249
245
  depth_label = next(
250
246
  depth_label
251
- for depth_label in depths_to_check
252
- if depth_label in field.coords
247
+ for depth_label in field.coords
248
+ if any(depth_label.startswith(prefix) for prefix in depths_to_check)
253
249
  )
254
250
  except StopIteration:
255
251
  raise ValueError(
@@ -28,20 +28,22 @@ class LateralRegrid:
28
28
  source_dim_names["longitude"]: target_coords["lon"],
29
29
  }
30
30
 
31
- def apply(self, da):
31
+ def apply(self, da, method="linear"):
32
32
  """Fills missing values and regrids the variable.
33
33
 
34
34
  Parameters
35
35
  ----------
36
36
  da : xarray.DataArray
37
37
  Input data to fill and regrid.
38
+ method : str
39
+ Interpolation method to use.
38
40
 
39
41
  Returns
40
42
  -------
41
43
  xarray.DataArray
42
44
  Regridded data with filled values.
43
45
  """
44
- regridded = da.interp(self.coords, method="linear").drop_vars(
46
+ regridded = da.interp(self.coords, method=method).drop_vars(
45
47
  list(self.coords.keys())
46
48
  )
47
49
  return regridded
@@ -12,13 +12,13 @@ from roms_tools.setup.datasets import (
12
12
  CESMBGCSurfaceForcingDataset,
13
13
  )
14
14
  from roms_tools.setup.utils import (
15
+ get_target_coords,
15
16
  nan_check,
16
17
  substitute_nans_by_fillvalue,
17
18
  interpolate_from_climatology,
18
19
  get_variable_metadata,
19
20
  group_dataset,
20
21
  save_datasets,
21
- get_target_coords,
22
22
  rotate_velocities,
23
23
  convert_to_roms_time,
24
24
  _to_yaml,
@@ -105,8 +105,8 @@ class SurfaceForcing:
105
105
 
106
106
  data.apply_lateral_fill()
107
107
 
108
- variable_info = self._set_variable_info(data)
109
- var_names = variable_info.keys()
108
+ self._set_variable_info(data)
109
+ var_names = self.variable_info.keys()
110
110
 
111
111
  processed_fields = {}
112
112
  # lateral regridding
@@ -118,7 +118,7 @@ class SurfaceForcing:
118
118
  )
119
119
 
120
120
  # rotation of velocities and interpolation to u/v points
121
- if "uwnd" in variable_info and "vwnd" in variable_info:
121
+ if "uwnd" in self.variable_info and "vwnd" in self.variable_info:
122
122
  processed_fields["uwnd"], processed_fields["vwnd"] = rotate_velocities(
123
123
  processed_fields["uwnd"],
124
124
  processed_fields["vwnd"],
@@ -134,7 +134,7 @@ class SurfaceForcing:
134
134
 
135
135
  ds = self._write_into_dataset(processed_fields, data, d_meta)
136
136
 
137
- self._validate(ds, target_coords["mask"], variable_info)
137
+ self._validate(ds)
138
138
 
139
139
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
140
140
  for var_name in ds.data_vars:
@@ -213,9 +213,8 @@ class SurfaceForcing:
213
213
 
214
214
  Returns
215
215
  -------
216
- dict
217
- A dictionary where the keys are variable names and the values are dictionaries of metadata
218
- about each variable, including 'location', 'is_vector', 'vector_pair', and 'is_3d'.
216
+ None
217
+ This method updates the instance attribute `variable_info` with the metadata dictionary for the variables.
219
218
  """
220
219
  default_info = {
221
220
  "location": "rho",
@@ -256,7 +255,7 @@ class SurfaceForcing:
256
255
  else:
257
256
  variable_info[var_name] = {**default_info, "validate": False}
258
257
 
259
- return variable_info
258
+ object.__setattr__(self, "variable_info", variable_info)
260
259
 
261
260
  def _apply_correction(self, processed_fields, data):
262
261
 
@@ -329,7 +328,7 @@ class SurfaceForcing:
329
328
 
330
329
  return ds
331
330
 
332
- def _validate(self, ds, mask, variable_info):
331
+ def _validate(self, ds):
333
332
  """Validates the dataset by checking for NaN values at wet points, which would
334
333
  indicate missing raw data coverage over the target domain.
335
334
 
@@ -337,12 +336,6 @@ class SurfaceForcing:
337
336
  ----------
338
337
  ds : xarray.Dataset
339
338
  The dataset to validate.
340
- mask : xarray.DataArray
341
- Land mask (1=ocean, 0=land) to determine wet points in the domain.
342
- variable_info : dict
343
- A dictionary containing metadata about each variable (e.g., location,
344
- whether it's a 3D variable, etc.). Used to retrieve information for
345
- validating each variable.
346
339
 
347
340
  Raises
348
341
  ------
@@ -357,8 +350,8 @@ class SurfaceForcing:
357
350
 
358
351
  for var_name in ds.data_vars:
359
352
  # Only validate variables based on "validate" flag if use_dask is False
360
- if not self.use_dask or variable_info[var_name]["validate"]:
361
- nan_check(ds[var_name].isel(time=0), mask)
353
+ if not self.use_dask or self.variable_info[var_name]["validate"]:
354
+ nan_check(ds[var_name].isel(time=0), self.target_coords["mask"])
362
355
 
363
356
  def _add_global_metadata(self, ds=None):
364
357
 
roms_tools/setup/tides.py CHANGED
@@ -85,8 +85,8 @@ class TidalForcing:
85
85
 
86
86
  data.apply_lateral_fill()
87
87
 
88
- variable_info = self._set_variable_info()
89
- var_names = variable_info.keys()
88
+ self._set_variable_info()
89
+ var_names = self.variable_info.keys()
90
90
 
91
91
  processed_fields = {}
92
92
  # lateral regridding
@@ -98,7 +98,7 @@ class TidalForcing:
98
98
  )
99
99
 
100
100
  # rotation of velocities and interpolation to u/v points
101
- vector_pairs = get_vector_pairs(variable_info)
101
+ vector_pairs = get_vector_pairs(self.variable_info)
102
102
  for pair in vector_pairs:
103
103
  u_component = pair[0]
104
104
  v_component = pair[1]
@@ -129,7 +129,7 @@ class TidalForcing:
129
129
 
130
130
  ds = self._add_global_metadata(ds)
131
131
 
132
- self._validate(ds, variable_info)
132
+ self._validate(ds)
133
133
 
134
134
  # substitute NaNs over land by a fill value to avoid blow-up of ROMS
135
135
  for var_name in ds.data_vars:
@@ -163,9 +163,8 @@ class TidalForcing:
163
163
 
164
164
  Returns
165
165
  -------
166
- dict
167
- A dictionary where the keys are variable names and the values are dictionaries of metadata
168
- about each variable, including 'location', 'is_vector', 'vector_pair', and 'is_3d'.
166
+ None
167
+ This method updates the instance attribute `variable_info` with the metadata dictionary for the variables.
169
168
  """
170
169
  default_info = {
171
170
  "location": "rho",
@@ -210,7 +209,7 @@ class TidalForcing:
210
209
  },
211
210
  }
212
211
 
213
- return variable_info
212
+ object.__setattr__(self, "variable_info", variable_info)
214
213
 
215
214
  def _write_into_dataset(self, processed_fields, d_meta):
216
215
 
@@ -243,7 +242,7 @@ class TidalForcing:
243
242
 
244
243
  return ds
245
244
 
246
- def _validate(self, ds, variable_info):
245
+ def _validate(self, ds):
247
246
  """Validates the dataset by checking for NaN values at wet points for specified
248
247
  variables, which would indicate missing raw data coverage over the target
249
248
  domain.
@@ -252,8 +251,6 @@ class TidalForcing:
252
251
  ----------
253
252
  ds : xarray.Dataset
254
253
  The dataset to validate, containing tidal variables and a mask for wet points.
255
- variable_info : dict
256
- A dictionary containing metadata about the variables, including whether to validate them.
257
254
 
258
255
  Raises
259
256
  ------
@@ -268,7 +265,7 @@ class TidalForcing:
268
265
  """
269
266
  for var_name in ds.data_vars:
270
267
  # only validate variables based on "validate" flag if use_dask is false
271
- if not self.use_dask or variable_info[var_name]["validate"]:
268
+ if not self.use_dask or self.variable_info[var_name]["validate"]:
272
269
  nan_check(ds[var_name].isel(ntides=0), self.grid.ds.mask_rho)
273
270
 
274
271
  def plot(self, var_name, ntides=0) -> None: