roms-tools 2.5.0__py3-none-any.whl → 2.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. ci/environment-with-xesmf.yml +16 -0
  2. roms_tools/analysis/roms_output.py +521 -187
  3. roms_tools/analysis/utils.py +169 -0
  4. roms_tools/plot.py +351 -214
  5. roms_tools/regrid.py +161 -9
  6. roms_tools/setup/boundary_forcing.py +22 -22
  7. roms_tools/setup/datasets.py +40 -44
  8. roms_tools/setup/grid.py +28 -28
  9. roms_tools/setup/initial_conditions.py +23 -31
  10. roms_tools/setup/nesting.py +3 -3
  11. roms_tools/setup/river_forcing.py +22 -23
  12. roms_tools/setup/surface_forcing.py +14 -13
  13. roms_tools/setup/tides.py +7 -7
  14. roms_tools/setup/topography.py +2 -2
  15. roms_tools/tests/test_analysis/test_roms_output.py +299 -188
  16. roms_tools/tests/test_regrid.py +85 -2
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -2
  18. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +2 -2
  19. roms_tools/tests/test_setup/test_river_forcing.py +47 -51
  20. roms_tools/tests/test_vertical_coordinate.py +73 -0
  21. roms_tools/utils.py +11 -7
  22. roms_tools/vertical_coordinate.py +7 -0
  23. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/METADATA +22 -11
  24. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/RECORD +33 -30
  25. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/WHEEL +1 -1
  26. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zarray +0 -0
  27. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zattrs +0 -0
  28. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/0.0 +0 -0
  29. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zarray +0 -0
  30. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zattrs +0 -0
  31. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/0.0 +0 -0
  32. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info/licenses}/LICENSE +0 -0
  33. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/top_level.txt +0 -0
roms_tools/regrid.py CHANGED
@@ -1,7 +1,9 @@
1
+ import xgcm
1
2
  import xarray as xr
3
+ import warnings
2
4
 
3
5
 
4
- class LateralRegrid:
6
+ class LateralRegridToROMS:
5
7
  """Handles lateral regridding of data onto a new spatial grid."""
6
8
 
7
9
  def __init__(self, target_coords, source_dim_names):
@@ -49,7 +51,92 @@ class LateralRegrid:
49
51
  return regridded
50
52
 
51
53
 
52
- class VerticalRegrid:
54
+ class LateralRegridFromROMS:
55
+ """Regrids data from a curvilinear ROMS grid onto latitude-longitude coordinates
56
+ using xESMF.
57
+
58
+ It requires the `xesmf` library, which can be installed by installing `roms-tools` via conda.
59
+
60
+ Parameters
61
+ ----------
62
+ source_grid_ds : xarray.Dataset
63
+ The source dataset containing the curvilinear ROMS grid with 'lat_rho' and 'lon_rho'.
64
+
65
+ target_coords : dict
66
+ A dictionary containing 'lat' and 'lon' arrays representing the target
67
+ latitude and longitude coordinates for regridding.
68
+
69
+ method : str, optional
70
+ The regridding method to use. Default is "bilinear". Other options include "nearest_s2d" and "conservative".
71
+
72
+ Raises
73
+ ------
74
+ ImportError
75
+ If xESMF is not installed.
76
+ """
77
+
78
+ def __init__(self, ds_in, target_coords, method="bilinear"):
79
+ """Initializes the regridder with the source and target grids.
80
+
81
+ Parameters
82
+ ----------
83
+ ds_in : xarray.Dataset or xarray.DataArray
84
+ The source dataset or dataarray containing the curvilinear ROMS grid with coordinates 'lat' and 'lon'.
85
+
86
+ target_coords : dict
87
+ A dictionary containing 'lat' and 'lon' arrays representing the target latitude
88
+ and longitude coordinates for regridding.
89
+
90
+ method : str, optional
91
+ The regridding method to use. Default is "bilinear". Other options include
92
+ "nearest_s2d" and "conservative".
93
+
94
+ Raises
95
+ ------
96
+ ImportError
97
+ If xESMF is not installed.
98
+ """
99
+
100
+ try:
101
+ import xesmf as xe
102
+
103
+ except ImportError:
104
+ raise ImportError(
105
+ "xesmf is required for this regridding task. Please install `roms-tools` via conda, which includes xesmf."
106
+ )
107
+
108
+ ds_out = xr.Dataset()
109
+ ds_out["lat"] = target_coords["lat"]
110
+ ds_out["lon"] = target_coords["lon"]
111
+
112
+ with warnings.catch_warnings():
113
+ warnings.filterwarnings("ignore", category=UserWarning, module="xesmf")
114
+ self.regridder = xe.Regridder(
115
+ ds_in, ds_out, method=method, unmapped_to_nan=True
116
+ )
117
+
118
+ def apply(self, da):
119
+ """Applies the regridding to the provided data array.
120
+
121
+ Parameters
122
+ ----------
123
+ da : xarray.DataArray
124
+ The data array to regrid. This should have the same dimension names as the
125
+ source grid (e.g., 'lat' and 'lon').
126
+
127
+ Returns
128
+ -------
129
+ xarray.DataArray
130
+ The regridded data array.
131
+ """
132
+
133
+ with warnings.catch_warnings():
134
+ warnings.filterwarnings("ignore", category=UserWarning, module="xesmf")
135
+ regridded = self.regridder(da, keep_attrs=True)
136
+ return regridded
137
+
138
+
139
+ class VerticalRegridToROMS:
53
140
  """Interpolates data onto new vertical (depth) coordinates.
54
141
 
55
142
  Parameters
@@ -106,13 +193,13 @@ class VerticalRegrid:
106
193
  }
107
194
  )
108
195
 
109
- def apply(self, var, fill_nans=True):
196
+ def apply(self, da, fill_nans=True):
110
197
  """Interpolates the variable onto the new depth grid using precomputed
111
198
  coefficients for linear interpolation between layers.
112
199
 
113
200
  Parameters
114
201
  ----------
115
- var : xarray.DataArray
202
+ da : xarray.DataArray
116
203
  The input data to be regridded along the depth dimension. This should be
117
204
  an array with the same depth coordinates as the original grid.
118
205
  fill_nans : bool, optional
@@ -130,16 +217,81 @@ class VerticalRegrid:
130
217
 
131
218
  dims = {"dim": self.depth_dim}
132
219
 
133
- var_below = var.where(self.coeff["is_below"]).sum(**dims)
134
- var_above = var.where(self.coeff["is_above"]).sum(**dims)
220
+ da_below = da.where(self.coeff["is_below"]).sum(**dims)
221
+ da_above = da.where(self.coeff["is_above"]).sum(**dims)
135
222
 
136
- result = var_below + (var_above - var_below) * self.coeff["factor"]
223
+ result = da_below + (da_above - da_below) * self.coeff["factor"]
137
224
  if fill_nans:
138
- result = result.where(self.coeff["upper_mask"], var.isel({dims["dim"]: 0}))
139
- result = result.where(self.coeff["lower_mask"], var.isel({dims["dim"]: -1}))
225
+ result = result.where(self.coeff["upper_mask"], da.isel({dims["dim"]: 0}))
226
+ result = result.where(self.coeff["lower_mask"], da.isel({dims["dim"]: -1}))
140
227
  else:
141
228
  result = result.where(self.coeff["upper_mask"]).where(
142
229
  self.coeff["lower_mask"]
143
230
  )
144
231
 
145
232
  return result
233
+
234
+
235
+ class VerticalRegridFromROMS:
236
+ """A class for regridding data from the ROMS vertical coordinate system to target
237
+ depth levels.
238
+
239
+ This class uses the `xgcm` package to perform the transformation from the ROMS depth coordinates to
240
+ a user-defined set of target depth levels. It assumes that the input dataset `ds` contains the necessary
241
+ vertical coordinate information (`s_rho`).
242
+
243
+ Attributes
244
+ ----------
245
+ grid : xgcm.Grid
246
+ The grid object used for regridding, initialized with the given dataset `ds`.
247
+ """
248
+
249
+ def __init__(self, ds):
250
+ """Initializes the `VerticalRegridFromROMS` object by creating an `xgcm.Grid`
251
+ instance.
252
+
253
+ Parameters
254
+ ----------
255
+ ds : xarray.Dataset
256
+ The dataset containing the ROMS output data, which must include the vertical coordinate `s_rho`.
257
+ """
258
+ self.grid = xgcm.Grid(ds, coords={"s_rho": {"center": "s_rho"}}, periodic=False)
259
+
260
+ def apply(self, da, depth_coords, target_depth_levels, mask_edges=True):
261
+ """Applies vertical regridding from ROMS to the specified target depth levels.
262
+
263
+ This method transforms the input data array `da` from the ROMS vertical coordinate (`s_rho`)
264
+ to a set of target depth levels defined by `target_depth_levels`.
265
+
266
+ Parameters
267
+ ----------
268
+ da : xarray.DataArray
269
+ The data array containing the ROMS output field to be regridded. It must have a vertical
270
+ dimension corresponding to `s_rho`.
271
+
272
+ depth_coords : array-like
273
+ The depth coordinates of the input data array `da` (typically the `s_rho` coordinate in ROMS).
274
+
275
+ target_depth_levels : array-like
276
+ The target depth levels to which the input data `da` will be regridded.
277
+
278
+ mask_edges: bool, optional
279
+ If activated, target values outside the range of depth_coords are masked with nan. Defaults to True.
280
+
281
+ Returns
282
+ -------
283
+ xarray.DataArray
284
+ A new `xarray.DataArray` containing the regridded data at the specified target depth levels.
285
+ """
286
+
287
+ with warnings.catch_warnings():
288
+ warnings.filterwarnings("ignore", category=FutureWarning, module="xgcm")
289
+ transformed = self.grid.transform(
290
+ da,
291
+ "s_rho",
292
+ target_depth_levels,
293
+ target_data=depth_coords,
294
+ mask_edges=mask_edges,
295
+ )
296
+
297
+ return transformed
@@ -9,7 +9,7 @@ from datetime import datetime
9
9
  import matplotlib.pyplot as plt
10
10
  from pathlib import Path
11
11
  from roms_tools import Grid
12
- from roms_tools.regrid import LateralRegrid, VerticalRegrid
12
+ from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
13
13
  from roms_tools.utils import save_datasets
14
14
  from roms_tools.vertical_coordinate import compute_depth
15
15
  from roms_tools.plot import _section_plot, _line_plot
@@ -35,7 +35,7 @@ from roms_tools.setup.utils import (
35
35
  )
36
36
 
37
37
 
38
- @dataclass(frozen=True, kw_only=True)
38
+ @dataclass(kw_only=True)
39
39
  class BoundaryForcing:
40
40
  """Represents boundary forcing input data for ROMS.
41
41
 
@@ -124,7 +124,7 @@ class BoundaryForcing:
124
124
 
125
125
  self._input_checks()
126
126
  # Dataset for depth coordinates
127
- object.__setattr__(self, "ds_depth_coords", xr.Dataset())
127
+ self.ds_depth_coords = xr.Dataset()
128
128
 
129
129
  target_coords = get_target_coords(self.grid)
130
130
 
@@ -185,7 +185,7 @@ class BoundaryForcing:
185
185
  lat = target_coords["lat"].isel(
186
186
  **self.bdry_coords["vector"][direction]
187
187
  )
188
- lateral_regrid = LateralRegrid(
188
+ lateral_regrid = LateralRegridToROMS(
189
189
  {"lat": lat, "lon": lon}, bdry_data.dim_names
190
190
  )
191
191
  for var_name in vector_var_names:
@@ -214,7 +214,7 @@ class BoundaryForcing:
214
214
  lat = target_coords["lat"].isel(
215
215
  **self.bdry_coords["rho"][direction]
216
216
  )
217
- lateral_regrid = LateralRegrid(
217
+ lateral_regrid = LateralRegridToROMS(
218
218
  {"lat": lat, "lon": lon}, bdry_data.dim_names
219
219
  )
220
220
  for var_name in tracer_var_names:
@@ -292,7 +292,7 @@ class BoundaryForcing:
292
292
  # vertical regridding
293
293
  for location in ["rho", "u", "v"]:
294
294
  if len(var_names_dict[location]) > 0:
295
- vertical_regrid = VerticalRegrid(
295
+ vertical_regrid = VerticalRegridToROMS(
296
296
  self.ds_depth_coords[f"layer_depth_{location}_{direction}"],
297
297
  bdry_data.ds[bdry_data.dim_names["depth"]],
298
298
  )
@@ -335,7 +335,7 @@ class BoundaryForcing:
335
335
  for var_name in ds.data_vars:
336
336
  ds[var_name] = substitute_nans_by_fillvalue(ds[var_name])
337
337
 
338
- object.__setattr__(self, "ds", ds)
338
+ self.ds = ds
339
339
 
340
340
  def _input_checks(self):
341
341
  # Check that start_time and end_time are both None or none of them is
@@ -361,11 +361,10 @@ class BoundaryForcing:
361
361
  raise ValueError("`source` must include a 'path'.")
362
362
 
363
363
  # Set 'climatology' to False if not provided in 'source'
364
- object.__setattr__(
365
- self,
366
- "source",
367
- {**self.source, "climatology": self.source.get("climatology", False)},
368
- )
364
+ self.source = {
365
+ **self.source,
366
+ "climatology": self.source.get("climatology", False),
367
+ }
369
368
 
370
369
  # Ensure adjust_depth_for_sea_surface_height is only used with type="physics"
371
370
  if self.type == "bgc" and self.adjust_depth_for_sea_surface_height:
@@ -373,7 +372,7 @@ class BoundaryForcing:
373
372
  "adjust_depth_for_sea_surface_height is not applicable for BGC fields. "
374
373
  "Setting it to False."
375
374
  )
376
- object.__setattr__(self, "adjust_depth_for_sea_surface_height", False)
375
+ self.adjust_depth_for_sea_surface_height = False
377
376
  elif self.adjust_depth_for_sea_surface_height:
378
377
  logging.info("Sea surface height will be used to adjust depth coordinates.")
379
378
  else:
@@ -495,7 +494,7 @@ class BoundaryForcing:
495
494
  else:
496
495
  variable_info[var_name] = {**default_info, "validate": False}
497
496
 
498
- object.__setattr__(self, "variable_info", variable_info)
497
+ self.variable_info = variable_info
499
498
 
500
499
  def _write_into_dataset(self, direction, processed_fields, ds=None):
501
500
  if ds is None:
@@ -563,7 +562,7 @@ class BoundaryForcing:
563
562
 
564
563
  bdry_coords = get_boundary_coords()
565
564
 
566
- object.__setattr__(self, "bdry_coords", bdry_coords)
565
+ self.bdry_coords = bdry_coords
567
566
 
568
567
  def _get_depth_coordinates(
569
568
  self,
@@ -861,12 +860,6 @@ class BoundaryForcing:
861
860
 
862
861
  field = self.ds[var_name].isel(bry_time=time)
863
862
 
864
- if self.use_dask:
865
- from dask.diagnostics import ProgressBar
866
-
867
- with ProgressBar():
868
- field = field.load()
869
-
870
863
  title = field.long_name
871
864
  var_name_wo_direction, direction = var_name.split("_")
872
865
  location = self.variable_info[var_name_wo_direction]["location"]
@@ -881,10 +874,17 @@ class BoundaryForcing:
881
874
 
882
875
  mask = mask.isel(**self.bdry_coords[location][direction])
883
876
 
877
+ # Load the data
878
+ if self.use_dask:
879
+ from dask.diagnostics import ProgressBar
880
+
881
+ with ProgressBar():
882
+ field = field.load()
883
+
884
884
  if "s_rho" in field.dims:
885
885
  layer_depth = self.ds_depth_coords[f"layer_depth_{location}_{direction}"]
886
886
  if self.adjust_depth_for_sea_surface_height:
887
- layer_depth = layer_depth.isel(time=time)
887
+ layer_depth = layer_depth.isel(time=time).load()
888
888
  field = field.assign_coords({"layer_depth": layer_depth})
889
889
  if var_name.startswith(("u", "v", "ubar", "vbar", "zeta")):
890
890
  vmax = max(field.max().values, -field.min().values)
@@ -25,7 +25,7 @@ from roms_tools.setup.fill import LateralFill
25
25
  # lat-lon datasets
26
26
 
27
27
 
28
- @dataclass(frozen=True, kw_only=True)
28
+ @dataclass(kw_only=True)
29
29
  class Dataset:
30
30
  """Represents forcing data on original grid.
31
31
 
@@ -132,8 +132,8 @@ class Dataset:
132
132
  self.infer_horizontal_resolution(ds)
133
133
 
134
134
  # Check whether the data covers the entire globe
135
- object.__setattr__(self, "is_global", self.check_if_global(ds))
136
- object.__setattr__(self, "ds", ds)
135
+ self.is_global = self.check_if_global(ds)
136
+ self.ds = ds
137
137
 
138
138
  if self.apply_post_processing:
139
139
  self.post_process()
@@ -354,7 +354,7 @@ class Dataset:
354
354
  resolution = np.mean([lat_resolution, lon_resolution])
355
355
 
356
356
  # Set the computed resolution as an attribute
357
- object.__setattr__(self, "resolution", resolution)
357
+ self.resolution = resolution
358
358
 
359
359
  def compute_minimal_grid_spacing(self, ds: xr.Dataset):
360
360
  """Compute the minimal grid spacing in a dataset based on latitude and longitude
@@ -528,7 +528,7 @@ class Dataset:
528
528
  This method modifies the dataset in place and does not return anything.
529
529
  """
530
530
  ds = self.ds.astype({var: "float64" for var in self.ds.data_vars})
531
- object.__setattr__(self, "ds", ds)
531
+ self.ds = ds
532
532
 
533
533
  def choose_subdomain(
534
534
  self,
@@ -671,7 +671,7 @@ class Dataset:
671
671
  if return_copy:
672
672
  return Dataset.from_ds(self, subdomain)
673
673
  else:
674
- object.__setattr__(self, "ds", subdomain)
674
+ self.ds = subdomain
675
675
 
676
676
  def apply_lateral_fill(self):
677
677
  """Apply lateral fill to variables using the dataset's mask and grid dimensions.
@@ -757,7 +757,7 @@ class Dataset:
757
757
  dataset = cls.__new__(cls)
758
758
 
759
759
  # Directly set the provided dataset as the 'ds' attribute
760
- object.__setattr__(dataset, "ds", ds)
760
+ dataset.ds = ds
761
761
 
762
762
  # Copy all other attributes from the original data instance
763
763
  for attr in vars(original_dataset):
@@ -767,7 +767,7 @@ class Dataset:
767
767
  return dataset
768
768
 
769
769
 
770
- @dataclass(frozen=True, kw_only=True)
770
+ @dataclass(kw_only=True)
771
771
  class TPXODataset(Dataset):
772
772
  """Represents tidal data on the original grid from the TPXO dataset.
773
773
 
@@ -858,15 +858,11 @@ class TPXODataset(Dataset):
858
858
  {"nx": "longitude", "ny": "latitude", self.dim_names["ntides"]: "ntides"}
859
859
  )
860
860
 
861
- object.__setattr__(
862
- self,
863
- "dim_names",
864
- {
865
- "latitude": "latitude",
866
- "longitude": "longitude",
867
- "ntides": "ntides",
868
- },
869
- )
861
+ self.dim_names = {
862
+ "latitude": "latitude",
863
+ "longitude": "longitude",
864
+ "ntides": "ntides",
865
+ }
870
866
 
871
867
  return ds
872
868
 
@@ -909,15 +905,15 @@ class TPXODataset(Dataset):
909
905
  ds["mask"] = mask
910
906
  ds = ds.drop_vars(["depth"])
911
907
 
912
- object.__setattr__(self, "ds", ds)
908
+ self.ds = ds
913
909
 
914
910
  # Remove "depth" from var_names
915
911
  updated_var_names = {**self.var_names} # Create a copy of the dictionary
916
912
  updated_var_names.pop("depth", None) # Remove "depth" if it exists
917
- object.__setattr__(self, "var_names", updated_var_names)
913
+ self.var_names = updated_var_names
918
914
 
919
915
 
920
- @dataclass(frozen=True, kw_only=True)
916
+ @dataclass(kw_only=True)
921
917
  class GLORYSDataset(Dataset):
922
918
  """Represents GLORYS data on original grid.
923
919
 
@@ -996,7 +992,7 @@ class GLORYSDataset(Dataset):
996
992
  self.ds["mask_vel"] = mask_vel
997
993
 
998
994
 
999
- @dataclass(frozen=True, kw_only=True)
995
+ @dataclass(kw_only=True)
1000
996
  class CESMDataset(Dataset):
1001
997
  """Represents CESM data on original grid.
1002
998
 
@@ -1077,12 +1073,12 @@ class CESMDataset(Dataset):
1077
1073
  # Update dimension names
1078
1074
  updated_dim_names = self.dim_names.copy()
1079
1075
  updated_dim_names["time"] = "time"
1080
- object.__setattr__(self, "dim_names", updated_dim_names)
1076
+ self.dim_names = updated_dim_names
1081
1077
 
1082
1078
  return ds
1083
1079
 
1084
1080
 
1085
- @dataclass(frozen=True, kw_only=True)
1081
+ @dataclass(kw_only=True)
1086
1082
  class CESMBGCDataset(CESMDataset):
1087
1083
  """Represents CESM BGC data on original grid.
1088
1084
 
@@ -1184,12 +1180,12 @@ class CESMBGCDataset(CESMDataset):
1184
1180
  if "z_t_150m" in ds.variables:
1185
1181
  ds = ds.drop_vars("z_t_150m")
1186
1182
  # update dataset
1187
- object.__setattr__(self, "ds", ds)
1183
+ self.ds = ds
1188
1184
 
1189
1185
  # Update dim_names with "depth": "depth" key-value pair
1190
1186
  updated_dim_names = self.dim_names.copy()
1191
1187
  updated_dim_names["depth"] = "depth"
1192
- object.__setattr__(self, "dim_names", updated_dim_names)
1188
+ self.dim_names = updated_dim_names
1193
1189
 
1194
1190
  mask = xr.where(
1195
1191
  self.ds[self.var_names["PO4"]]
@@ -1202,7 +1198,7 @@ class CESMBGCDataset(CESMDataset):
1202
1198
  self.ds["mask"] = mask
1203
1199
 
1204
1200
 
1205
- @dataclass(frozen=True, kw_only=True)
1201
+ @dataclass(kw_only=True)
1206
1202
  class CESMBGCSurfaceForcingDataset(CESMDataset):
1207
1203
  """Represents CESM BGC surface forcing data on original grid.
1208
1204
 
@@ -1258,7 +1254,7 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
1258
1254
 
1259
1255
  if "z_t" in self.ds.variables:
1260
1256
  ds = self.ds.drop_vars("z_t")
1261
- object.__setattr__(self, "ds", ds)
1257
+ self.ds = ds
1262
1258
 
1263
1259
  mask = xr.where(
1264
1260
  self.ds[self.var_names["pco2_air"]]
@@ -1271,7 +1267,7 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
1271
1267
  self.ds["mask"] = mask
1272
1268
 
1273
1269
 
1274
- @dataclass(frozen=True, kw_only=True)
1270
+ @dataclass(kw_only=True)
1275
1271
  class ERA5Dataset(Dataset):
1276
1272
  """Represents ERA5 data on original grid.
1277
1273
 
@@ -1371,27 +1367,27 @@ class ERA5Dataset(Dataset):
1371
1367
  ds["qair"].attrs["long_name"] = "Absolute humidity at 2m"
1372
1368
  ds["qair"].attrs["units"] = "kg/kg"
1373
1369
  ds = ds.drop_vars([self.var_names["d2m"]])
1374
- object.__setattr__(self, "ds", ds)
1370
+ self.ds = ds
1375
1371
 
1376
1372
  # Update var_names dictionary
1377
1373
  var_names = {**self.var_names, "qair": "qair"}
1378
1374
  var_names.pop("d2m")
1379
- object.__setattr__(self, "var_names", var_names)
1375
+ self.var_names = var_names
1380
1376
 
1381
1377
  if "mask" in self.var_names.keys():
1382
1378
  ds = self.ds
1383
1379
  mask = xr.where(self.ds[self.var_names["mask"]].isel(time=0).isnull(), 0, 1)
1384
1380
  ds["mask"] = mask
1385
1381
  ds = ds.drop_vars([self.var_names["mask"]])
1386
- object.__setattr__(self, "ds", ds)
1382
+ self.ds = ds
1387
1383
 
1388
1384
  # Remove mask from var_names dictionary
1389
1385
  var_names = self.var_names
1390
1386
  var_names.pop("mask")
1391
- object.__setattr__(self, "var_names", var_names)
1387
+ self.var_names = var_names
1392
1388
 
1393
1389
 
1394
- @dataclass(frozen=True, kw_only=True)
1390
+ @dataclass(kw_only=True)
1395
1391
  class ERA5Correction(Dataset):
1396
1392
  """Global dataset to correct ERA5 radiation. The dataset contains multiplicative
1397
1393
  correction factors for the ERA5 shortwave radiation, obtained by comparing the
@@ -1493,10 +1489,10 @@ class ERA5Correction(Dataset):
1493
1489
  raise ValueError(
1494
1490
  "The correction dataset does not contain all specified longitude values."
1495
1491
  )
1496
- object.__setattr__(self, "ds", subdomain)
1492
+ self.ds = subdomain
1497
1493
 
1498
1494
 
1499
- @dataclass(frozen=True, kw_only=True)
1495
+ @dataclass(kw_only=True)
1500
1496
  class ETOPO5Dataset(Dataset):
1501
1497
  """Represents topography data on the original grid from the ETOPO5 dataset.
1502
1498
 
@@ -1553,7 +1549,7 @@ class ETOPO5Dataset(Dataset):
1553
1549
  return ds
1554
1550
 
1555
1551
 
1556
- @dataclass(frozen=True, kw_only=True)
1552
+ @dataclass(kw_only=True)
1557
1553
  class SRTM15Dataset(Dataset):
1558
1554
  """Represents topography data on the original grid from the SRTM15 dataset.
1559
1555
 
@@ -1589,7 +1585,7 @@ class SRTM15Dataset(Dataset):
1589
1585
 
1590
1586
 
1591
1587
  # river datasets
1592
- @dataclass(frozen=True, kw_only=True)
1588
+ @dataclass(kw_only=True)
1593
1589
  class RiverDataset:
1594
1590
  """Represents river data.
1595
1591
 
@@ -1647,7 +1643,7 @@ class RiverDataset:
1647
1643
 
1648
1644
  # Select relevant times
1649
1645
  ds = self.add_time_info(ds)
1650
- object.__setattr__(self, "ds", ds)
1646
+ self.ds = ds
1651
1647
 
1652
1648
  def load_data(self) -> xr.Dataset:
1653
1649
  """Load dataset from the specified file.
@@ -1785,13 +1781,13 @@ class RiverDataset:
1785
1781
 
1786
1782
  ds = assign_dates_to_climatology(self.ds, "month")
1787
1783
  ds = ds.swap_dims({"month": "time"})
1788
- object.__setattr__(self, "ds", ds)
1784
+ self.ds = ds
1789
1785
 
1790
1786
  updated_dim_names = {**self.dim_names}
1791
1787
  updated_dim_names["time"] = "time"
1792
- object.__setattr__(self, "dim_names", updated_dim_names)
1788
+ self.dim_names = updated_dim_names
1793
1789
 
1794
- object.__setattr__(self, "climatology", True)
1790
+ self.climatology = True
1795
1791
 
1796
1792
  def sort_by_river_volume(self, ds: xr.Dataset) -> xr.Dataset:
1797
1793
  """Sorts the dataset by river volume in descending order (largest rivers first),
@@ -1911,7 +1907,7 @@ class RiverDataset:
1911
1907
  ds = xr.Dataset()
1912
1908
  river_indices = {}
1913
1909
 
1914
- object.__setattr__(self, "ds", ds)
1910
+ self.ds = ds
1915
1911
 
1916
1912
  return river_indices
1917
1913
 
@@ -1961,10 +1957,10 @@ class RiverDataset:
1961
1957
  )
1962
1958
 
1963
1959
  # Set the filtered dataset as the new `ds`
1964
- object.__setattr__(self, "ds", ds_filtered)
1960
+ self.ds = ds_filtered
1965
1961
 
1966
1962
 
1967
- @dataclass(frozen=True, kw_only=True)
1963
+ @dataclass(kw_only=True)
1968
1964
  class DaiRiverDataset(RiverDataset):
1969
1965
  """Represents river data from the Dai river dataset.
1970
1966