roms-tools 2.2.1__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. ci/environment.yml +1 -0
  2. roms_tools/__init__.py +2 -0
  3. roms_tools/analysis/roms_output.py +590 -0
  4. roms_tools/{setup/download.py → download.py} +3 -0
  5. roms_tools/{setup/plot.py → plot.py} +34 -28
  6. roms_tools/setup/boundary_forcing.py +199 -203
  7. roms_tools/setup/datasets.py +60 -136
  8. roms_tools/setup/grid.py +40 -67
  9. roms_tools/setup/initial_conditions.py +249 -247
  10. roms_tools/setup/nesting.py +6 -27
  11. roms_tools/setup/river_forcing.py +41 -76
  12. roms_tools/setup/surface_forcing.py +125 -75
  13. roms_tools/setup/tides.py +31 -51
  14. roms_tools/setup/topography.py +1 -1
  15. roms_tools/setup/utils.py +44 -224
  16. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  17. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  18. roms_tools/tests/test_setup/test_boundary_forcing.py +221 -58
  19. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/.zattrs +5 -3
  20. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/.zmetadata +156 -121
  21. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/abs_time/.zarray +2 -2
  22. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/abs_time/.zattrs +2 -1
  23. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/abs_time/0 +0 -0
  24. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/bry_time/.zarray +2 -2
  25. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/bry_time/.zattrs +1 -1
  26. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/bry_time/0 +0 -0
  27. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_east/.zarray +4 -4
  28. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_east/0.0.0 +0 -0
  29. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_north/.zarray +4 -4
  30. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_north/0.0.0 +0 -0
  31. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_south/.zarray +4 -4
  32. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_south/0.0.0 +0 -0
  33. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_west/.zarray +4 -4
  34. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/salt_west/0.0.0 +0 -0
  35. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_east/.zarray +4 -4
  36. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_east/0.0.0 +0 -0
  37. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_north/.zarray +4 -4
  38. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_north/0.0.0 +0 -0
  39. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_south/.zarray +4 -4
  40. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_south/0.0.0 +0 -0
  41. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_west/.zarray +4 -4
  42. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/temp_west/0.0.0 +0 -0
  43. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_east/.zarray +4 -4
  44. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_east/0.0.0 +0 -0
  45. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_north/.zarray +4 -4
  46. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_north/0.0.0 +0 -0
  47. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_south/.zarray +4 -4
  48. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_south/0.0.0 +0 -0
  49. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_west/.zarray +4 -4
  50. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/u_west/0.0.0 +0 -0
  51. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_east/.zarray +4 -4
  52. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_east/0.0 +0 -0
  53. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_north/.zarray +4 -4
  54. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_north/0.0 +0 -0
  55. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_south/.zarray +4 -4
  56. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_south/0.0 +0 -0
  57. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_west/.zarray +4 -4
  58. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/ubar_west/0.0 +0 -0
  59. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_east/.zarray +4 -4
  60. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_east/0.0.0 +0 -0
  61. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_north/.zarray +4 -4
  62. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_north/0.0.0 +0 -0
  63. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_south/.zarray +4 -4
  64. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_south/0.0.0 +0 -0
  65. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_west/.zarray +4 -4
  66. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/v_west/0.0.0 +0 -0
  67. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_east/.zarray +4 -4
  68. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_east/0.0 +0 -0
  69. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_north/.zarray +4 -4
  70. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_north/0.0 +0 -0
  71. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_south/.zarray +4 -4
  72. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_south/0.0 +0 -0
  73. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_west/.zarray +4 -4
  74. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/vbar_west/0.0 +0 -0
  75. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_east/.zarray +4 -4
  76. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_east/.zattrs +8 -0
  77. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_east/0.0 +0 -0
  78. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_north/.zarray +4 -4
  79. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_north/.zattrs +8 -0
  80. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_north/0.0 +0 -0
  81. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_south/.zarray +4 -4
  82. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_south/.zattrs +8 -0
  83. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_south/0.0 +0 -0
  84. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_west/.zarray +4 -4
  85. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_west/.zattrs +8 -0
  86. roms_tools/tests/test_setup/test_data/boundary_forcing.zarr/zeta_west/0.0 +0 -0
  87. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zattrs +4 -4
  88. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/.zmetadata +4 -4
  89. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/angle/0.0 +0 -0
  90. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/angle_coarse/0.0 +0 -0
  91. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/f/0.0 +0 -0
  92. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/h/0.0 +0 -0
  93. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lat_coarse/0.0 +0 -0
  94. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lat_rho/0.0 +0 -0
  95. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lat_u/0.0 +0 -0
  96. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lat_v/0.0 +0 -0
  97. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lon_coarse/0.0 +0 -0
  98. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lon_rho/0.0 +0 -0
  99. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lon_u/0.0 +0 -0
  100. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/lon_v/0.0 +0 -0
  101. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_coarse/0.0 +0 -0
  102. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_rho/0.0 +0 -0
  103. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_u/0.0 +0 -0
  104. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/mask_v/0.0 +0 -0
  105. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/pm/0.0 +0 -0
  106. roms_tools/tests/test_setup/test_data/grid_that_straddles_dateline.zarr/pn/0.0 +0 -0
  107. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/.zattrs +2 -1
  108. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/.zmetadata +6 -4
  109. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/Cs_r/.zattrs +1 -1
  110. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/Cs_w/.zattrs +1 -1
  111. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/NH4/0.0.0.0 +0 -0
  112. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/NO3/0.0.0.0 +0 -0
  113. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/PO4/0.0.0.0 +0 -0
  114. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/abs_time/.zattrs +1 -0
  115. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/diatSi/0.0.0.0 +0 -0
  116. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/ocean_time/.zattrs +1 -1
  117. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/spC/0.0.0.0 +0 -0
  118. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/spCaCO3/0.0.0.0 +0 -0
  119. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/spFe/0.0.0.0 +0 -0
  120. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/temp/0.0.0.0 +0 -0
  121. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/u/0.0.0.0 +0 -0
  122. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/ubar/0.0.0 +0 -0
  123. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/v/0.0.0.0 +0 -0
  124. roms_tools/tests/test_setup/test_data/initial_conditions_with_bgc_from_climatology.zarr/vbar/0.0.0 +0 -0
  125. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +30 -0
  126. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_location/.zarray +22 -0
  127. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_location/.zattrs +8 -0
  128. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_location/0.0 +0 -0
  129. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +30 -0
  130. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_location/.zarray +22 -0
  131. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_location/.zattrs +8 -0
  132. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_location/0.0 +0 -0
  133. roms_tools/tests/test_setup/test_datasets.py +1 -1
  134. roms_tools/tests/test_setup/test_grid.py +1 -14
  135. roms_tools/tests/test_setup/test_initial_conditions.py +205 -67
  136. roms_tools/tests/test_setup/test_nesting.py +0 -16
  137. roms_tools/tests/test_setup/test_river_forcing.py +9 -37
  138. roms_tools/tests/test_setup/test_surface_forcing.py +103 -74
  139. roms_tools/tests/test_setup/test_tides.py +5 -17
  140. roms_tools/tests/test_setup/test_topography.py +1 -1
  141. roms_tools/tests/test_setup/test_utils.py +57 -1
  142. roms_tools/tests/{test_utils.py → test_tiling/test_partition.py} +1 -1
  143. roms_tools/tiling/partition.py +338 -0
  144. roms_tools/utils.py +310 -276
  145. roms_tools/vertical_coordinate.py +227 -0
  146. {roms_tools-2.2.1.dist-info → roms_tools-2.4.0.dist-info}/METADATA +1 -1
  147. {roms_tools-2.2.1.dist-info → roms_tools-2.4.0.dist-info}/RECORD +151 -142
  148. roms_tools/setup/vertical_coordinate.py +0 -109
  149. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  150. {roms_tools-2.2.1.dist-info → roms_tools-2.4.0.dist-info}/LICENSE +0 -0
  151. {roms_tools-2.2.1.dist-info → roms_tools-2.4.0.dist-info}/WHEEL +0 -0
  152. {roms_tools-2.2.1.dist-info → roms_tools-2.4.0.dist-info}/top_level.txt +0 -0
@@ -3,28 +3,35 @@ import numpy as np
3
3
  import importlib.metadata
4
4
  from dataclasses import dataclass, field
5
5
  from typing import Dict, Union, List, Optional
6
- from roms_tools.setup.grid import Grid
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
8
+ import logging
7
9
  from datetime import datetime
10
+ from roms_tools import Grid
11
+ from roms_tools.regrid import LateralRegrid, VerticalRegrid
12
+ from roms_tools.plot import _plot, _section_plot, _profile_plot, _line_plot
13
+ from roms_tools.utils import (
14
+ transpose_dimensions,
15
+ save_datasets,
16
+ get_dask_chunks,
17
+ interpolate_from_rho_to_u,
18
+ interpolate_from_rho_to_v,
19
+ )
20
+ from roms_tools.vertical_coordinate import (
21
+ compute_depth_coordinates,
22
+ compute_depth,
23
+ )
8
24
  from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
9
- from roms_tools.setup.vertical_coordinate import compute_depth
10
25
  from roms_tools.setup.utils import (
11
26
  nan_check,
12
27
  substitute_nans_by_fillvalue,
13
28
  get_variable_metadata,
14
- save_datasets,
15
29
  get_target_coords,
16
30
  rotate_velocities,
17
31
  compute_barotropic_velocity,
18
- transpose_dimensions,
19
- interpolate_from_rho_to_u,
20
- interpolate_from_rho_to_v,
21
32
  _to_yaml,
22
33
  _from_yaml,
23
34
  )
24
- from roms_tools.setup.regrid import LateralRegrid, VerticalRegrid
25
- from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
26
- import matplotlib.pyplot as plt
27
- from pathlib import Path
28
35
 
29
36
 
30
37
  @dataclass(frozen=True, kw_only=True)
@@ -62,10 +69,17 @@ class InitialConditions:
62
69
  - A list of strings or Path objects containing multiple files.
63
70
  - "climatology" (bool): Indicates if the data is climatology data. Defaults to False.
64
71
 
72
+ adjust_depth_for_sea_surface_height : bool, optional
73
+ Whether to account for sea surface height variations when computing depth coordinates.
74
+ Defaults to `False`.
65
75
  model_reference_date : datetime, optional
66
76
  The reference date for the model. Defaults to January 1, 2000.
67
77
  use_dask: bool, optional
68
78
  Indicates whether to use dask for processing. If True, data is processed with dask; if False, data is processed eagerly. Defaults to False.
79
+ horizontal_chunk_size : int, optional
80
+ The chunk size used for horizontal partitioning for the vertical regridding when `use_dask = True`. Defaults to 50.
81
+ A larger number results in a bigger memory footprint but faster computations.
82
+ A smaller number results in a smaller memory footprint but slower computations.
69
83
  bypass_validation: bool, optional
70
84
  Indicates whether to skip validation checks in the processed data. When set to True,
71
85
  the validation process that ensures no NaN values exist at wet points
@@ -90,7 +104,9 @@ class InitialConditions:
90
104
  source: Dict[str, Union[str, Path, List[Union[str, Path]]]]
91
105
  bgc_source: Optional[Dict[str, Union[str, Path, List[Union[str, Path]]]]] = None
92
106
  model_reference_date: datetime = datetime(2000, 1, 1)
107
+ adjust_depth_for_sea_surface_height: bool = False
93
108
  use_dask: bool = False
109
+ horizontal_chunk_size: int = 50
94
110
  bypass_validation: bool = False
95
111
 
96
112
  ds: xr.Dataset = field(init=False, repr=False)
@@ -98,6 +114,8 @@ class InitialConditions:
98
114
  def __post_init__(self):
99
115
 
100
116
  self._input_checks()
117
+ # Dataset for depth coordinates
118
+ object.__setattr__(self, "ds_depth_coords", xr.Dataset())
101
119
 
102
120
  processed_fields = {}
103
121
  processed_fields = self._process_data(processed_fields, type="physics")
@@ -132,7 +150,6 @@ class InitialConditions:
132
150
  target_coords,
133
151
  buffer_points=20, # lateral fill needs good buffer from data margin
134
152
  )
135
-
136
153
  data.extrapolate_deepest_to_bottom()
137
154
  data.apply_lateral_fill()
138
155
 
@@ -143,6 +160,7 @@ class InitialConditions:
143
160
 
144
161
  # lateral regridding
145
162
  lateral_regrid = LateralRegrid(target_coords, data.dim_names)
163
+
146
164
  for var_name in var_names:
147
165
  if var_name in data.var_names.keys():
148
166
  processed_fields[var_name] = lateral_regrid.apply(
@@ -151,60 +169,61 @@ class InitialConditions:
151
169
 
152
170
  # rotation of velocities and interpolation to u/v points
153
171
  if "u" in variable_info and "v" in variable_info:
154
- (processed_fields["u"], processed_fields["v"],) = rotate_velocities(
172
+ processed_fields["u"], processed_fields["v"] = rotate_velocities(
155
173
  processed_fields["u"],
156
174
  processed_fields["v"],
157
175
  target_coords["angle"],
158
176
  interpolate=True,
159
177
  )
160
178
 
161
- var_names_dict = {}
162
- for location in ["rho", "u", "v"]:
163
- var_names_dict[location] = [
179
+ var_names_dict = {
180
+ location: [
164
181
  name
165
182
  for name, info in variable_info.items()
166
183
  if info["location"] == location and info["is_3d"]
167
184
  ]
185
+ for location in ["rho", "u", "v"]
186
+ }
168
187
 
169
- # compute layer depth coordinates
170
- if len(var_names_dict["u"]) > 0 or len(var_names_dict["v"]) > 0:
171
- self._get_vertical_coordinates(
172
- type="layer",
173
- additional_locations=["u", "v"],
174
- )
175
- else:
176
- if len(var_names_dict["rho"]) > 0:
177
- self._get_vertical_coordinates(type="layer", additional_locations=[])
178
- # vertical regridding
188
+ if type == "bgc":
189
+ # Ensure time coordinate matches that of physical variables
190
+ for var_name in variable_info.keys():
191
+ processed_fields[var_name] = processed_fields[var_name].assign_coords(
192
+ {"time": processed_fields["temp"]["time"]}
193
+ )
194
+
195
+ # Get depth coordinates
196
+ zeta = (
197
+ processed_fields["zeta"] if self.adjust_depth_for_sea_surface_height else 0
198
+ )
199
+
200
+ for location in ["rho", "u", "v"]:
201
+ if len(var_names_dict[location]) > 0:
202
+ self._get_depth_coordinates(zeta, location, "layer")
203
+
204
+ # Vertical regridding
179
205
  for location in ["rho", "u", "v"]:
180
206
  if len(var_names_dict[location]) > 0:
181
207
  vertical_regrid = VerticalRegrid(
182
- self.grid.ds[f"layer_depth_{location}"],
208
+ self.ds_depth_coords[f"layer_depth_{location}"],
183
209
  data.ds[data.dim_names["depth"]],
184
210
  )
185
211
  for var_name in var_names_dict[location]:
186
212
  if var_name in processed_fields:
187
- processed_fields[var_name] = vertical_regrid.apply(
188
- processed_fields[var_name]
189
- )
190
-
191
- # compute barotropic velocities
213
+ field = processed_fields[var_name]
214
+ if self.use_dask:
215
+ field = field.chunk(
216
+ get_dask_chunks(location, self.horizontal_chunk_size)
217
+ )
218
+ processed_fields[var_name] = vertical_regrid.apply(field)
219
+
220
+ # Compute barotropic velocities
192
221
  if "u" in variable_info and "v" in variable_info:
193
- self._get_vertical_coordinates(
194
- type="interface",
195
- additional_locations=["u", "v"],
196
- )
197
222
  for location in ["u", "v"]:
223
+ self._get_depth_coordinates(zeta, location, "interface")
198
224
  processed_fields[f"{location}bar"] = compute_barotropic_velocity(
199
225
  processed_fields[location],
200
- self.grid.ds[f"interface_depth_{location}"],
201
- )
202
-
203
- if type == "bgc":
204
- # Ensure time coordinate matches that of physical variables
205
- for var_name in variable_info.keys():
206
- processed_fields[var_name] = processed_fields[var_name].assign_coords(
207
- {"time": processed_fields["temp"]["time"]}
226
+ self.ds_depth_coords[f"interface_depth_{location}"],
208
227
  )
209
228
 
210
229
  for var_name in processed_fields.keys():
@@ -247,6 +266,12 @@ class InitialConditions:
247
266
  "climatology": self.bgc_source.get("climatology", False),
248
267
  },
249
268
  )
269
+ if self.adjust_depth_for_sea_surface_height:
270
+ logging.info("Sea surface height will be used to adjust depth coordinates.")
271
+ else:
272
+ logging.info(
273
+ "Sea surface height will NOT be used to adjust depth coordinates."
274
+ )
250
275
 
251
276
  def _get_data(self):
252
277
 
@@ -368,82 +393,61 @@ class InitialConditions:
368
393
 
369
394
  object.__setattr__(self, f"variable_info_{type}", variable_info)
370
395
 
371
- def _get_vertical_coordinates(self, type, additional_locations=["u", "v"]):
372
- """Retrieve layer and interface depth coordinates.
373
-
374
- This method computes and updates the layer and interface depth coordinates. It handles depth calculations for rho points and
375
- additional specified locations (u and v).
396
+ def _get_depth_coordinates(
397
+ self, zeta: xr.DataArray | float, location: str, depth_type: str = "layer"
398
+ ) -> None:
399
+ """Ensure depth coordinates are computed and stored for a given location and
400
+ depth type.
376
401
 
377
402
  Parameters
378
403
  ----------
379
- type : str
380
- The type of depth coordinate to retrieve. Valid options are:
381
- - "layer": Retrieves layer depth coordinates.
382
- - "interface": Retrieves interface depth coordinates.
383
-
384
- additional_locations : list of str, optional
385
- Specifies additional locations to compute depth coordinates for. Default is ["u", "v"].
386
- Valid options include:
387
- - "u": Computes depth coordinates for u points.
388
- - "v": Computes depth coordinates for v points.
389
-
390
- Updates
391
- -------
392
- self.grid.ds : xarray.Dataset
393
- The dataset is updated with the following vertical depth coordinates:
394
- - f"{type}_depth_rho": Depth coordinates at rho points.
395
- - f"{type}_depth_u": Depth coordinates at u points (if applicable).
396
- - f"{type}_depth_v": Depth coordinates at v points (if applicable).
397
- """
404
+ zeta : xr.DataArray or float
405
+ Free-surface elevation (can be a scalar or a DataArray).
406
+ location : str
407
+ Grid location for depth computation ("rho", "u", or "v").
408
+ depth_type : str, optional
409
+ Type of depth coordinates to compute, by default "layer".
398
410
 
399
- layer_vars = []
400
- for location in ["rho"] + additional_locations:
401
- layer_vars.append(f"{type}_depth_{location}")
402
-
403
- if all(layer_var in self.grid.ds for layer_var in layer_vars):
404
- # Vertical coordinate data already exists
405
- pass
406
-
407
- elif f"{type}_depth_rho" in self.grid.ds:
408
- depth = self.grid.ds[f"{type}_depth_rho"]
409
-
410
- if "u" in additional_locations or "v" in additional_locations:
411
- # interpolation
412
- if "u" in additional_locations:
413
- depth_u = interpolate_from_rho_to_u(depth)
414
- depth_u.attrs["long_name"] = f"{type} depth at u-points"
415
- depth_u.attrs["units"] = "m"
416
- self.grid.ds[f"{type}_depth_u"] = depth_u
417
- if "v" in additional_locations:
418
- depth_v = interpolate_from_rho_to_v(depth)
419
- depth_v.attrs["long_name"] = f"{type} depth at v-points"
420
- depth_v.attrs["units"] = "m"
421
- self.grid.ds[f"{type}_depth_v"] = depth_v
422
- else:
423
- h = self.grid.ds["h"]
424
- if type == "layer":
425
- depth = compute_depth(
426
- 0, h, self.grid.hc, self.grid.ds.Cs_r, self.grid.ds.sigma_r
427
- )
411
+ Notes
412
+ ------
413
+ Rather than calling compute_depth_coordinates from the vertical_coordinate.py module,
414
+ this method computes the depth coordinates from scratch because of optional chunking.
415
+ """
416
+ key = f"{depth_type}_depth_{location}"
417
+
418
+ if key not in self.ds_depth_coords:
419
+ # Select the appropriate depth computation parameters
420
+ if depth_type == "layer":
421
+ Cs = self.grid.ds["Cs_r"]
422
+ sigma = self.grid.ds["sigma_r"]
423
+ elif depth_type == "interface":
424
+ Cs = self.grid.ds["Cs_w"]
425
+ sigma = self.grid.ds["sigma_w"]
428
426
  else:
429
- depth = compute_depth(
430
- 0, h, self.grid.hc, self.grid.ds.Cs_w, self.grid.ds.sigma_w
427
+ raise ValueError(
428
+ f"Invalid depth_type: {depth_type}. Choose 'layer' or 'interface'."
431
429
  )
432
430
 
433
- depth.attrs["long_name"] = f"{type} depth at rho-points"
434
- depth.attrs["units"] = "m"
435
- self.grid.ds[f"{type}_depth_rho"] = depth
436
-
437
- if "u" in additional_locations or "v" in additional_locations:
438
- # interpolation
439
- depth_u = interpolate_from_rho_to_u(depth)
440
- depth_u.attrs["long_name"] = f"{type} depth at u-points"
441
- depth_u.attrs["units"] = "m"
442
- depth_v = interpolate_from_rho_to_v(depth)
443
- depth_v.attrs["long_name"] = f"{type} depth at v-points"
444
- depth_v.attrs["units"] = "m"
445
- self.grid.ds[f"{type}_depth_u"] = depth_u
446
- self.grid.ds[f"{type}_depth_v"] = depth_v
431
+ h = self.grid.ds["h"]
432
+
433
+ # Interpolate h and zeta to the specified location
434
+ if location == "u":
435
+ h = interpolate_from_rho_to_u(h)
436
+ if isinstance(zeta, xr.DataArray):
437
+ zeta = interpolate_from_rho_to_u(zeta)
438
+ elif location == "v":
439
+ h = interpolate_from_rho_to_v(h)
440
+ if isinstance(zeta, xr.DataArray):
441
+ zeta = interpolate_from_rho_to_v(zeta)
442
+
443
+ if self.use_dask:
444
+ h = h.chunk(get_dask_chunks(location, self.horizontal_chunk_size))
445
+ if self.adjust_depth_for_sea_surface_height:
446
+ zeta = zeta.chunk(
447
+ get_dask_chunks(location, self.horizontal_chunk_size)
448
+ )
449
+ depth = compute_depth(zeta, h, self.grid.ds.attrs["hc"], Cs, sigma)
450
+ self.ds_depth_coords[key] = depth
447
451
 
448
452
  def _write_into_dataset(self, processed_fields, d_meta):
449
453
 
@@ -457,7 +461,7 @@ class InitialConditions:
457
461
 
458
462
  # initialize vertical velocity to zero
459
463
  ds["w"] = xr.zeros_like(
460
- self.grid.ds["interface_depth_rho"].expand_dims(
464
+ (self.grid.ds["Cs_w"] * self.grid.ds["h"]).expand_dims(
461
465
  time=processed_fields["u"].time
462
466
  )
463
467
  ).astype(np.float32)
@@ -554,6 +558,9 @@ class InitialConditions:
554
558
  ds.attrs["roms_tools_version"] = roms_tools_version
555
559
  ds.attrs["ini_time"] = str(self.ini_time)
556
560
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
561
+ ds.attrs["adjust_depth_for_sea_surface_height"] = str(
562
+ self.adjust_depth_for_sea_surface_height
563
+ )
557
564
  ds.attrs["source"] = self.source["name"]
558
565
  if self.bgc_source is not None:
559
566
  ds.attrs["bgc_source"] = self.bgc_source["name"]
@@ -656,18 +663,24 @@ class InitialConditions:
656
663
  If the field specified by `var_name` is 2D and both `eta` and `xi` are specified.
657
664
  """
658
665
 
659
- if len(self.ds[var_name].squeeze().dims) == 3 and not any(
660
- [s is not None, eta is not None, xi is not None]
661
- ):
666
+ field = self.ds[var_name].squeeze()
667
+
668
+ if len(field.dims) == 3:
669
+ if not any([s is not None, eta is not None, xi is not None]):
670
+ raise ValueError(
671
+ "Invalid input: For 3D fields, you must specify at least one of the dimensions 's', 'eta', or 'xi'."
672
+ )
673
+ if all([s is not None, eta is not None, xi is not None]):
674
+ raise ValueError(
675
+ "Ambiguous input: For 3D fields, specify at most two of 's', 'eta', or 'xi'. Specifying all three is not allowed."
676
+ )
677
+
678
+ if len(field.dims) == 2 and all([eta is not None, xi is not None]):
662
679
  raise ValueError(
663
- "For 3D fields, at least one of s, eta, or xi must be specified."
680
+ "Conflicting input: For 2D fields, specify only one dimension, either 'eta' or 'xi', not both."
664
681
  )
665
682
 
666
- if len(self.ds[var_name].squeeze().dims) == 2 and all(
667
- [eta is not None, xi is not None]
668
- ):
669
- raise ValueError("For 2D fields, specify either eta or xi, not both.")
670
-
683
+ # Load the data
671
684
  if self.use_dask:
672
685
  from dask.diagnostics import ProgressBar
673
686
 
@@ -675,54 +688,55 @@ class InitialConditions:
675
688
  self.ds[var_name].load()
676
689
 
677
690
  field = self.ds[var_name].squeeze()
678
- if s is not None:
679
- layer_contours = False
680
691
 
692
+ # Get correct mask and horizontal coordinates
681
693
  if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
682
- if layer_contours:
683
- if "interface_depth_rho" in self.grid.ds:
684
- interface_depth = self.grid.ds.interface_depth_rho
685
- else:
686
- self.get_vertical_coordinates(
687
- type="interface", additional_locations=[]
688
- )
689
- layer_depth = self.grid.ds.layer_depth_rho
690
- mask = self.grid.ds.mask_rho
691
- field = field.assign_coords(
692
- {"lon": self.grid.ds.lon_rho, "lat": self.grid.ds.lat_rho}
693
- )
694
-
694
+ loc = "rho"
695
695
  elif all(dim in field.dims for dim in ["eta_rho", "xi_u"]):
696
- if layer_contours:
697
- if "interface_depth_u" in self.grid.ds:
698
- interface_depth = self.grid.ds.interface_depth_u
699
- else:
700
- self.get_vertical_coordinates(
701
- type="interface", additional_locations=["u", "v"]
702
- )
703
- layer_depth = self.grid.ds.layer_depth_u
704
- mask = self.grid.ds.mask_u
705
- field = field.assign_coords(
706
- {"lon": self.grid.ds.lon_u, "lat": self.grid.ds.lat_u}
707
- )
708
-
696
+ loc = "u"
709
697
  elif all(dim in field.dims for dim in ["eta_v", "xi_rho"]):
710
- if layer_contours:
711
- if "interface_depth_v" in self.grid.ds:
712
- interface_depth = self.grid.ds.interface_depth_v
713
- else:
714
- self.get_vertical_coordinates(
715
- type="interface", additional_locations=["u", "v"]
716
- )
717
- layer_depth = self.grid.ds.layer_depth_v
718
- mask = self.grid.ds.mask_v
719
- field = field.assign_coords(
720
- {"lon": self.grid.ds.lon_v, "lat": self.grid.ds.lat_v}
721
- )
698
+ loc = "v"
722
699
  else:
723
700
  ValueError("provided field does not have two horizontal dimension")
724
701
 
725
- # slice the field as desired
702
+ mask = self.grid.ds[f"mask_{loc}"]
703
+ lat_deg = self.grid.ds[f"lat_{loc}"]
704
+ lon_deg = self.grid.ds[f"lon_{loc}"]
705
+
706
+ if self.grid.straddle:
707
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
708
+
709
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
710
+
711
+ # Retrieve depth coordinates
712
+ if s is not None:
713
+ layer_contours = False
714
+ # Note that `layer_depth_{loc}` has already been computed during `__post_init__`.
715
+ layer_depth = self.ds_depth_coords[f"layer_depth_{loc}"].squeeze()
716
+
717
+ # Slice the field as desired
718
+ def _slice_and_assign(
719
+ field,
720
+ mask,
721
+ layer_depth,
722
+ title,
723
+ dim_name,
724
+ dim_values,
725
+ idx,
726
+ ):
727
+ if dim_name in field.dims:
728
+ title = title + f", {dim_name} = {dim_values[idx].item()}"
729
+ field = field.isel(**{dim_name: idx})
730
+ mask = mask.isel(**{dim_name: idx})
731
+ layer_depth = layer_depth.isel(**{dim_name: idx})
732
+ if "s_rho" in field.dims:
733
+ field = field.assign_coords({"layer_depth": layer_depth})
734
+ else:
735
+ raise ValueError(
736
+ f"None of the expected dimensions ({dim_name}) found in field."
737
+ )
738
+ return field, mask, layer_depth, title
739
+
726
740
  title = field.long_name
727
741
  if s is not None:
728
742
  title = title + f", s_rho = {field.s_rho[s].item()}"
@@ -733,49 +747,28 @@ class InitialConditions:
733
747
  depth_contours = False
734
748
 
735
749
  if eta is not None:
736
- if "eta_rho" in field.dims:
737
- title = title + f", eta_rho = {field.eta_rho[eta].item()}"
738
- field = field.isel(eta_rho=eta)
739
- layer_depth = layer_depth.isel(eta_rho=eta)
740
- if layer_contours:
741
- interface_depth = interface_depth.isel(eta_rho=eta)
742
- if "s_rho" in field.dims:
743
- field = field.assign_coords({"layer_depth": layer_depth})
744
- elif "eta_v" in field.dims:
745
- title = title + f", eta_v = {field.eta_v[eta].item()}"
746
- field = field.isel(eta_v=eta)
747
- layer_depth = layer_depth.isel(eta_v=eta)
748
- if layer_contours:
749
- interface_depth = interface_depth.isel(eta_v=eta)
750
- if "s_rho" in field.dims:
751
- field = field.assign_coords({"layer_depth": layer_depth})
752
- else:
753
- raise ValueError(
754
- f"None of the expected dimensions (eta_rho, eta_v) found in ds[{var_name}]."
755
- )
750
+ field, mask, layer_depth, title = _slice_and_assign(
751
+ field,
752
+ mask,
753
+ layer_depth,
754
+ title,
755
+ "eta_rho" if "eta_rho" in field.dims else "eta_v",
756
+ field.eta_rho if "eta_rho" in field.dims else field.eta_v,
757
+ eta,
758
+ )
759
+
756
760
  if xi is not None:
757
- if "xi_rho" in field.dims:
758
- title = title + f", xi_rho = {field.xi_rho[xi].item()}"
759
- field = field.isel(xi_rho=xi)
760
- layer_depth = layer_depth.isel(xi_rho=xi)
761
- if layer_contours:
762
- interface_depth = interface_depth.isel(xi_rho=xi)
763
- if "s_rho" in field.dims:
764
- field = field.assign_coords({"layer_depth": layer_depth})
765
- elif "xi_u" in field.dims:
766
- title = title + f", xi_u = {field.xi_u[xi].item()}"
767
- field = field.isel(xi_u=xi)
768
- layer_depth = layer_depth.isel(xi_u=xi)
769
- if layer_contours:
770
- interface_depth = interface_depth.isel(xi_u=xi)
771
- if "s_rho" in field.dims:
772
- field = field.assign_coords({"layer_depth": layer_depth})
773
- else:
774
- raise ValueError(
775
- f"None of the expected dimensions (xi_rho, xi_u) found in ds[{var_name}]."
776
- )
761
+ field, mask, layer_depth, title = _slice_and_assign(
762
+ field,
763
+ mask,
764
+ layer_depth,
765
+ title,
766
+ "xi_rho" if "xi_rho" in field.dims else "xi_u",
767
+ field.xi_rho if "xi_rho" in field.dims else field.xi_u,
768
+ xi,
769
+ )
777
770
 
778
- # chose colorbar
771
+ # Choose colorbar
779
772
  if var_name in ["u", "v", "w", "ubar", "vbar", "zeta"]:
780
773
  vmax = max(field.max().values, -field.min().values)
781
774
  vmin = -vmax
@@ -792,26 +785,59 @@ class InitialConditions:
792
785
 
793
786
  if eta is None and xi is None:
794
787
  _plot(
795
- self.grid.ds,
796
788
  field=field.where(mask),
797
- straddle=self.grid.straddle,
798
789
  depth_contours=depth_contours,
799
790
  title=title,
800
791
  kwargs=kwargs,
801
792
  c="g",
802
793
  )
803
794
  else:
804
- if not layer_contours:
805
- interface_depth = None
806
- else:
807
- # restrict number of layer_contours to 10 for the sake of plot clearity
808
- nr_layers = len(interface_depth["s_w"])
809
- selected_layers = np.linspace(
810
- 0, nr_layers - 1, min(nr_layers, 10), dtype=int
811
- )
812
- interface_depth = interface_depth.isel(s_w=selected_layers)
813
-
814
795
  if len(field.dims) == 2:
796
+ if layer_contours:
797
+ if loc == "rho":
798
+ # interface_depth_rho has not been computed yet
799
+ interface_depth = compute_depth_coordinates(
800
+ self.grid.ds,
801
+ self.ds.zeta,
802
+ depth_type="interface",
803
+ location=loc,
804
+ eta=eta,
805
+ xi=xi,
806
+ )
807
+ elif loc == "u":
808
+ index_kwargs = {}
809
+ if eta is not None:
810
+ index_kwargs["eta_rho"] = eta
811
+ if xi is not None:
812
+ index_kwargs["xi_u"] = xi
813
+
814
+ interface_depth = (
815
+ self.ds_depth_coords[f"interface_depth_{loc}"]
816
+ .isel(**index_kwargs)
817
+ .squeeze()
818
+ )
819
+ elif loc == "v":
820
+ index_kwargs = {}
821
+ if eta is not None:
822
+ index_kwargs["eta_v"] = eta
823
+ if xi is not None:
824
+ index_kwargs["xi_rho"] = xi
825
+
826
+ interface_depth = (
827
+ self.ds_depth_coords[f"interface_depth_{loc}"]
828
+ .isel(**index_kwargs)
829
+ .squeeze()
830
+ )
831
+
832
+ # restrict number of layer_contours to 10 for the sake of plot clearity
833
+ nr_layers = len(interface_depth["s_w"])
834
+ selected_layers = np.linspace(
835
+ 0, nr_layers - 1, min(nr_layers, 10), dtype=int
836
+ )
837
+ interface_depth = interface_depth.isel(s_w=selected_layers)
838
+ else:
839
+ interface_depth = None
840
+
815
841
  _section_plot(
816
842
  field,
817
843
  interface_depth=interface_depth,
@@ -821,40 +847,22 @@ class InitialConditions:
821
847
  )
822
848
  else:
823
849
  if "s_rho" in field.dims:
824
- _profile_plot(field, title=title, ax=ax)
850
+ _profile_plot(field.where(mask), title=title, ax=ax)
825
851
  else:
826
- _line_plot(field, title=title, ax=ax)
827
-
828
- def save(
829
- self, filepath: Union[str, Path], np_eta: int = None, np_xi: int = None
830
- ) -> None:
831
- """Save the initial conditions information to a netCDF4 file.
832
-
833
- This method supports saving the dataset in two modes:
852
+ _line_plot(field.where(mask), title=title, ax=ax)
834
853
 
835
- 1. **Single File Mode (default)**:
836
-
837
- If both `np_eta` and `np_xi` are `None`, the entire dataset is saved as a single netCDF4 file
838
- with the base filename specified by `filepath.nc`.
839
-
840
- 2. **Partitioned Mode**:
841
-
842
- - If either `np_eta` or `np_xi` is specified, the dataset is divided into spatial tiles along the eta-axis and xi-axis.
843
- - Each spatial tile is saved as a separate netCDF4 file.
854
+ def save(self, filepath: Union[str, Path]) -> None:
855
+ """Save the initial conditions information to one netCDF4 file.
844
856
 
845
857
  Parameters
846
858
  ----------
847
859
  filepath : Union[str, Path]
848
860
  The base path or filename where the dataset should be saved.
849
- np_eta : int, optional
850
- The number of partitions along the `eta` direction. If `None`, no spatial partitioning is performed.
851
- np_xi : int, optional
852
- The number of partitions along the `xi` direction. If `None`, no spatial partitioning is performed.
853
861
 
854
862
  Returns
855
863
  -------
856
- List[Path]
857
- A list of Path objects for the filenames that were saved.
864
+ Path
865
+ A `Path` object representing the location of the saved file.
858
866
  """
859
867
 
860
868
  # Ensure filepath is a Path object
@@ -864,17 +872,11 @@ class InitialConditions:
864
872
  if filepath.suffix == ".nc":
865
873
  filepath = filepath.with_suffix("")
866
874
 
867
- if self.use_dask:
868
- from dask.diagnostics import ProgressBar
869
-
870
- with ProgressBar():
871
- self.ds.load()
872
-
873
875
  dataset_list = [self.ds]
874
876
  output_filenames = [str(filepath)]
875
877
 
876
878
  saved_filenames = save_datasets(
877
- dataset_list, output_filenames, np_eta=np_eta, np_xi=np_xi
879
+ dataset_list, output_filenames, use_dask=self.use_dask
878
880
  )
879
881
 
880
882
  return saved_filenames