roms-tools 2.5.0__py3-none-any.whl → 2.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. ci/environment-with-xesmf.yml +16 -0
  2. roms_tools/analysis/roms_output.py +521 -187
  3. roms_tools/analysis/utils.py +169 -0
  4. roms_tools/plot.py +351 -214
  5. roms_tools/regrid.py +161 -9
  6. roms_tools/setup/boundary_forcing.py +22 -22
  7. roms_tools/setup/datasets.py +40 -44
  8. roms_tools/setup/grid.py +28 -28
  9. roms_tools/setup/initial_conditions.py +23 -31
  10. roms_tools/setup/nesting.py +3 -3
  11. roms_tools/setup/river_forcing.py +22 -23
  12. roms_tools/setup/surface_forcing.py +14 -13
  13. roms_tools/setup/tides.py +7 -7
  14. roms_tools/setup/topography.py +2 -2
  15. roms_tools/tests/test_analysis/test_roms_output.py +299 -188
  16. roms_tools/tests/test_regrid.py +85 -2
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +2 -2
  18. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +2 -2
  19. roms_tools/tests/test_setup/test_river_forcing.py +47 -51
  20. roms_tools/tests/test_vertical_coordinate.py +73 -0
  21. roms_tools/utils.py +11 -7
  22. roms_tools/vertical_coordinate.py +7 -0
  23. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/METADATA +22 -11
  24. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/RECORD +33 -30
  25. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/WHEEL +1 -1
  26. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zarray +0 -0
  27. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/.zattrs +0 -0
  28. /roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/{river_location → river_flux}/0.0 +0 -0
  29. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zarray +0 -0
  30. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/.zattrs +0 -0
  31. /roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/{river_location → river_flux}/0.0 +0 -0
  32. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info/licenses}/LICENSE +0 -0
  33. {roms_tools-2.5.0.dist-info → roms_tools-2.6.1.dist-info}/top_level.txt +0 -0
roms_tools/setup/grid.py CHANGED
@@ -24,7 +24,7 @@ from roms_tools.setup.utils import extract_single_value
24
24
  from pathlib import Path
25
25
 
26
26
 
27
- @dataclass(frozen=True, kw_only=True)
27
+ @dataclass(kw_only=True)
28
28
  class Grid:
29
29
  """A single ROMS grid, used for creating, plotting, and then saving a new ROMS
30
30
  domain grid.
@@ -131,7 +131,7 @@ class Grid:
131
131
 
132
132
  def _input_checks(self):
133
133
  if self.topography_source is None:
134
- object.__setattr__(self, "topography_source", {"name": "ETOPO5"})
134
+ self.topography_source = {"name": "ETOPO5"}
135
135
 
136
136
  if "name" not in self.topography_source:
137
137
  raise ValueError(
@@ -157,7 +157,7 @@ class Grid:
157
157
  "========================================================================================================"
158
158
  )
159
159
 
160
- object.__setattr__(self, "ds", ds)
160
+ self.ds = ds
161
161
 
162
162
  def update_topography(
163
163
  self, topography_source=None, hmin=None, verbose=False
@@ -223,9 +223,9 @@ class Grid:
223
223
  )
224
224
 
225
225
  # Update the grid's dataset and related attributes
226
- object.__setattr__(self, "ds", ds)
227
- object.__setattr__(self, "topography_source", topography_source)
228
- object.__setattr__(self, "hmin", hmin)
226
+ self.ds = ds
227
+ self.topography_source = topography_source
228
+ self.hmin = hmin
229
229
 
230
230
  def update_vertical_coordinate(
231
231
  self, N=None, theta_s=None, theta_b=None, hc=None, verbose=False
@@ -321,11 +321,11 @@ class Grid:
321
321
  "========================================================================================================"
322
322
  )
323
323
 
324
- object.__setattr__(self, "ds", ds)
325
- object.__setattr__(self, "theta_s", theta_s)
326
- object.__setattr__(self, "theta_b", theta_b)
327
- object.__setattr__(self, "hc", hc)
328
- object.__setattr__(self, "N", N)
324
+ self.ds = ds
325
+ self.theta_s = theta_s
326
+ self.theta_b = theta_b
327
+ self.hc = hc
328
+ self.N = N
329
329
 
330
330
  def _straddle(self) -> None:
331
331
  """Check if the Greenwich meridian goes through the domain.
@@ -342,9 +342,9 @@ class Grid:
342
342
  np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
343
343
  or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
344
344
  ):
345
- object.__setattr__(self, "straddle", True)
345
+ self.straddle = True
346
346
  else:
347
- object.__setattr__(self, "straddle", False)
347
+ self.straddle = False
348
348
 
349
349
  def _coarsen(self):
350
350
  """Update the grid by adding grid variables that are coarsened versions of the
@@ -391,7 +391,7 @@ class Grid:
391
391
  ds[coarse_var].attrs["long_name"] = f"{long_name} on coarsened grid"
392
392
  ds[coarse_var].attrs["units"] = ds[fine_var].attrs["units"]
393
393
 
394
- object.__setattr__(self, "ds", ds)
394
+ self.ds = ds
395
395
 
396
396
  def plot(
397
397
  self, bathymetry: bool = True, title: str = None, with_dim_names: bool = False
@@ -461,7 +461,7 @@ class Grid:
461
461
  xi : int, optional
462
462
  The xi-index to plot. Default is None.
463
463
  ax : matplotlib.axes.Axes, optional
464
- The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots that display the eta- and xi-dimensions at the same time.
464
+ The axes to plot on. If None, a new figure is created. Note that this argument does not work for 2D horizontal plots.
465
465
 
466
466
  Returns
467
467
  -------
@@ -581,14 +581,14 @@ class Grid:
581
581
  grid = cls.__new__(cls)
582
582
 
583
583
  # Set the dataset for the grid instance
584
- object.__setattr__(grid, "ds", ds)
584
+ grid.ds = ds
585
585
 
586
586
  # Check if the Greenwich meridian goes through the domain.
587
587
  grid._straddle()
588
588
 
589
589
  if not all(coord in grid.ds for coord in ["lat_u", "lon_u", "lat_v", "lon_v"]):
590
590
  ds = _add_lat_lon_at_velocity_points(grid.ds, grid.straddle)
591
- object.__setattr__(grid, "ds", ds)
591
+ grid.ds = ds
592
592
 
593
593
  # Coarsen the grid if necessary
594
594
  if not all(
@@ -606,7 +606,7 @@ class Grid:
606
606
  for var in ["lat_rho", "lon_rho", "lat_coarse", "lon_coarse"]:
607
607
  if var not in ds.coords:
608
608
  ds = grid.ds.set_coords(var)
609
- object.__setattr__(grid, "ds", ds)
609
+ grid.ds = ds
610
610
 
611
611
  # Update vertical coordinate if necessary
612
612
  if not all(var in grid.ds for var in ["Cs_r", "Cs_w"]):
@@ -620,14 +620,14 @@ class Grid:
620
620
  N=N, theta_s=theta_s, theta_b=theta_b, hc=hc, verbose=True
621
621
  )
622
622
  else:
623
- object.__setattr__(grid, "theta_s", ds.attrs["theta_s"].item())
624
- object.__setattr__(grid, "theta_b", ds.attrs["theta_b"].item())
625
- object.__setattr__(grid, "hc", ds.attrs["hc"].item())
626
- object.__setattr__(grid, "N", len(ds.s_rho))
623
+ grid.theta_s = ds.attrs["theta_s"].item()
624
+ grid.theta_b = ds.attrs["theta_b"].item()
625
+ grid.hc = ds.attrs["hc"].item()
626
+ grid.N = len(ds.s_rho)
627
627
 
628
628
  # Manually set the remaining attributes by extracting parameters from dataset
629
- object.__setattr__(grid, "nx", ds.sizes["xi_rho"] - 2)
630
- object.__setattr__(grid, "ny", ds.sizes["eta_rho"] - 2)
629
+ grid.nx = ds.sizes["xi_rho"] - 2
630
+ grid.ny = ds.sizes["eta_rho"] - 2
631
631
  if "center_lon" in ds.attrs:
632
632
  center_lon = ds.attrs["center_lon"]
633
633
  elif "tra_lon" in ds:
@@ -637,7 +637,7 @@ class Grid:
637
637
  "Missing grid information: 'center_lon' attribute or 'tra_lon' variable "
638
638
  "must be present in the dataset."
639
639
  )
640
- object.__setattr__(grid, "center_lon", center_lon)
640
+ grid.center_lon = center_lon
641
641
  if "center_lat" in ds.attrs:
642
642
  center_lat = ds.attrs["center_lat"]
643
643
  elif "tra_lat" in ds:
@@ -647,7 +647,7 @@ class Grid:
647
647
  "Missing grid information: 'center_lat' attribute or 'tra_lat' variable "
648
648
  "must be present in the dataset."
649
649
  )
650
- object.__setattr__(grid, "center_lat", center_lat)
650
+ grid.center_lat = center_lat
651
651
  if "rot" in ds.attrs:
652
652
  rot = ds.attrs["rot"]
653
653
  elif "rotate" in ds:
@@ -657,7 +657,7 @@ class Grid:
657
657
  "Missing grid information: 'rot' attribute or 'rotate' variable "
658
658
  "must be present in the dataset."
659
659
  )
660
- object.__setattr__(grid, "rot", rot)
660
+ grid.rot = rot
661
661
 
662
662
  for attr in [
663
663
  "size_x",
@@ -848,7 +848,7 @@ class Grid:
848
848
  "========================================================================================================"
849
849
  )
850
850
 
851
- object.__setattr__(self, "ds", ds)
851
+ self.ds = ds
852
852
 
853
853
  def _add_global_metadata(self, ds):
854
854
  """Add global metadata and attributes to the dataset.
@@ -8,7 +8,7 @@ from pathlib import Path
8
8
  import logging
9
9
  from datetime import datetime
10
10
  from roms_tools import Grid
11
- from roms_tools.regrid import LateralRegrid, VerticalRegrid
11
+ from roms_tools.regrid import LateralRegridToROMS, VerticalRegridToROMS
12
12
  from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
13
13
  from roms_tools.utils import (
14
14
  transpose_dimensions,
@@ -34,7 +34,7 @@ from roms_tools.setup.utils import (
34
34
  )
35
35
 
36
36
 
37
- @dataclass(frozen=True, kw_only=True)
37
+ @dataclass(kw_only=True)
38
38
  class InitialConditions:
39
39
  """Represents initial conditions for ROMS, including physical and biogeochemical
40
40
  data.
@@ -115,7 +115,7 @@ class InitialConditions:
115
115
 
116
116
  self._input_checks()
117
117
  # Dataset for depth coordinates
118
- object.__setattr__(self, "ds_depth_coords", xr.Dataset())
118
+ self.ds_depth_coords = xr.Dataset()
119
119
 
120
120
  processed_fields = {}
121
121
  processed_fields = self._process_data(processed_fields, type="physics")
@@ -135,7 +135,7 @@ class InitialConditions:
135
135
  for var_name in ds.data_vars:
136
136
  ds[var_name] = substitute_nans_by_fillvalue(ds[var_name])
137
137
 
138
- object.__setattr__(self, "ds", ds)
138
+ self.ds = ds
139
139
 
140
140
  def _process_data(self, processed_fields, type="physics"):
141
141
 
@@ -161,7 +161,7 @@ class InitialConditions:
161
161
  var_names = variable_info.keys()
162
162
 
163
163
  # lateral regridding
164
- lateral_regrid = LateralRegrid(target_coords, data.dim_names)
164
+ lateral_regrid = LateralRegridToROMS(target_coords, data.dim_names)
165
165
 
166
166
  for var_name in var_names:
167
167
  if var_name in data.var_names.keys():
@@ -206,7 +206,7 @@ class InitialConditions:
206
206
  # Vertical regridding
207
207
  for location in ["rho", "u", "v"]:
208
208
  if len(var_names_dict[location]) > 0:
209
- vertical_regrid = VerticalRegrid(
209
+ vertical_regrid = VerticalRegridToROMS(
210
210
  self.ds_depth_coords[f"layer_depth_{location}"],
211
211
  data.ds[data.dim_names["depth"]],
212
212
  )
@@ -247,14 +247,10 @@ class InitialConditions:
247
247
  if "path" not in self.source.keys():
248
248
  raise ValueError("`source` must include a 'path'.")
249
249
  # set self.source["climatology"] to False if not provided
250
- object.__setattr__(
251
- self,
252
- "source",
253
- {
254
- **self.source,
255
- "climatology": self.source.get("climatology", False),
256
- },
257
- )
250
+ self.source = {
251
+ **self.source,
252
+ "climatology": self.source.get("climatology", False),
253
+ }
258
254
  if self.bgc_source is not None:
259
255
  if "name" not in self.bgc_source.keys():
260
256
  raise ValueError(
@@ -265,14 +261,10 @@ class InitialConditions:
265
261
  "`bgc_source` must include a 'path' if it is provided."
266
262
  )
267
263
  # set self.bgc_source["climatology"] to False if not provided
268
- object.__setattr__(
269
- self,
270
- "bgc_source",
271
- {
272
- **self.bgc_source,
273
- "climatology": self.bgc_source.get("climatology", False),
274
- },
275
- )
264
+ self.bgc_source = {
265
+ **self.bgc_source,
266
+ "climatology": self.bgc_source.get("climatology", False),
267
+ }
276
268
  if self.adjust_depth_for_sea_surface_height:
277
269
  logging.info("Sea surface height will be used to adjust depth coordinates.")
278
270
  else:
@@ -655,7 +647,7 @@ class InitialConditions:
655
647
  visualize the layering of the water column. For clarity, the number of layer
656
648
  contours displayed is limited to a maximum of 10. Default is False.
657
649
  ax : matplotlib.axes.Axes, optional
658
- The axes to plot on. If None, a new figure is created. Note that this argument does not work for horizontal plots that display the eta- and xi-dimensions at the same time.
650
+ The axes to plot on. If None, a new figure is created. Note that this argument does not work for 2D horizontal plots. Default is None.
659
651
 
660
652
  Returns
661
653
  -------
@@ -687,13 +679,6 @@ class InitialConditions:
687
679
  "Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
688
680
  )
689
681
 
690
- # Load the data
691
- if self.use_dask:
692
- from dask.diagnostics import ProgressBar
693
-
694
- with ProgressBar():
695
- self.ds[var_name].load()
696
-
697
682
  field = self.ds[var_name].squeeze()
698
683
 
699
684
  # Get correct mask and horizontal coordinates
@@ -715,11 +700,18 @@ class InitialConditions:
715
700
 
716
701
  field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
717
702
 
703
+ # Load the data
704
+ if self.use_dask:
705
+ from dask.diagnostics import ProgressBar
706
+
707
+ with ProgressBar():
708
+ self.ds[var_name].load()
709
+
718
710
  # Retrieve depth coordinates
719
711
  if s is not None:
720
712
  layer_contours = False
721
713
  # Note that `layer_depth_{loc}` has already been computed during `__post_init__`.
722
- layer_depth = self.ds_depth_coords[f"layer_depth_{loc}"].squeeze()
714
+ layer_depth = self.ds_depth_coords[f"layer_depth_{loc}"].squeeze().load()
723
715
 
724
716
  # Slice the field as desired
725
717
  def _slice_and_assign(
@@ -20,7 +20,7 @@ from roms_tools.setup.utils import (
20
20
  )
21
21
 
22
22
 
23
- @dataclass(frozen=True, kw_only=True)
23
+ @dataclass(kw_only=True)
24
24
  class ChildGrid(Grid):
25
25
  """Represents a ROMS child grid that is compatible with the provided parent grid.
26
26
 
@@ -87,7 +87,7 @@ class ChildGrid(Grid):
87
87
  self.metadata["period"],
88
88
  )
89
89
 
90
- object.__setattr__(self, "ds_nesting", ds_nesting)
90
+ self.ds_nesting = ds_nesting
91
91
 
92
92
  def _modify_child_topography_and_mask(self):
93
93
  """Adjust the child grid's topography and mask to align with the parent grid.
@@ -107,7 +107,7 @@ class ChildGrid(Grid):
107
107
  parent_grid_ds, child_grid_ds
108
108
  )
109
109
 
110
- object.__setattr__(self, "ds", child_grid_ds)
110
+ self.ds = child_grid_ds
111
111
 
112
112
  def update_topography(
113
113
  self, topography_source=None, hmin=None, verbose=False
@@ -23,7 +23,7 @@ from roms_tools.setup.utils import (
23
23
  )
24
24
 
25
25
 
26
- @dataclass(frozen=True, kw_only=True)
26
+ @dataclass(kw_only=True)
27
27
  class RiverForcing:
28
28
  """Represents river forcing input data for ROMS.
29
29
 
@@ -118,13 +118,13 @@ class RiverForcing:
118
118
  raise ValueError(
119
119
  "No relevant rivers found. Consider increasing domain size or using a different river dataset."
120
120
  )
121
- object.__setattr__(self, "original_indices", original_indices)
121
+ self.original_indices = original_indices
122
122
  updated_indices = self._move_rivers_to_closest_coast(target_coords, data)
123
- object.__setattr__(self, "indices", updated_indices)
123
+ self.indices = updated_indices
124
124
 
125
125
  else:
126
126
  logging.info("Use provided river indices.")
127
- object.__setattr__(self, "original_indices", self.indices)
127
+ self.original_indices = self.indices
128
128
  check_river_locations_are_along_coast(self.grid.ds.mask_rho, self.indices)
129
129
  data.extract_named_rivers(self.indices)
130
130
 
@@ -135,11 +135,11 @@ class RiverForcing:
135
135
  for var_name in ds.data_vars:
136
136
  ds[var_name] = substitute_nans_by_fillvalue(ds[var_name], fill_value=0.0)
137
137
 
138
- object.__setattr__(self, "ds", ds)
138
+ self.ds = ds
139
139
 
140
140
  def _input_checks(self):
141
141
  if self.source is None:
142
- object.__setattr__(self, "source", {"name": "DAI"})
142
+ self.source = {"name": "DAI"}
143
143
 
144
144
  if "name" not in self.source:
145
145
  raise ValueError("`source` must include a 'name'.")
@@ -148,11 +148,10 @@ class RiverForcing:
148
148
  raise ValueError("`source` must include a 'path'.")
149
149
 
150
150
  # Set 'climatology' to False if not provided in 'source'
151
- object.__setattr__(
152
- self,
153
- "source",
154
- {**self.source, "climatology": self.source.get("climatology", False)},
155
- )
151
+ self.source = {
152
+ **self.source,
153
+ "climatology": self.source.get("climatology", False),
154
+ }
156
155
 
157
156
  # Check if 'indices' is provided and has the correct format
158
157
  if self.indices is not None:
@@ -255,23 +254,23 @@ class RiverForcing:
255
254
  - `river_tracer`: A `DataArray` representing tracer data for temperature, salinity and BGC tracers (if specified) for each river over time.
256
255
  """
257
256
  if self.source["climatology"]:
258
- object.__setattr__(self, "climatology", True)
257
+ self.climatology = True
259
258
  else:
260
259
  if self.convert_to_climatology in ["never", "if_any_missing"]:
261
260
  data_ds = data.select_relevant_times(data.ds)
262
261
  if self.convert_to_climatology == "if_any_missing":
263
262
  if data_ds[data.var_names["flux"]].isnull().any():
264
263
  data.compute_climatology()
265
- object.__setattr__(self, "climatology", True)
264
+ self.climatology = True
266
265
  else:
267
- object.__setattr__(data, "ds", data_ds)
268
- object.__setattr__(self, "climatology", False)
266
+ data.ds = data_ds
267
+ self.climatology = False
269
268
  else:
270
- object.__setattr__(data, "ds", data_ds)
271
- object.__setattr__(self, "climatology", False)
269
+ data.ds = data_ds
270
+ self.climatology = False
272
271
  elif self.convert_to_climatology == "always":
273
272
  data.compute_climatology()
274
- object.__setattr__(self, "climatology", True)
273
+ self.climatology = True
275
274
 
276
275
  ds = xr.Dataset()
277
276
 
@@ -442,21 +441,21 @@ class RiverForcing:
442
441
  return river_indices
443
442
 
444
443
  def _write_indices_into_dataset(self, ds):
445
- """Adds river location indices to the dataset as the "river_location" variable.
444
+ """Adds river location indices to the dataset as the "river_flux" variable.
446
445
 
447
- This method creates a new "river_location" variable
446
+ This method creates a new "river_flux" variable
448
447
  using river station indices from `self.indices` and assigns it to the dataset.
449
448
  The indices specify the river station locations in terms of eta_rho and xi_rho grid cell indices.
450
449
 
451
450
  Parameters
452
451
  ----------
453
452
  ds : xarray.Dataset
454
- The dataset to which the "river_location" variable will be added.
453
+ The dataset to which the "river_flux" variable will be added.
455
454
 
456
455
  Returns
457
456
  -------
458
457
  xarray.Dataset
459
- The modified dataset with the "river_location" variable added.
458
+ The modified dataset with the "river_flux" variable added.
460
459
  """
461
460
 
462
461
  river_locations = xr.zeros_like(self.grid.ds.h)
@@ -475,7 +474,7 @@ class RiverForcing:
475
474
 
476
475
  river_locations.attrs["long_name"] = "River ID plus local volume fraction"
477
476
  river_locations.attrs["units"] = "none"
478
- ds["river_location"] = river_locations
477
+ ds["river_flux"] = river_locations
479
478
 
480
479
  ds = ds.drop_vars(["lat_rho", "lon_rho"])
481
480
 
@@ -9,7 +9,7 @@ import logging
9
9
  from typing import Dict, Union, List, Optional
10
10
  from roms_tools import Grid
11
11
  from roms_tools.utils import save_datasets
12
- from roms_tools.regrid import LateralRegrid
12
+ from roms_tools.regrid import LateralRegridToROMS
13
13
  from roms_tools.plot import _plot
14
14
  from roms_tools.setup.datasets import (
15
15
  ERA5Dataset,
@@ -30,7 +30,7 @@ from roms_tools.setup.utils import (
30
30
  )
31
31
 
32
32
 
33
- @dataclass(frozen=True, kw_only=True)
33
+ @dataclass(kw_only=True)
34
34
  class SurfaceForcing:
35
35
  """Represents surface forcing input data for ROMS.
36
36
 
@@ -121,10 +121,10 @@ class SurfaceForcing:
121
121
  logging.info("Data will be interpolated onto grid coarsened by factor 2.")
122
122
  else:
123
123
  logging.info("Data will be interpolated onto fine grid.")
124
- object.__setattr__(self, "use_coarse_grid", use_coarse_grid)
124
+ self.use_coarse_grid = use_coarse_grid
125
125
 
126
126
  target_coords = get_target_coords(self.grid, self.use_coarse_grid)
127
- object.__setattr__(self, "target_coords", target_coords)
127
+ self.target_coords = target_coords
128
128
 
129
129
  data.choose_subdomain(
130
130
  target_coords,
@@ -140,7 +140,7 @@ class SurfaceForcing:
140
140
 
141
141
  processed_fields = {}
142
142
  # lateral regridding
143
- lateral_regrid = LateralRegrid(target_coords, data.dim_names)
143
+ lateral_regrid = LateralRegridToROMS(target_coords, data.dim_names)
144
144
  for var_name in var_names:
145
145
  if var_name in data.var_names.keys():
146
146
  processed_fields[var_name] = lateral_regrid.apply(
@@ -171,7 +171,7 @@ class SurfaceForcing:
171
171
  for var_name in ds.data_vars:
172
172
  ds[var_name] = substitute_nans_by_fillvalue(ds[var_name])
173
173
 
174
- object.__setattr__(self, "ds", ds)
174
+ self.ds = ds
175
175
 
176
176
  def _input_checks(self):
177
177
  # Check that start_time and end_time are both None or none of them is
@@ -197,11 +197,10 @@ class SurfaceForcing:
197
197
  raise ValueError("`source` must include a 'path'.")
198
198
 
199
199
  # Set 'climatology' to False if not provided in 'source'
200
- object.__setattr__(
201
- self,
202
- "source",
203
- {**self.source, "climatology": self.source.get("climatology", False)},
204
- )
200
+ self.source = {
201
+ **self.source,
202
+ "climatology": self.source.get("climatology", False),
203
+ }
205
204
 
206
205
  # Validate 'coarse_grid_mode'
207
206
  valid_modes = ["auto", "always", "never"]
@@ -339,7 +338,7 @@ class SurfaceForcing:
339
338
  else:
340
339
  variable_info[var_name] = {**default_info, "validate": False}
341
340
 
342
- object.__setattr__(self, "variable_info", variable_info)
341
+ self.variable_info = variable_info
343
342
 
344
343
  def _apply_correction(self, processed_fields, data):
345
344
 
@@ -382,7 +381,9 @@ class SurfaceForcing:
382
381
  )
383
382
 
384
383
  # Spatial regridding
385
- lateral_regrid = LateralRegrid(self.target_coords, correction_data.dim_names)
384
+ lateral_regrid = LateralRegridToROMS(
385
+ self.target_coords, correction_data.dim_names
386
+ )
386
387
  corr_factor = lateral_regrid.apply(corr_factor)
387
388
 
388
389
  processed_fields["swrad"] = processed_fields["swrad"] * corr_factor
roms_tools/setup/tides.py CHANGED
@@ -8,7 +8,7 @@ from pathlib import Path
8
8
  from dataclasses import dataclass, field
9
9
  from roms_tools import Grid
10
10
  from roms_tools.plot import _plot
11
- from roms_tools.regrid import LateralRegrid
11
+ from roms_tools.regrid import LateralRegridToROMS
12
12
  from roms_tools.utils import save_datasets
13
13
  from roms_tools.setup.datasets import TPXODataset
14
14
  from roms_tools.setup.utils import (
@@ -25,7 +25,7 @@ from roms_tools.setup.utils import (
25
25
  )
26
26
 
27
27
 
28
- @dataclass(frozen=True, kw_only=True)
28
+ @dataclass(kw_only=True)
29
29
  class TidalForcing:
30
30
  """Represents tidal forcing for ROMS.
31
31
 
@@ -88,7 +88,7 @@ class TidalForcing:
88
88
  data.convert_to_float64()
89
89
 
90
90
  # select desired number of constituents
91
- object.__setattr__(data, "ds", data.ds.isel(ntides=slice(None, self.ntides)))
91
+ data.ds = data.ds.isel(ntides=slice(None, self.ntides))
92
92
  self._correct_tides(data)
93
93
 
94
94
  data.apply_lateral_fill()
@@ -98,7 +98,7 @@ class TidalForcing:
98
98
 
99
99
  processed_fields = {}
100
100
  # lateral regridding
101
- lateral_regrid = LateralRegrid(target_coords, data.dim_names)
101
+ lateral_regrid = LateralRegridToROMS(target_coords, data.dim_names)
102
102
  for var_name in var_names:
103
103
  if var_name in data.var_names.keys():
104
104
  processed_fields[var_name] = lateral_regrid.apply(
@@ -144,7 +144,7 @@ class TidalForcing:
144
144
  for var_name in ds.data_vars:
145
145
  ds[var_name] = substitute_nans_by_fillvalue(ds[var_name])
146
146
 
147
- object.__setattr__(self, "ds", ds)
147
+ self.ds = ds
148
148
 
149
149
  def _input_checks(self):
150
150
 
@@ -218,7 +218,7 @@ class TidalForcing:
218
218
  },
219
219
  }
220
220
 
221
- object.__setattr__(self, "variable_info", variable_info)
221
+ self.variable_info = variable_info
222
222
 
223
223
  def _write_into_dataset(self, processed_fields, d_meta):
224
224
 
@@ -504,7 +504,7 @@ class TidalForcing:
504
504
  var_names.pop("sal_Re", None) # Remove "sal_Re" if it exists
505
505
  var_names.pop("sal_Im", None) # Remove "sal_Im" if it exists
506
506
 
507
- object.__setattr__(data, "var_names", var_names)
507
+ data.var_names = var_names
508
508
 
509
509
 
510
510
  def modified_julian_days(year, month, day, hour=0):
@@ -6,7 +6,7 @@ import gcm_filters
6
6
  from roms_tools.setup.utils import handle_boundaries
7
7
  import warnings
8
8
  from itertools import count
9
- from roms_tools.regrid import LateralRegrid
9
+ from roms_tools.regrid import LateralRegridToROMS
10
10
  from roms_tools.setup.datasets import ETOPO5Dataset, SRTM15Dataset
11
11
 
12
12
 
@@ -150,7 +150,7 @@ def _make_raw_topography(
150
150
 
151
151
  if verbose:
152
152
  start_time = time.time()
153
- lateral_regrid = LateralRegrid(target_coords, data.dim_names)
153
+ lateral_regrid = LateralRegridToROMS(target_coords, data.dim_names)
154
154
  hraw = lateral_regrid.apply(data.ds[data.var_names["topo"]], method=method)
155
155
  if verbose:
156
156
  logging.info(