roms-tools 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/__init__.py CHANGED
@@ -11,6 +11,5 @@ except ImportError: # pragma: no cover
11
11
  from roms_tools.setup.grid import Grid # noqa: F401
12
12
  from roms_tools.setup.tides import TidalForcing # noqa: F401
13
13
  from roms_tools.setup.surface_forcing import SurfaceForcing # noqa: F401
14
- from roms_tools.setup.vertical_coordinate import VerticalCoordinate # noqa: F401
15
14
  from roms_tools.setup.initial_conditions import InitialConditions # noqa: F401
16
15
  from roms_tools.setup.boundary_forcing import BoundaryForcing # noqa: F401
roms_tools/_version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # Do not change! Do not track in version control!
2
- __version__ = "1.0.1"
2
+ __version__ = "1.2.0"
@@ -7,7 +7,6 @@ import importlib.metadata
7
7
  from typing import Dict, Union, Optional
8
8
  from dataclasses import dataclass, field, asdict
9
9
  from roms_tools.setup.grid import Grid
10
- from roms_tools.setup.vertical_coordinate import VerticalCoordinate
11
10
  from roms_tools.setup.mixins import ROMSToolsMixins
12
11
  from datetime import datetime
13
12
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
@@ -29,8 +28,6 @@ class BoundaryForcing(ROMSToolsMixins):
29
28
  ----------
30
29
  grid : Grid
31
30
  Object representing the grid information.
32
- vertical_coordinate: VerticalCoordinate
33
- Object representing the vertical coordinate information.
34
31
  start_time : datetime
35
32
  Start time of the desired boundary forcing data.
36
33
  end_time : datetime
@@ -59,7 +56,6 @@ class BoundaryForcing(ROMSToolsMixins):
59
56
  --------
60
57
  >>> boundary_forcing = BoundaryForcing(
61
58
  ... grid=grid,
62
- ... vertical_coordinate=vertical_coordinate,
63
59
  ... boundaries={"south": True, "east": True, "north": False, "west": True},
64
60
  ... start_time=datetime(2022, 1, 1),
65
61
  ... end_time=datetime(2022, 1, 2),
@@ -73,7 +69,6 @@ class BoundaryForcing(ROMSToolsMixins):
73
69
  """
74
70
 
75
71
  grid: Grid
76
- vertical_coordinate: VerticalCoordinate
77
72
  start_time: datetime
78
73
  end_time: datetime
79
74
  boundaries: Dict[str, bool] = field(
@@ -106,7 +101,7 @@ class BoundaryForcing(ROMSToolsMixins):
106
101
  vars_2d = ["zeta"]
107
102
  vars_3d = ["temp", "salt", "u", "v"]
108
103
  data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
109
- data_vars = super().process_velocities(data_vars, angle)
104
+ data_vars = super().process_velocities(data_vars, angle, "u", "v")
110
105
  object.__setattr__(data, "data_vars", data_vars)
111
106
 
112
107
  if self.bgc_source is not None:
@@ -126,9 +121,9 @@ class BoundaryForcing(ROMSToolsMixins):
126
121
  bgc_data = None
127
122
 
128
123
  d_meta = super().get_variable_metadata()
129
- bdry_coords, rename = super().get_boundary_info()
124
+ bdry_coords = super().get_boundary_info()
130
125
 
131
- ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords, rename)
126
+ ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords)
132
127
 
133
128
  for direction in ["south", "east", "north", "west"]:
134
129
  if self.boundaries[direction]:
@@ -207,7 +202,7 @@ class BoundaryForcing(ROMSToolsMixins):
207
202
 
208
203
  return data
209
204
 
210
- def _write_into_dataset(self, data, d_meta, bdry_coords, rename):
205
+ def _write_into_dataset(self, data, d_meta, bdry_coords):
211
206
 
212
207
  # save in new dataset
213
208
  ds = xr.Dataset()
@@ -220,21 +215,18 @@ class BoundaryForcing(ROMSToolsMixins):
220
215
  ds[f"{var}_{direction}"] = (
221
216
  data.data_vars[var]
222
217
  .isel(**bdry_coords["u"][direction])
223
- .rename(**rename["u"][direction])
224
218
  .astype(np.float32)
225
219
  )
226
220
  elif var in ["v", "vbar"]:
227
221
  ds[f"{var}_{direction}"] = (
228
222
  data.data_vars[var]
229
223
  .isel(**bdry_coords["v"][direction])
230
- .rename(**rename["v"][direction])
231
224
  .astype(np.float32)
232
225
  )
233
226
  else:
234
227
  ds[f"{var}_{direction}"] = (
235
228
  data.data_vars[var]
236
229
  .isel(**bdry_coords["rho"][direction])
237
- .rename(**rename["rho"][direction])
238
230
  .astype(np.float32)
239
231
  )
240
232
  ds[f"{var}_{direction}"].attrs[
@@ -293,127 +285,63 @@ class BoundaryForcing(ROMSToolsMixins):
293
285
 
294
286
  return ds
295
287
 
296
- def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords, rename):
288
+ def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords):
297
289
 
298
290
  ds = self._add_global_metadata()
299
- ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
300
- ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
291
+ ds["sc_r"] = self.grid.ds["sc_r"]
292
+ ds["Cs_r"] = self.grid.ds["Cs_r"]
301
293
 
302
294
  ds = DataTree(name="root", data=ds)
303
295
 
304
- ds_physics = self._write_into_dataset(data, d_meta, bdry_coords, rename)
305
- ds_physics = self._add_coordinates(bdry_coords, rename, ds_physics)
296
+ ds_physics = self._write_into_dataset(data, d_meta, bdry_coords)
306
297
  ds_physics = self._add_global_metadata(ds_physics)
307
298
  ds_physics.attrs["physics_source"] = self.physics_source["name"]
308
299
 
309
300
  ds_physics = DataTree(name="physics", parent=ds, data=ds_physics)
310
301
 
311
302
  if bgc_data:
312
- ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords, rename)
313
- ds_bgc = self._add_coordinates(bdry_coords, rename, ds_bgc)
303
+ ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords)
314
304
  ds_bgc = self._add_global_metadata(ds_bgc)
315
305
  ds_bgc.attrs["bgc_source"] = self.bgc_source["name"]
316
306
  ds_bgc = DataTree(name="bgc", parent=ds, data=ds_bgc)
317
307
 
318
308
  return ds
319
309
 
320
- def _add_coordinates(self, bdry_coords, rename, ds=None):
321
-
322
- if ds is None:
323
- ds = xr.Dataset()
324
-
325
- for direction in ["south", "east", "north", "west"]:
326
-
327
- if self.boundaries[direction]:
310
+ def _get_coordinates(self, direction, point):
311
+ """
312
+ Retrieve layer and interface depth coordinates for a specified grid boundary.
328
313
 
329
- lat_rho = self.grid.ds.lat_rho.isel(
330
- **bdry_coords["rho"][direction]
331
- ).rename(**rename["rho"][direction])
332
- lon_rho = self.grid.ds.lon_rho.isel(
333
- **bdry_coords["rho"][direction]
334
- ).rename(**rename["rho"][direction])
335
- layer_depth_rho = (
336
- self.vertical_coordinate.ds["layer_depth_rho"]
337
- .isel(**bdry_coords["rho"][direction])
338
- .rename(**rename["rho"][direction])
339
- )
340
- interface_depth_rho = (
341
- self.vertical_coordinate.ds["interface_depth_rho"]
342
- .isel(**bdry_coords["rho"][direction])
343
- .rename(**rename["rho"][direction])
344
- )
314
+ This method extracts the layer depth and interface depth coordinates along
315
+ a specified boundary (north, south, east, or west) and for a specified point
316
+ type (rho, u, or v) from the grid dataset.
345
317
 
346
- lat_u = self.grid.ds.lat_u.isel(**bdry_coords["u"][direction]).rename(
347
- **rename["u"][direction]
348
- )
349
- lon_u = self.grid.ds.lon_u.isel(**bdry_coords["u"][direction]).rename(
350
- **rename["u"][direction]
351
- )
352
- layer_depth_u = (
353
- self.vertical_coordinate.ds["layer_depth_u"]
354
- .isel(**bdry_coords["u"][direction])
355
- .rename(**rename["u"][direction])
356
- )
357
- interface_depth_u = (
358
- self.vertical_coordinate.ds["interface_depth_u"]
359
- .isel(**bdry_coords["u"][direction])
360
- .rename(**rename["u"][direction])
361
- )
318
+ Parameters
319
+ ----------
320
+ direction : str
321
+ The direction of the boundary to retrieve coordinates for. Valid options
322
+ are "north", "south", "east", and "west".
323
+ point : str
324
+ The type of grid point to retrieve coordinates for. Valid options are
325
+ "rho" for the grid's central points, "u" for the u-flux points, and "v"
326
+ for the v-flux points.
362
327
 
363
- lat_v = self.grid.ds.lat_v.isel(**bdry_coords["v"][direction]).rename(
364
- **rename["v"][direction]
365
- )
366
- lon_v = self.grid.ds.lon_v.isel(**bdry_coords["v"][direction]).rename(
367
- **rename["v"][direction]
368
- )
369
- layer_depth_v = (
370
- self.vertical_coordinate.ds["layer_depth_v"]
371
- .isel(**bdry_coords["v"][direction])
372
- .rename(**rename["v"][direction])
373
- )
374
- interface_depth_v = (
375
- self.vertical_coordinate.ds["interface_depth_v"]
376
- .isel(**bdry_coords["v"][direction])
377
- .rename(**rename["v"][direction])
378
- )
328
+ Returns
329
+ -------
330
+ xarray.DataArray, xarray.DataArray
331
+ The layer depth and interface depth coordinates for the specified grid
332
+ boundary and point type.
333
+ """
379
334
 
380
- ds = ds.assign_coords(
381
- {
382
- f"layer_depth_rho_{direction}": layer_depth_rho,
383
- f"layer_depth_u_{direction}": layer_depth_u,
384
- f"layer_depth_v_{direction}": layer_depth_v,
385
- f"interface_depth_rho_{direction}": interface_depth_rho,
386
- f"interface_depth_u_{direction}": interface_depth_u,
387
- f"interface_depth_v_{direction}": interface_depth_v,
388
- f"lat_rho_{direction}": lat_rho,
389
- f"lat_u_{direction}": lat_u,
390
- f"lat_v_{direction}": lat_v,
391
- f"lon_rho_{direction}": lon_rho,
392
- f"lon_u_{direction}": lon_u,
393
- f"lon_v_{direction}": lon_v,
394
- }
395
- )
335
+ bdry_coords = super().get_boundary_info()
396
336
 
397
- # Gracefully handle dropping variables that might not be present
398
- variables_to_drop = [
399
- "s_rho",
400
- "layer_depth_rho",
401
- "layer_depth_u",
402
- "layer_depth_v",
403
- "interface_depth_rho",
404
- "interface_depth_u",
405
- "interface_depth_v",
406
- "lat_rho",
407
- "lon_rho",
408
- "lat_u",
409
- "lon_u",
410
- "lat_v",
411
- "lon_v",
412
- ]
413
- existing_vars = [var for var in variables_to_drop if var in ds]
414
- ds = ds.drop_vars(existing_vars)
337
+ layer_depth = self.grid.ds[f"layer_depth_{point}"].isel(
338
+ **bdry_coords[point][direction]
339
+ )
340
+ interface_depth = self.grid.ds[f"interface_depth_{point}"].isel(
341
+ **bdry_coords[point][direction]
342
+ )
415
343
 
416
- return ds
344
+ return layer_depth, interface_depth
417
345
 
418
346
  def _add_global_metadata(self, ds=None):
419
347
 
@@ -430,10 +358,9 @@ class BoundaryForcing(ROMSToolsMixins):
430
358
  ds.attrs["end_time"] = str(self.end_time)
431
359
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
432
360
 
433
- ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
434
- ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
435
- ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
436
- ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
361
+ ds.attrs["theta_s"] = self.grid.ds.attrs["theta_s"]
362
+ ds.attrs["theta_b"] = self.grid.ds.attrs["theta_b"]
363
+ ds.attrs["hc"] = self.grid.ds.attrs["hc"]
437
364
 
438
365
  return ds
439
366
 
@@ -492,8 +419,9 @@ class BoundaryForcing(ROMSToolsMixins):
492
419
  time : int, optional
493
420
  The time index to plot. Default is 0.
494
421
  layer_contours : bool, optional
495
- Whether to include layer contours in the plot. This can help visualize the depth levels
496
- of the field. Default is False.
422
+ If True, contour lines representing the boundaries between vertical layers will
423
+ be added to the plot. For clarity, the number of layer
424
+ contours displayed is limited to a maximum of 10. Default is False.
497
425
 
498
426
  Returns
499
427
  -------
@@ -519,6 +447,19 @@ class BoundaryForcing(ROMSToolsMixins):
519
447
  field = ds[varname].isel(bry_time=time).load()
520
448
  title = field.long_name
521
449
 
450
+ if "s_rho" in field.dims:
451
+ if varname.startswith(("u_", "ubar_")):
452
+ point = "u"
453
+ elif varname.startswith(("v_", "vbar_")):
454
+ point = "v"
455
+ else:
456
+ point = "rho"
457
+ direction = varname.split("_")[-1]
458
+
459
+ layer_depth, interface_depth = self._get_coordinates(direction, point)
460
+
461
+ field = field.assign_coords({"layer_depth": layer_depth})
462
+
522
463
  # chose colorbar
523
464
  if varname.startswith(("u", "v", "ubar", "vbar", "zeta")):
524
465
  vmax = max(field.max().values, -field.min().values)
@@ -536,27 +477,6 @@ class BoundaryForcing(ROMSToolsMixins):
536
477
 
537
478
  if len(field.dims) == 2:
538
479
  if layer_contours:
539
- depths_to_check = [
540
- "interface_depth_rho",
541
- "interface_depth_u",
542
- "interface_depth_v",
543
- ]
544
- try:
545
- interface_depth = next(
546
- ds[depth_label]
547
- for depth_label in ds.coords
548
- if any(
549
- depth_label.startswith(prefix) for prefix in depths_to_check
550
- )
551
- and (
552
- set(ds[depth_label].dims) - {"s_w"}
553
- == set(field.dims) - {"s_rho"}
554
- )
555
- )
556
- except StopIteration:
557
- raise ValueError(
558
- f"None of the expected depths ({', '.join(depths_to_check)}) have dimensions matching field.dims"
559
- )
560
480
  # restrict number of layer_contours to 10 for the sake of plot clearity
561
481
  nr_layers = len(interface_depth["s_w"])
562
482
  selected_layers = np.linspace(
@@ -664,11 +584,6 @@ class BoundaryForcing(ROMSToolsMixins):
664
584
  grid_data.pop("ds", None) # Exclude non-serializable fields
665
585
  grid_data.pop("straddle", None)
666
586
 
667
- # Serialize VerticalCoordinate data
668
- vertical_coordinate_data = asdict(self.vertical_coordinate)
669
- vertical_coordinate_data.pop("ds", None) # Exclude non-serializable fields
670
- vertical_coordinate_data.pop("grid", None) # Exclude non-serializable fields
671
-
672
587
  # Include the version of roms-tools
673
588
  try:
674
589
  roms_tools_version = importlib.metadata.version("roms-tools")
@@ -679,7 +594,6 @@ class BoundaryForcing(ROMSToolsMixins):
679
594
  header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
680
595
 
681
596
  grid_yaml_data = {"Grid": grid_data}
682
- vertical_coordinate_yaml_data = {"VerticalCoordinate": vertical_coordinate_data}
683
597
 
684
598
  boundary_forcing_data = {
685
599
  "BoundaryForcing": {
@@ -694,7 +608,6 @@ class BoundaryForcing(ROMSToolsMixins):
694
608
 
695
609
  yaml_data = {
696
610
  **grid_yaml_data,
697
- **vertical_coordinate_yaml_data,
698
611
  **boundary_forcing_data,
699
612
  }
700
613
 
@@ -745,13 +658,10 @@ class BoundaryForcing(ROMSToolsMixins):
745
658
  boundary_forcing_data[date_string]
746
659
  )
747
660
 
748
- # Create VerticalCoordinate instance from the YAML file
749
- vertical_coordinate = VerticalCoordinate.from_yaml(filepath)
750
- grid = vertical_coordinate.grid
661
+ grid = Grid.from_yaml(filepath)
751
662
 
752
663
  # Create and return an instance of InitialConditions
753
664
  return cls(
754
665
  grid=grid,
755
- vertical_coordinate=vertical_coordinate,
756
666
  **boundary_forcing_data,
757
667
  )
@@ -584,9 +584,14 @@ class TPXODataset(Dataset):
584
584
  "ntides": self.dim_names["ntides"],
585
585
  },
586
586
  )
587
+ self.check_dataset(ds)
588
+
587
589
  # Select relevant fields
588
590
  ds = super().select_relevant_fields(ds)
589
591
 
592
+ # Make sure that latitude is ascending
593
+ ds = super().ensure_latitude_ascending(ds)
594
+
590
595
  # Check whether the data covers the entire globe
591
596
  object.__setattr__(self, "is_global", super().check_if_global(ds))
592
597
 
@@ -769,6 +774,8 @@ class CESMDataset(Dataset):
769
774
  ds = assign_dates_to_climatology(ds, time_dim)
770
775
  # rename dimension
771
776
  ds = ds.swap_dims({time_dim: "time"})
777
+ if time_dim in ds.variables:
778
+ ds = ds.drop_vars(time_dim)
772
779
  # Update dimension names
773
780
  updated_dim_names = self.dim_names.copy()
774
781
  updated_dim_names["time"] = "time"
@@ -872,9 +879,9 @@ class CESMBGCDataset(CESMDataset):
872
879
  ds["depth"].attrs["long_name"] = "Depth"
873
880
  ds["depth"].attrs["units"] = "m"
874
881
  ds = ds.swap_dims({"z_t": "depth"})
875
- if "z_t" in ds:
882
+ if "z_t" in ds.variables:
876
883
  ds = ds.drop_vars("z_t")
877
- if "z_t_150m" in ds:
884
+ if "z_t_150m" in ds.variables:
878
885
  ds = ds.drop_vars("z_t_150m")
879
886
  # update dataset
880
887
  object.__setattr__(self, "ds", ds)
@@ -932,6 +939,19 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):
932
939
 
933
940
  climatology: Optional[bool] = False
934
941
 
942
+ def post_process(self):
943
+ """
944
+ Perform post-processing on the dataset to remove specific variables.
945
+
946
+ This method checks if the variable "z_t" exists in the dataset. If it does,
947
+ the variable is removed from the dataset. The modified dataset is then
948
+ reassigned to the `ds` attribute of the object.
949
+ """
950
+
951
+ if "z_t" in self.ds.variables:
952
+ ds = self.ds.drop_vars("z_t")
953
+ object.__setattr__(self, "ds", ds)
954
+
935
955
 
936
956
  @dataclass(frozen=True, kw_only=True)
937
957
  class ERA5Dataset(Dataset):