roms-tools 2.6.2__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +11 -77
  3. roms_tools/analysis/utils.py +0 -66
  4. roms_tools/constants.py +2 -0
  5. roms_tools/download.py +46 -3
  6. roms_tools/plot.py +22 -5
  7. roms_tools/setup/cdr_forcing.py +1126 -0
  8. roms_tools/setup/datasets.py +742 -87
  9. roms_tools/setup/grid.py +42 -4
  10. roms_tools/setup/river_forcing.py +11 -84
  11. roms_tools/setup/tides.py +81 -411
  12. roms_tools/setup/utils.py +241 -37
  13. roms_tools/tests/test_setup/test_cdr_forcing.py +772 -0
  14. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +53 -1
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -1
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_long_name/.zarray +20 -0
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_long_name/.zattrs +6 -0
  18. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_long_name/0 +0 -0
  19. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_unit/.zarray +20 -0
  20. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_unit/.zattrs +6 -0
  21. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_unit/0 +0 -0
  22. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +53 -1
  23. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/.zattrs +1 -1
  24. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_long_name/.zarray +20 -0
  25. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_long_name/.zattrs +6 -0
  26. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_long_name/0 +0 -0
  27. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_unit/.zarray +20 -0
  28. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_unit/.zattrs +6 -0
  29. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_unit/0 +0 -0
  30. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/.zattrs +1 -2
  31. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/.zmetadata +27 -5
  32. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ntides/.zarray +20 -0
  33. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ntides/.zattrs +5 -0
  34. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ntides/0 +0 -0
  35. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/omega/.zattrs +1 -3
  36. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/pot_Im/0.0.0 +0 -0
  37. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/pot_Re/0.0.0 +0 -0
  38. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ssh_Im/0.0.0 +0 -0
  39. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ssh_Re/0.0.0 +0 -0
  40. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/u_Im/0.0.0 +0 -0
  41. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/u_Re/0.0.0 +0 -0
  42. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/v_Im/0.0.0 +0 -0
  43. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/v_Re/0.0.0 +0 -0
  44. roms_tools/tests/test_setup/test_datasets.py +103 -1
  45. roms_tools/tests/test_setup/test_tides.py +112 -47
  46. roms_tools/utils.py +115 -1
  47. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/METADATA +1 -1
  48. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/RECORD +51 -33
  49. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/WHEEL +1 -1
  50. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/licenses/LICENSE +0 -0
  51. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import numpy as np
6
6
  from typing import Dict, Optional, Union, List
7
7
  from pathlib import Path
8
8
  import logging
9
+ from roms_tools.constants import R_EARTH
9
10
  from roms_tools.utils import _load_data
10
11
  from roms_tools.setup.utils import (
11
12
  assign_dates_to_climatology,
@@ -20,6 +21,7 @@ from roms_tools.download import (
20
21
  download_correction_data,
21
22
  download_topo,
22
23
  download_river_data,
24
+ download_sal_data,
23
25
  )
24
26
  from roms_tools.setup.fill import LateralFill
25
27
 
@@ -168,7 +170,7 @@ class Dataset:
168
170
 
169
171
  return ds
170
172
 
171
- def clean_up(self, ds: xr.Dataset) -> xr.Dataset:
173
+ def clean_up(self, ds: xr.Dataset, **kwargs) -> xr.Dataset:
172
174
  """Dummy method to be overridden by child classes to clean up the dataset.
173
175
 
174
176
  This method is intended as a placeholder and should be implemented in subclasses
@@ -182,9 +184,9 @@ class Dataset:
182
184
  Returns
183
185
  -------
184
186
  xr.Dataset
185
- The xarray Dataset cleaned up (as implemented by child classes).
187
+ The cleaned-up xarray Dataset (as implemented by child classes).
186
188
  """
187
- return ds
189
+ return ds # Default behavior (no-op, subclasses should override)
188
190
 
189
191
  def check_dataset(self, ds: xr.Dataset) -> None:
190
192
  """Check if the dataset contains the specified variables and dimensions.
@@ -221,6 +223,7 @@ class Dataset:
221
223
  if (
222
224
  var not in self.var_names.values()
223
225
  and var not in self.opt_var_names.values()
226
+ and var != "mask"
224
227
  ):
225
228
  ds = ds.drop_vars(var)
226
229
 
@@ -385,7 +388,6 @@ class Dataset:
385
388
  and longitude differences, in meters.
386
389
  """
387
390
 
388
- r_earth = 6371315.0
389
391
  lat_dim = self.dim_names["latitude"]
390
392
  lon_dim = self.dim_names["longitude"]
391
393
 
@@ -398,11 +400,11 @@ class Dataset:
398
400
  lon_diff = np.abs(np.diff(longitudes)).min() # Minimal longitude spacing
399
401
 
400
402
  # Latitude spacing is constant at all longitudes
401
- min_lat_spacing = (2 * np.pi * r_earth * lat_diff) / 360
403
+ min_lat_spacing = (2 * np.pi * R_EARTH * lat_diff) / 360
402
404
 
403
405
  # Longitude spacing varies with latitude
404
406
  min_lon_spacing = (
405
- 2 * np.pi * r_earth * lon_diff * np.cos(np.radians(latitudes.min()))
407
+ 2 * np.pi * R_EARTH * lon_diff * np.cos(np.radians(latitudes.min()))
406
408
  ) / 360
407
409
 
408
410
  # The minimal spacing is the smaller of the two
@@ -698,6 +700,10 @@ class Dataset:
698
700
 
699
701
  Notes
700
702
  -----
703
+ This method assumes that the variables in the dataset use a dimension
704
+ ordering where latitude comes before longitude, i.e., ('latitude', 'longitude').
705
+ Ensure that this convention is followed to avoid unexpected behavior.
706
+
701
707
  Looping over `self.ds.data_vars` instead of `self.var_names` ensures that each
702
708
  dataset variable is filled only once, even if multiple entries in `self.var_names`
703
709
  point to the same variable in the dataset.
@@ -801,119 +807,223 @@ class TPXODataset(Dataset):
801
807
 
802
808
  Parameters
803
809
  ----------
804
- reference_date : datetime, optional
805
- The reference date for the TPXO data. Default is datetime(1992, 1, 1).
810
+ filename : str
811
+ The path to the TPXO dataset file.
812
+ grid_filename : str
813
+ The path to the TPXO grid file.
814
+ location : str
815
+ "h", "u", "v"
816
+ var_names : Dict[str, str]
817
+ Dictionary of variable names required in the dataset.
818
+ dim_names : Dict[str, str], optional
819
+ Dictionary specifying the names of dimensions in the dataset. Defaults to:
820
+ {"longitude": "ny", "latitude": "nx", "ntides": "nc"}.
821
+ tolerate_coord_mismatch : bool, optional
822
+ If True, allows mismatched latitude/longitude coordinates between the TPXO dataset and the TPXO grid
823
+ by selecting the nearest grid-aligned coordinates. Default is False.
824
+
825
+ Attributes
826
+ ----------
827
+ ds : xr.Dataset
828
+ The xarray Dataset containing the TPXO tidal model data, loaded from the specified file.
806
829
  """
807
830
 
808
- var_names: Dict[str, str] = field(
809
- default_factory=lambda: {
810
- "ssh_Re": "h_Re",
811
- "ssh_Im": "h_Im",
812
- "sal_Re": "sal_Re",
813
- "sal_Im": "sal_Im",
814
- "u_Re": "u_Re",
815
- "u_Im": "u_Im",
816
- "v_Re": "v_Re",
817
- "v_Im": "v_Im",
818
- "depth": "depth",
819
- }
820
- )
831
+ filename: str
832
+ grid_filename: str
833
+ location: str
834
+ var_names: Dict[str, str]
821
835
  dim_names: Dict[str, str] = field(
822
- default_factory=lambda: {"longitude": "ny", "latitude": "nx", "ntides": "nc"}
836
+ default_factory=lambda: {"longitude": "nx", "latitude": "ny", "ntides": "nc"}
823
837
  )
824
- reference_date: datetime = datetime(1992, 1, 1)
838
+ tolerate_coord_mismatch: bool = False
839
+ ds: xr.Dataset = field(init=False, repr=False)
825
840
 
826
841
  def clean_up(self, ds: xr.Dataset) -> xr.Dataset:
827
- """Clean up and standardize the dimensions and coordinates of the dataset for
828
- further processing.
829
-
830
- This method performs the following operations:
831
- - Assigns new coordinate variables for 'omega', 'longitude', and 'latitude' based on existing dataset variables.
832
- - 'omega' is retained as it is.
833
- - 'longitude' is derived from 'lon_r', assuming it is constant along the 'ny' dimension.
834
- - 'latitude' is derived from 'lat_r', assuming it is constant along the 'nx' dimension.
835
- - Renames the dimensions 'nx' and 'ny' to 'longitude' and 'latitude', respectively, for consistency.
836
- - Renames the tidal dimension to 'ntides' for standardization.
837
- - Updates the `dim_names` attribute of the object to reflect the new dimension names: 'longitude', 'latitude', and 'ntides'.
842
+ """Standardize the dataset's dimensions and coordinates for further processing.
843
+
844
+ This method performs the following operations:
845
+ - Assigns 'longitude' and 'latitude' coordinates from existing dataset variables.
846
+ - Adds a `mask` variable indicating valid data points based on grid conditions.
847
+ - Renames dimensions:
848
+ - 'nx' 'longitude'
849
+ - 'ny' 'latitude'
850
+ - Tidal dimension 'ntides'
851
+ - Updates `dim_names` to reflect the new dimension names.
852
+ - If coordinates do not match and `self.tolerate_coord_mismatch=True`, the dataset is reindexed
853
+ to the grid's coordinates using nearest-neighbor selection. Otherwise, a ValueError is raised.
854
+
855
+ Parameters
856
+ ----------
857
+ ds : xr.Dataset
858
+ The dataset to be standardized.
859
+
860
+ Returns
861
+ -------
862
+ xr.Dataset
863
+ The cleaned dataset with updated coordinates and dimensions.
864
+
865
+ Raises
866
+ ------
867
+ ValueError
868
+ If longitude or latitude values do not match the grid.
869
+ """
838
870
 
839
- Parameters
840
- ----------
841
- ds : xr.Dataset
842
- The input dataset to be cleaned and standardized. It should contain the coordinates 'omega', 'lon_r', 'lat_r', and the tidal dimension.
871
+ ds_grid = _load_data(self.grid_filename, self.dim_names, self.use_dask)
872
+
873
+ # Define mask and coordinate names based on location
874
+ if self.location == "h":
875
+ mask_name = "mz"
876
+ lon_name = "lon_z"
877
+ lat_name = "lat_z"
878
+ elif self.location == "u":
879
+ mask_name = "mu"
880
+ lon_name = "lon_u"
881
+ lat_name = "lat_u"
882
+ elif self.location == "v":
883
+ mask_name = "mv"
884
+ lon_name = "lon_v"
885
+ lat_name = "lat_v"
886
+
887
+ # Assign and rename coordinates
888
+ ds_grid = ds_grid.assign_coords(
889
+ {
890
+ "nx": ds_grid[lon_name].isel(
891
+ ny=0
892
+ ), # lon is constant along ny, i.e., is only a function of nx
893
+ "ny": ds_grid[lat_name].isel(
894
+ nx=0
895
+ ), # lat is constant along nx, i.e., is only a function of ny
896
+ }
897
+ )
898
+ ds_grid = ds_grid.rename({"nx": "longitude", "ny": "latitude"})
843
899
 
844
- Returns
845
- -------
846
- ds : xr.Dataset
847
- A cleaned and standardized `xarray.Dataset` with updated coordinates and dimensions.
848
- """
900
+ # Drop all dimensions except 'longitude' and 'latitude'
901
+ dims_to_keep = {"longitude", "latitude"}
902
+ dims_to_drop = [dim for dim in ds_grid.dims if dim not in dims_to_keep]
903
+ if dims_to_drop:
904
+ ds_grid = ds_grid.isel({dim: 0 for dim in dims_to_drop})
905
+
906
+ # Ensure correct dimension order
907
+ ds_grid = ds_grid.transpose("latitude", "longitude")
908
+
909
+ dims_to_keep = {"longitude", "latitude"}
910
+ dims_to_drop = set(ds_grid.dims) - dims_to_keep
911
+ ds_grid = (
912
+ ds_grid.isel({dim: 0 for dim in dims_to_drop}) if dims_to_drop else ds_grid
913
+ )
914
+ # Bring dimensions in correct order
915
+ ds_grid = ds_grid.transpose("latitude", "longitude")
916
+
917
+ ds = ds.rename({"con": "nc"})
849
918
  ds = ds.assign_coords(
850
919
  {
851
- "omega": ds["omega"],
852
- "nx": ds["lon_r"].isel(
920
+ "nc": ("nc", [c.strip() for c in ds["nc"].values]), # Strip padding
921
+ "nx": ds[lon_name].isel(
853
922
  ny=0
854
- ), # lon_r is constant along ny, i.e., is only a function of nx
855
- "ny": ds["lat_r"].isel(
923
+ ), # lon is constant along ny, i.e., is only a function of nx
924
+ "ny": ds[lat_name].isel(
856
925
  nx=0
857
- ), # lat_r is constant along nx, i.e., is only a function of ny
926
+ ), # lat is constant along nx, i.e., is only a function of ny
858
927
  }
859
928
  )
860
929
  ds = ds.rename(
861
930
  {"nx": "longitude", "ny": "latitude", self.dim_names["ntides"]: "ntides"}
862
931
  )
863
932
 
933
+ ds = ds.transpose("latitude", "longitude", "ntides")
934
+
864
935
  self.dim_names = {
865
936
  "latitude": "latitude",
866
937
  "longitude": "longitude",
867
938
  "ntides": "ntides",
868
939
  }
869
940
 
941
+ if self.tolerate_coord_mismatch:
942
+ # Reindex dataset to match grid lat/lon using nearest-neighbor method
943
+ ds = ds.sel(
944
+ {
945
+ "latitude": ds_grid["latitude"],
946
+ "longitude": ds_grid["longitude"],
947
+ },
948
+ method="nearest",
949
+ )
950
+
951
+ # Validate matching lat/lon shapes and values
952
+ for coord in [lon_name, lat_name]:
953
+ if ds[coord].shape != ds_grid[coord].shape:
954
+ raise ValueError(
955
+ f"Mismatch in {coord} array sizes. Dataset: {ds[coord].shape}, Grid: {ds_grid[coord].shape}"
956
+ )
957
+ if not np.allclose(ds[coord].values, ds_grid[coord].values):
958
+ raise ValueError(
959
+ f"{coord.capitalize()} values from dataset do not match grid. Dataset: {ds[coord].values}, Grid: {ds_grid[coord].values}"
960
+ )
961
+
962
+ # Add mask
963
+ mask = ds_grid[mask_name].isnull()
964
+ # Create a fresh xarray.DataArray for the mask using only the dimension names.
965
+ # This avoids issues that can arise from small coordinate mismatches when assigning directly.
966
+ ds["mask"] = xr.DataArray(
967
+ ds_grid[mask_name].isnull().values, dims=list(mask.dims)
968
+ )
969
+
870
970
  return ds
871
971
 
872
- def check_number_constituents(self, ntides: int):
873
- """Checks if the number of constituents in the dataset is at least `ntides`.
972
+ def select_constituents(self, ntides: int, omega: Dict[str, float]):
973
+ """Selects the first `ntides` tidal constituents based on the provided omega
974
+ values.
975
+
976
+ This method filters the dataset to retain only the tidal constituents that match
977
+ the first `ntides` from the provided `omega` dictionary. It ensures that the dataset
978
+ contains the expected constituents before proceeding with the selection. If the dataset
979
+ does not contain the required constituents, an error is raised.
874
980
 
875
981
  Parameters
876
982
  ----------
877
983
  ntides : int
878
- The required number of tidal constituents.
984
+ The number of tidal constituents to retain from the omega values. The method selects
985
+ the first `ntides` constituents based on the provided omega dictionary and maps their
986
+ corresponding values to the dataset.
987
+
988
+ omega : dict
989
+ A dictionary where keys are tidal constituent names and values are their associated omega
990
+ values. The first `ntides` keys from this dictionary will be used to filter the tidal
991
+ constituents in the dataset.
879
992
 
880
993
  Raises
881
994
  ------
882
995
  ValueError
883
- If the number of constituents in the dataset is less than `ntides`.
996
+ If the dataset does not contain all required tidal constituents from the first `ntides`
997
+ selected from the `omega` dictionary, a `ValueError` is raised, indicating the mismatch
998
+ between the expected and present constituents in the dataset.
884
999
  """
885
- if len(self.ds[self.dim_names["ntides"]]) < ntides:
886
- raise ValueError(
887
- f"The dataset contains fewer than {ntides} tidal constituents."
888
- )
889
1000
 
890
- def post_process(self):
891
- """Apply a depth-based mask to the dataset, ensuring only positive depths are
892
- retained.
1001
+ # Expected constituents based on the first 'ntides' from the omega dictionary
1002
+ expected_constituents = list(omega.keys())[:ntides]
893
1003
 
894
- This method checks if the 'depth' variable is present in the dataset. If found, a mask is created where
895
- depths greater than 0 are considered valid (mask value of 1). This mask is applied to all data variables
896
- in the dataset, replacing values at invalid depths (depth ≤ 0) with NaN. The mask itself is also stored
897
- in the dataset under the variable 'mask'.
1004
+ # Extract the current tidal constituents from the dataset
1005
+ dataset_constituents = [
1006
+ c.decode("utf-8").strip() for c in self.ds["ntides"].values
1007
+ ]
898
1008
 
899
- Returns
900
- -------
901
- None
902
- The dataset is modified in-place by applying the mask to each variable.
903
- """
1009
+ # Check if the dataset contains the expected constituents
1010
+ if not all(c in dataset_constituents for c in expected_constituents):
1011
+ raise ValueError(
1012
+ f"The dataset contains tidal constituents {dataset_constituents} that do not match the first {ntides} required constituents "
1013
+ f"from the TPXO dataset: {expected_constituents}. "
1014
+ "Ensure the dataset includes the required constituents or reduce the 'ntides' parameter."
1015
+ )
904
1016
 
905
- if "depth" in self.var_names.keys():
906
- ds = self.ds
907
- mask = xr.where(self.ds["depth"] > 0, 1, 0)
908
- ds["mask"] = mask
909
- ds = ds.drop_vars(["depth"])
1017
+ # Select only the expected constituents from the dataset
1018
+ filtered_constituents = [
1019
+ c for c in dataset_constituents if c in expected_constituents
1020
+ ]
910
1021
 
911
- self.ds = ds
1022
+ # Encode the filtered constituents back to byte strings before selecting in xarray
1023
+ filtered_constituents_bytes = [c.encode("utf-8") for c in filtered_constituents]
912
1024
 
913
- # Remove "depth" from var_names
914
- updated_var_names = {**self.var_names} # Create a copy of the dictionary
915
- updated_var_names.pop("depth", None) # Remove "depth" if it exists
916
- self.var_names = updated_var_names
1025
+ # Update the dataset with the filtered constituents
1026
+ self.ds = self.ds.sel(ntides=filtered_constituents_bytes)
917
1027
 
918
1028
 
919
1029
  @dataclass(kw_only=True)
@@ -1007,15 +1117,11 @@ class UnifiedDataset(Dataset):
1007
1117
  }
1008
1118
  )
1009
1119
 
1010
- object.__setattr__(
1011
- self,
1012
- "dim_names",
1013
- {
1014
- "latitude": "latitude",
1015
- "longitude": "longitude",
1016
- "depth": "depth",
1017
- },
1018
- )
1120
+ self.dim_names = {
1121
+ "latitude": "latitude",
1122
+ "longitude": "longitude",
1123
+ "depth": "depth",
1124
+ }
1019
1125
 
1020
1126
  # Handle time dimension
1021
1127
  if "time" not in self.dim_names:
@@ -2012,6 +2118,497 @@ class DaiRiverDataset(RiverDataset):
2012
2118
  return ds
2013
2119
 
2014
2120
 
2121
+ @dataclass
2122
+ class TPXOManager:
2123
+ """Manages multiple TPXODataset instances and selects and processes tidal
2124
+ constituents from the TPXO dataset.
2125
+
2126
+ This class handles multiple tidal constituents following the TPXO9v2a standard.
2127
+ The self-attraction and loading (SAL) correction data is sourced internally from TPXO9v2a.
2128
+
2129
+ Parameters
2130
+ ----------
2131
+ filenames : dict
2132
+ Dictionary containing paths to TPXO dataset files. Expected keys:
2133
+ - "h" : Path to the elevation file.
2134
+ - "u" : Path to the u-velocity component file.
2135
+ - "grid" : Path to the grid file.
2136
+
2137
+ ntides : int
2138
+ Number of tidal constituents to select for processing.
2139
+
2140
+ reference_date : datetime, optional
2141
+ Reference date for the TPXO data. Defaults to January 1, 1992.
2142
+ Used as the baseline for tidal time series calculations.
2143
+
2144
+ allan_factor : float, optional
2145
+ Factor used in tidal model computations. Defaults to 2.0.
2146
+
2147
+ use_dask : bool, optional
2148
+ Whether to use Dask for chunking. If True, data is loaded lazily; if False, data is loaded eagerly. Defaults to False.
2149
+
2150
+ Notes
2151
+ -----
2152
+ In TPXO products newer then 9v2a, the order of tidal constituents may change beyond the 10th constituent.
2153
+ Before selecting the first `ntides` constituents, newer products are reordered to match the TPXO9v2a standard.
2154
+ However, this reordering has minimal impact, as constituents beyond the 10th have negligible amplitudes.
2155
+ """
2156
+
2157
+ filenames: dict
2158
+ ntides: int
2159
+ reference_date: datetime = datetime(1992, 1, 1)
2160
+ allan_factor: float = 2.0
2161
+ use_dask: Optional[bool] = False
2162
+
2163
+ def __post_init__(self):
2164
+
2165
+ fname_sal = download_sal_data("sal_tpxo9.v2a.nc")
2166
+
2167
+ # Initialize the data_dict with TPXODataset instances
2168
+ data_dict = {
2169
+ "h": TPXODataset(
2170
+ filename=self.filenames["h"],
2171
+ grid_filename=self.filenames["grid"],
2172
+ location="h",
2173
+ var_names={"ssh_Re": "hRe", "ssh_Im": "hIm"},
2174
+ use_dask=self.use_dask,
2175
+ ),
2176
+ "sal": TPXODataset(
2177
+ filename=fname_sal,
2178
+ grid_filename=self.filenames["grid"],
2179
+ location="h",
2180
+ var_names={"sal_Re": "hRe", "sal_Im": "hIm"},
2181
+ use_dask=self.use_dask,
2182
+ tolerate_coord_mismatch=True, # Allow coordinate mismatch since SAL is from TPXO9v2a and may not align exactly with newer grids
2183
+ ),
2184
+ "u": TPXODataset(
2185
+ filename=self.filenames["u"],
2186
+ grid_filename=self.filenames["grid"],
2187
+ location="u",
2188
+ var_names={"u_Re": "URe", "u_Im": "UIm"},
2189
+ use_dask=self.use_dask,
2190
+ ),
2191
+ "v": TPXODataset(
2192
+ filename=self.filenames["u"],
2193
+ grid_filename=self.filenames["grid"],
2194
+ location="v",
2195
+ var_names={"v_Re": "VRe", "v_Im": "VIm"},
2196
+ use_dask=self.use_dask,
2197
+ ),
2198
+ }
2199
+
2200
+ omega = self.get_omega()
2201
+
2202
+ for data in data_dict.values():
2203
+ data.select_constituents(self.ntides, omega)
2204
+
2205
+ data_dict["omega"] = xr.DataArray(
2206
+ data=list(omega.values())[: self.ntides],
2207
+ dims="ntides",
2208
+ attrs={"long_name": "angular frequency", "units": "radians per second"},
2209
+ )
2210
+
2211
+ object.__setattr__(self, "datasets", data_dict)
2212
+
2213
+ def get_omega(self):
2214
+ """Retrieve angular frequencies (omega) for tidal constituents from the TPXO9.v2
2215
+ atlas.
2216
+
2217
+ This method returns the angular frequencies (in radians per second) for 15 tidal constituents,
2218
+ sourced from the TPXO tidal model and defined in the OTPSnc `constit.h` file, see https://www.tpxo.net/otps.
2219
+ These values are essential for tidal modeling and analysis.
2220
+
2221
+ Returns
2222
+ -------
2223
+ dict
2224
+ A dictionary where the keys are tidal constituent labels (str) and the values
2225
+ are their respective angular frequencies (float, in radians per second).
2226
+ """
2227
+ omega = {
2228
+ "m2": 1.405189e-04, # Principal lunar semidiurnal
2229
+ "s2": 1.454441e-04, # Principal solar semidiurnal
2230
+ "n2": 1.378797e-04, # Larger lunar elliptic semidiurnal
2231
+ "k2": 1.458423e-04, # Lunisolar semidiurnal
2232
+ "k1": 7.292117e-05, # Lunar diurnal
2233
+ "o1": 6.759774e-05, # Lunar diurnal
2234
+ "p1": 7.252295e-05, # Solar diurnal
2235
+ "q1": 6.495854e-05, # Larger lunar elliptic diurnal
2236
+ "mm": 0.026392e-04, # Lunar monthly
2237
+ "mf": 0.053234e-04, # Lunar fortnightly
2238
+ "m4": 2.810377e-04, # Shallow water overtide of M2
2239
+ "mn4": 2.783984e-04, # Shallow water quarter diurnal
2240
+ "ms4": 2.859630e-04, # Shallow water quarter diurnal
2241
+ "2n2": 1.352405e-04, # Shallow water semidiurnal
2242
+ "s1": 7.2722e-05, # Solar diurnal
2243
+ }
2244
+ return omega
2245
+
2246
+ def compute_equilibrium_tide(self, lon, lat):
2247
+ """Compute equilibrium tide for given longitudes and latitudes.
2248
+
2249
+ Parameters
2250
+ ----------
2251
+ lon : xr.DataArray
2252
+ Longitudes in degrees.
2253
+ lat : xr.DataArray
2254
+ Latitudes in degrees.
2255
+
2256
+ Returns
2257
+ -------
2258
+ tpc : xr.DataArray
2259
+ Equilibrium tide complex amplitude.
2260
+
2261
+ Notes
2262
+ -----
2263
+ This method calculates the equilibrium tide complex amplitude for specified
2264
+ longitudes and latitudes, considering 15 tidal constituents and their corresponding
2265
+ amplitudes and elasticity factors. The order of the tidal constituents corresponds
2266
+ to the order in `self.get_omega()`, which must remain consistent for future use.
2267
+
2268
+ The tidal constituents are categorized as follows:
2269
+ - **2**: Semidiurnal
2270
+ - **1**: Diurnal
2271
+ - **0**: Long-period
2272
+
2273
+ The amplitudes and elasticity factors are sourced from the `constit.h` file in the OTPSnc package.
2274
+ """
2275
+
2276
+ # Amplitudes for 15 tidal constituents (from variable amp_d in constit.h of OTPSnc package)
2277
+ A = xr.DataArray(
2278
+ data=np.array(
2279
+ [
2280
+ 0.242334, # M2
2281
+ 0.112743, # S2
2282
+ 0.046397, # N2
2283
+ 0.030684, # K2
2284
+ 0.141565, # K1
2285
+ 0.100661, # O1
2286
+ 0.046848, # P1
2287
+ 0.019273, # Q1
2288
+ 0.022191, # Mm
2289
+ 0.042041, # Mf
2290
+ 0.0, # M4
2291
+ 0.0, # Mn4
2292
+ 0.0, # Ms4
2293
+ 0.006141, # 2n2
2294
+ 0.000764, # S1
2295
+ ]
2296
+ ),
2297
+ dims="ntides",
2298
+ )
2299
+
2300
+ # Elasticity factors for 15 tidal constituents (from variable alpha_d in constit.h of OTPSnc package)
2301
+ B = xr.DataArray(
2302
+ data=np.array(
2303
+ [
2304
+ 0.693, # M2
2305
+ 0.693, # S2
2306
+ 0.693, # N2
2307
+ 0.693, # K2
2308
+ 0.736, # K1
2309
+ 0.695, # O1
2310
+ 0.706, # P1
2311
+ 0.695, # Q1
2312
+ 0.693, # Mm
2313
+ 0.693, # Mf
2314
+ 0.693, # M4
2315
+ 0.693, # Mn4
2316
+ 0.693, # Ms4
2317
+ 0.693, # 2n2
2318
+ 0.693, # S1
2319
+ ]
2320
+ ),
2321
+ dims="ntides",
2322
+ )
2323
+
2324
+ # Tidal type (from variable ispec_d in constit.h of OTPSnc package)
2325
+ # types: 2 = semidiurnal, 1 = diurnal, 0 = long-term
2326
+ ityp = xr.DataArray(
2327
+ data=np.array(
2328
+ [
2329
+ 2, # M2
2330
+ 2, # S2
2331
+ 2, # N2
2332
+ 2, # K2
2333
+ 1, # K1
2334
+ 1, # O1
2335
+ 1, # P1
2336
+ 1, # Q1
2337
+ 0, # Mm
2338
+ 0, # Mf
2339
+ 0, # M4
2340
+ 0, # Mn4
2341
+ 0, # Ms4
2342
+ 2, # 2n2
2343
+ 1, # S1
2344
+ ]
2345
+ ),
2346
+ dims="ntides",
2347
+ )
2348
+
2349
+ d2r = np.pi / 180
2350
+ coslat2 = np.cos(d2r * lat) ** 2
2351
+ sin2lat = np.sin(2 * d2r * lat)
2352
+
2353
+ p_amp = (
2354
+ xr.where(ityp == 2, 1, 0) * A * B * coslat2 # semidiurnal
2355
+ + xr.where(ityp == 1, 1, 0) * A * B * sin2lat # diurnal
2356
+ + xr.where(ityp == 0, 1, 0) * A * B * (0.5 - 1.5 * coslat2) # long-term
2357
+ )
2358
+ p_pha = (
2359
+ xr.where(ityp == 2, 1, 0) * (-2 * lon * d2r) # semidiurnal
2360
+ + xr.where(ityp == 1, 1, 0) * (-lon * d2r) # diurnal
2361
+ + xr.where(ityp == 0, 1, 0) * xr.zeros_like(lon) # long-term
2362
+ )
2363
+
2364
+ tpc = p_amp * np.exp(-1j * p_pha)
2365
+ tpc = tpc.isel(ntides=slice(None, self.ntides))
2366
+
2367
+ return tpc
2368
+
2369
+ def egbert_correction(self, date):
2370
+ """Correct phases and amplitudes for real-time runs using parts of the post-
2371
+ processing code from Egbert's & Erofeeva's (OSU) TPXO model.
2372
+
2373
+ Parameters
2374
+ ----------
2375
+ date : datetime.datetime
2376
+ The date and time for which corrections are to be applied.
2377
+
2378
+ Returns
2379
+ -------
2380
+ pf : xr.DataArray
2381
+ Amplitude scaling factor for each of the 15 tidal constituents.
2382
+ pu : xr.DataArray
2383
+ Phase correction [radians] for each of the 15 tidal constituents.
2384
+ aa : xr.DataArray
2385
+ Astronomical arguments [radians] associated with the corrections.
2386
+
2387
+ Notes
2388
+ -----
2389
+ The order of the tidal constituents corresponds
2390
+ to the order in `self.get_omega()`, which must remain consistent for future use.
2391
+
2392
+ References
2393
+ ----------
2394
+ - Egbert, G.D., and S.Y. Erofeeva. "Efficient inverse modeling of barotropic ocean
2395
+ tides." Journal of Atmospheric and Oceanic Technology 19, no. 2 (2002): 183-204.
2396
+ """
2397
+
2398
+ year = date.year
2399
+ month = date.month
2400
+ day = date.day
2401
+ hour = date.hour
2402
+ minute = date.minute
2403
+ second = date.second
2404
+
2405
+ rad = np.pi / 180.0
2406
+ deg = 180.0 / np.pi
2407
+ mjd = modified_julian_days(year, month, day)
2408
+ tstart = mjd + hour / 24 + minute / (60 * 24) + second / (60 * 60 * 24)
2409
+
2410
+ # Determine nodal corrections pu & pf : these expressions are valid for period 1990-2010 (Cartwright 1990).
2411
+ # Reset time origin for astronomical arguments to 4th of May 1860:
2412
+ timetemp = tstart - 51544.4993
2413
+
2414
+ # mean longitude of lunar perigee
2415
+ P = 83.3535 + 0.11140353 * timetemp
2416
+ P = np.mod(P, 360.0)
2417
+ if P < 0:
2418
+ P = +360
2419
+ P *= rad
2420
+
2421
+ # mean longitude of ascending lunar node
2422
+ N = 125.0445 - 0.05295377 * timetemp
2423
+ N = np.mod(N, 360.0)
2424
+ if N < 0:
2425
+ N = +360
2426
+ N *= rad
2427
+
2428
+ sinn = np.sin(N)
2429
+ cosn = np.cos(N)
2430
+ sin2n = np.sin(2 * N)
2431
+ cos2n = np.cos(2 * N)
2432
+ sin3n = np.sin(3 * N)
2433
+
2434
+ pftmp = np.sqrt(
2435
+ (1 - 0.03731 * cosn + 0.00052 * cos2n) ** 2
2436
+ + (0.03731 * sinn - 0.00052 * sin2n) ** 2
2437
+ )
2438
+
2439
+ pf = np.zeros(15)
2440
+ pf[0] = pftmp # M2
2441
+ pf[1] = 1.0 # S2
2442
+ pf[2] = pftmp # N2
2443
+ pf[3] = np.sqrt(
2444
+ (1 + 0.2852 * cosn + 0.0324 * cos2n) ** 2
2445
+ + (0.3108 * sinn + 0.0324 * sin2n) ** 2
2446
+ ) # K2
2447
+ pf[4] = np.sqrt(
2448
+ (1 + 0.1158 * cosn - 0.0029 * cos2n) ** 2
2449
+ + (0.1554 * sinn - 0.0029 * sin2n) ** 2
2450
+ ) # K1
2451
+ pf[5] = np.sqrt(
2452
+ (1 + 0.189 * cosn - 0.0058 * cos2n) ** 2
2453
+ + (0.189 * sinn - 0.0058 * sin2n) ** 2
2454
+ ) # O1
2455
+ pf[6] = 1.0 # P1
2456
+ pf[7] = np.sqrt((1 + 0.188 * cosn) ** 2 + (0.188 * sinn) ** 2) # Q1
2457
+ pf[8] = 1.0 - 0.130 * cosn # Mm
2458
+ pf[9] = 1.043 + 0.414 * cosn # Mf
2459
+ pf[10] = pftmp**2 # M4
2460
+ pf[11] = pftmp**2 # Mn4
2461
+ pf[12] = pftmp**2 # Ms4
2462
+ pf[13] = pftmp # 2n2
2463
+ pf[14] = 1.0 # S1
2464
+ pf = xr.DataArray(pf, dims="ntides")
2465
+
2466
+ putmp = (
2467
+ np.arctan(
2468
+ (-0.03731 * sinn + 0.00052 * sin2n)
2469
+ / (1.0 - 0.03731 * cosn + 0.00052 * cos2n)
2470
+ )
2471
+ * deg
2472
+ )
2473
+
2474
+ pu = np.zeros(15)
2475
+ pu[0] = putmp # M2
2476
+ pu[1] = 0.0 # S2
2477
+ pu[2] = putmp # N2
2478
+ pu[3] = (
2479
+ np.arctan(
2480
+ -(0.3108 * sinn + 0.0324 * sin2n)
2481
+ / (1.0 + 0.2852 * cosn + 0.0324 * cos2n)
2482
+ )
2483
+ * deg
2484
+ ) # K2
2485
+ pu[4] = (
2486
+ np.arctan(
2487
+ (-0.1554 * sinn + 0.0029 * sin2n)
2488
+ / (1.0 + 0.1158 * cosn - 0.0029 * cos2n)
2489
+ )
2490
+ * deg
2491
+ ) # K1
2492
+ pu[5] = 10.8 * sinn - 1.3 * sin2n + 0.2 * sin3n # O1
2493
+ pu[6] = 0.0 # P1
2494
+ pu[7] = np.arctan(0.189 * sinn / (1.0 + 0.189 * cosn)) * deg # Q1
2495
+ pu[8] = 0.0 # Mm
2496
+ pu[9] = -23.7 * sinn + 2.7 * sin2n - 0.4 * sin3n # Mf
2497
+ pu[10] = putmp * 2.0 # M4
2498
+ pu[11] = putmp * 2.0 # Mn4
2499
+ pu[12] = putmp # Ms4
2500
+ pu[13] = putmp # 2n2
2501
+ pu[14] = 0.0 # S1
2502
+ pu = xr.DataArray(pu, dims="ntides")
2503
+ # convert from degrees to radians
2504
+ pu = pu * rad
2505
+
2506
+ aa = xr.DataArray(
2507
+ data=np.array(
2508
+ [
2509
+ 1.731557546, # M2
2510
+ 0.0, # S2
2511
+ 6.050721243, # N2
2512
+ 3.487600001, # K2
2513
+ 0.173003674, # K1
2514
+ 1.558553872, # O1
2515
+ 6.110181633, # P1
2516
+ 5.877717569, # Q1
2517
+ 1.964021610, # Mm
2518
+ 1.756042456, # Mf
2519
+ 3.463115091, # M4
2520
+ 1.499093481, # Mn4
2521
+ 1.731557546, # Ms4
2522
+ 4.086699633, # 2n2
2523
+ 0.0, # S1
2524
+ ]
2525
+ ),
2526
+ dims="ntides",
2527
+ )
2528
+ pf = pf.isel(ntides=slice(None, self.ntides))
2529
+ pu = pu.isel(ntides=slice(None, self.ntides))
2530
+ aa = aa.isel(ntides=slice(None, self.ntides))
2531
+
2532
+ return pf, pu, aa
2533
+
2534
+ def correct_tides(self, model_reference_date):
2535
+ """Apply tidal corrections to the dataset. This method corrects the dataset for
2536
+ equilibrium tides, self-attraction and loading (SAL) effects, and adjusts phases
2537
+ and amplitudes of tidal elevations and transports using Egbert's correction.
2538
+
2539
+ Parameters
2540
+ ----------
2541
+ model_reference_date : datetime
2542
+ The reference date for the ROMS simulation.
2543
+
2544
+ Returns
2545
+ -------
2546
+ None
2547
+ The dataset is modified in-place with corrected real and imaginary components for ssh, u, v, and the
2548
+ potential field ('pot_Re', 'pot_Im').
2549
+ """
2550
+
2551
+ datasets = self.datasets
2552
+ omega = self.datasets["omega"].isel(ntides=slice(None, self.ntides))
2553
+
2554
+ # Get equilibrium tides
2555
+ lon = datasets["sal"].ds[datasets["sal"].dim_names["longitude"]]
2556
+ lat = datasets["sal"].ds[datasets["sal"].dim_names["latitude"]]
2557
+ tpc = self.compute_equilibrium_tide(lon, lat)
2558
+
2559
+ # Correct for SAL
2560
+ tsc = self.allan_factor * (
2561
+ datasets["sal"].ds[datasets["sal"].var_names["sal_Re"]]
2562
+ + 1j * datasets["sal"].ds[datasets["sal"].var_names["sal_Im"]]
2563
+ )
2564
+ tpc = tpc - tsc
2565
+
2566
+ # Elevations and transports
2567
+ thc = (
2568
+ datasets["h"].ds[datasets["h"].var_names["ssh_Re"]]
2569
+ + 1j * datasets["h"].ds[datasets["h"].var_names["ssh_Im"]]
2570
+ )
2571
+ tuc = (
2572
+ datasets["u"].ds[datasets["u"].var_names["u_Re"]]
2573
+ + 1j * datasets["u"].ds[datasets["u"].var_names["u_Im"]]
2574
+ )
2575
+ tvc = (
2576
+ datasets["v"].ds[datasets["v"].var_names["v_Re"]]
2577
+ + 1j * datasets["v"].ds[datasets["v"].var_names["v_Im"]]
2578
+ )
2579
+
2580
+ # Apply correction for phases and amplitudes
2581
+ pf, pu, aa = self.egbert_correction(model_reference_date)
2582
+
2583
+ dt = (model_reference_date - self.reference_date).days * 3600 * 24
2584
+
2585
+ thc = pf * thc * np.exp(1j * (omega * dt + pu + aa))
2586
+ tuc = pf * tuc * np.exp(1j * (omega * dt + pu + aa))
2587
+ tvc = pf * tvc * np.exp(1j * (omega * dt + pu + aa))
2588
+ tpc = pf * tpc * np.exp(1j * (omega * dt + pu + aa))
2589
+
2590
+ datasets["h"].ds[datasets["h"].var_names["ssh_Re"]] = thc.real
2591
+ datasets["h"].ds[datasets["h"].var_names["ssh_Im"]] = thc.imag
2592
+ datasets["u"].ds[datasets["u"].var_names["u_Re"]] = tuc.real
2593
+ datasets["u"].ds[datasets["u"].var_names["u_Im"]] = tuc.imag
2594
+ datasets["v"].ds[datasets["v"].var_names["v_Re"]] = tvc.real
2595
+ datasets["v"].ds[datasets["v"].var_names["v_Im"]] = tvc.imag
2596
+ datasets["sal"].ds["pot_Re"] = tpc.real
2597
+ datasets["sal"].ds["pot_Im"] = tpc.imag
2598
+
2599
+ object.__setattr__(self, "datasets", datasets)
2600
+
2601
+ # Update var_names dictionary
2602
+ var_names = {
2603
+ **datasets["sal"].var_names,
2604
+ "pot_Re": "pot_Re",
2605
+ "pot_Im": "pot_Im",
2606
+ }
2607
+ var_names.pop("sal_Re", None) # Remove "sal_Re" if it exists
2608
+ var_names.pop("sal_Im", None) # Remove "sal_Im" if it exists
2609
+ object.__setattr__(self.datasets["sal"], "var_names", var_names)
2610
+
2611
+
2015
2612
  # shared functions
2016
2613
 
2017
2614
 
@@ -2214,3 +2811,61 @@ def decode_string(byte_array):
2214
2811
  )
2215
2812
 
2216
2813
  return decoded_string
2814
+
2815
+
2816
+ def modified_julian_days(year, month, day, hour=0):
2817
+ """Calculate the Modified Julian Day (MJD) for a given date and time.
2818
+
2819
+ The Modified Julian Day (MJD) is a modified Julian day count starting from
2820
+ November 17, 1858 AD. It is commonly used in astronomy and geodesy.
2821
+
2822
+ Parameters
2823
+ ----------
2824
+ year : int
2825
+ The year.
2826
+ month : int
2827
+ The month (1-12).
2828
+ day : int
2829
+ The day of the month.
2830
+ hour : float, optional
2831
+ The hour of the day as a fractional number (0 to 23.999...). Default is 0.
2832
+
2833
+ Returns
2834
+ -------
2835
+ mjd : float
2836
+ The Modified Julian Day (MJD) corresponding to the input date and time.
2837
+
2838
+ Notes
2839
+ -----
2840
+ The algorithm assumes that the input date (year, month, day) is within the
2841
+ Gregorian calendar, i.e., after October 15, 1582. Negative MJD values are
2842
+ allowed for dates before November 17, 1858.
2843
+
2844
+ References
2845
+ ----------
2846
+ - Wikipedia article on Julian Day: https://en.wikipedia.org/wiki/Julian_day
2847
+ - Wikipedia article on Modified Julian Day: https://en.wikipedia.org/wiki/Modified_Julian_day
2848
+
2849
+ Examples
2850
+ --------
2851
+ >>> modified_julian_days(2024, 5, 20, 12)
2852
+ 58814.0
2853
+ >>> modified_julian_days(1858, 11, 17)
2854
+ 0.0
2855
+ >>> modified_julian_days(1582, 10, 4)
2856
+ -141428.5
2857
+ """
2858
+
2859
+ if month < 3:
2860
+ year -= 1
2861
+ month += 12
2862
+
2863
+ A = year // 100
2864
+ B = A // 4
2865
+ C = 2 - A + B
2866
+ E = int(365.25 * (year + 4716))
2867
+ F = int(30.6001 * (month + 1))
2868
+ jd = C + day + hour / 24 + E + F - 1524.5
2869
+ mjd = jd - 2400000.5
2870
+
2871
+ return mjd