roms-tools 0.20__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,118 @@
1
+ import pooch
2
+ import xarray as xr
3
+
4
+ # Create a Pooch object to manage the global topography data
5
+ topo_data = pooch.create(
6
+ # Use the default cache folder for the operating system
7
+ path=pooch.os_cache("roms-tools"),
8
+ base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
9
+ # The registry specifies the files that can be fetched
10
+ registry={
11
+ "etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
12
+ },
13
+ )
14
+
15
+ # Create a Pooch object to manage the global SWR correction data
16
+ correction_data = pooch.create(
17
+ # Use the default cache folder for the operating system
18
+ path=pooch.os_cache("roms-tools"),
19
+ base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
20
+ # The registry specifies the files that can be fetched
21
+ registry={
22
+ "etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
23
+ "SSR_correction.nc": "sha256:a170c1698e6cc2765b3f0bb51a18c6a979bc796ac3a4c014585aeede1f1f8ea0",
24
+ },
25
+ )
26
+
27
+ # Create a Pooch object to manage the test data
28
+ pup_test_data = pooch.create(
29
+ # Use the default cache folder for the operating system
30
+ path=pooch.os_cache("roms-tools"),
31
+ base_url="https://github.com/CWorthy-ocean/roms-tools-test-data/raw/main/",
32
+ # The registry specifies the files that can be fetched
33
+ registry={
34
+ "GLORYS_test_data.nc": "648f88ec29c433bcf65f257c1fb9497bd3d5d3880640186336b10ed54f7129d2",
35
+ "ERA5_regional_test_data.nc": "bd12ce3b562fbea2a80a3b79ba74c724294043c28dc98ae092ad816d74eac794",
36
+ "ERA5_global_test_data.nc": "8ed177ab64c02caf509b9fb121cf6713f286cc603b1f302f15f3f4eb0c21dc4f",
37
+ "TPXO_global_test_data.nc": "457bfe87a7b247ec6e04e3c7d3e741ccf223020c41593f8ae33a14f2b5255e60",
38
+ "TPXO_regional_test_data.nc": "11739245e2286d9c9d342dce5221e6435d2072b50028bef2e86a30287b3b4032",
39
+ "CESM_regional_test_data_one_time_slice.nc": "43b578ecc067c85f95d6b97ed7b9dc8da7846f07c95331c6ba7f4a3161036a17",
40
+ "CESM_regional_test_data_climatology.nc": "986a200029d9478fd43e6e4a8bc43e8a8f4407554893c59b5fcc2e86fd203272",
41
+ "CESM_surface_global_test_data_climatology.nc": "a072757110c6f7b716a98f867688ef4195a5966741d2f368201ac24617254e35",
42
+ "CESM_surface_global_test_data.nc": "874106ffbc8b1b220db09df1551bbb89d22439d795b4d1e5a24ee775e9a7bf6e",
43
+ },
44
+ )
45
+
46
+
47
+ def fetch_topo(topography_source: str) -> xr.Dataset:
48
+ """
49
+ Load the global topography data as an xarray Dataset.
50
+
51
+ Parameters
52
+ ----------
53
+ topography_source : str
54
+ The source of the topography data to be loaded. Available options:
55
+ - "ETOPO5"
56
+
57
+ Returns
58
+ -------
59
+ xr.Dataset
60
+ The global topography data as an xarray Dataset.
61
+ """
62
+ # Mapping from user-specified topography options to corresponding filenames in the registry
63
+ topo_dict = {"ETOPO5": "etopo5.nc"}
64
+
65
+ # Fetch the file using Pooch, downloading if necessary
66
+ fname = topo_data.fetch(topo_dict[topography_source])
67
+
68
+ # Load the dataset using xarray and return it
69
+ ds = xr.open_dataset(fname)
70
+ return ds
71
+
72
+
73
+ def download_correction_data(filename: str) -> str:
74
+ """
75
+ Download the correction data file.
76
+
77
+ Parameters
78
+ ----------
79
+ filename : str
80
+ The name of the test data file to be downloaded. Available options:
81
+ - "SSR_correction.nc"
82
+
83
+ Returns
84
+ -------
85
+ str
86
+ The path to the downloaded test data file.
87
+ """
88
+ # Fetch the file using Pooch, downloading if necessary
89
+ fname = correction_data.fetch(filename)
90
+
91
+ return fname
92
+
93
+
94
+ def download_test_data(filename: str) -> str:
95
+ """
96
+ Download the test data file.
97
+
98
+ Parameters
99
+ ----------
100
+ filename : str
101
+ The name of the test data file to be downloaded. Available options:
102
+ - "GLORYS_test_data.nc"
103
+ - "ERA5_regional_test_data.nc"
104
+ - "ERA5_global_test_data.nc"
105
+ - "TPXO_global_test_data.nc"
106
+ - "TPXO_regional_test_data.nc"
107
+ - "CESM_regional_test_data_one_time_slice.nc"
108
+ - "CESM_regional_test_data_climatology.nc"
109
+
110
+ Returns
111
+ -------
112
+ str
113
+ The path to the downloaded test data file.
114
+ """
115
+ # Fetch the file using Pooch, downloading if necessary
116
+ fname = pup_test_data.fetch(filename)
117
+
118
+ return fname
@@ -3,103 +3,80 @@ import numpy as np
3
3
  import yaml
4
4
  import importlib.metadata
5
5
  from dataclasses import dataclass, field, asdict
6
+ from typing import Optional, Dict, Union
6
7
  from roms_tools.setup.grid import Grid
7
8
  from roms_tools.setup.vertical_coordinate import VerticalCoordinate
8
9
  from datetime import datetime
9
- from roms_tools.setup.datasets import Dataset
10
- from roms_tools.setup.fill import fill_and_interpolate
10
+ from roms_tools.setup.datasets import GLORYSDataset, CESMBGCDataset
11
11
  from roms_tools.setup.utils import (
12
12
  nan_check,
13
- interpolate_from_rho_to_u,
14
- interpolate_from_rho_to_v,
15
- extrapolate_deepest_to_bottom,
16
13
  )
14
+ from roms_tools.setup.mixins import ROMSToolsMixins
17
15
  from roms_tools.setup.plot import _plot, _section_plot, _profile_plot, _line_plot
18
16
  import matplotlib.pyplot as plt
19
17
 
20
18
 
21
19
  @dataclass(frozen=True, kw_only=True)
22
- class InitialConditions:
20
+ class InitialConditions(ROMSToolsMixins):
23
21
  """
24
- Represents initial conditions for ROMS.
22
+ Represents initial conditions for ROMS, including physical and biogeochemical data.
25
23
 
26
24
  Parameters
27
25
  ----------
28
26
  grid : Grid
29
- Object representing the grid information.
30
- vertical_coordinate: VerticalCoordinate
31
- Object representing the vertical coordinate information
27
+ Object representing the grid information used for the model.
28
+ vertical_coordinate : VerticalCoordinate
29
+ Object representing the vertical coordinate system.
32
30
  ini_time : datetime
33
- Desired initialization time.
31
+ The date and time at which the initial conditions are set.
32
+ physics_source : Dict[str, Union[str, None]]
33
+ Dictionary specifying the source of the physical initial condition data:
34
+ - "name" (str): Name of the data source (e.g., "GLORYS").
35
+ - "path" (str): Path to the physical data file. Can contain wildcards.
36
+ - "climatology" (bool): Indicates if the physical data is climatology data. Defaults to False.
37
+ bgc_source : Optional[Dict[str, Union[str, None]]]
38
+ Dictionary specifying the source of the biogeochemical (BGC) initial condition data:
39
+ - "name" (str): Name of the BGC data source (e.g., "CESM_REGRIDDED").
40
+ - "path" (str): Path to the BGC data file. Can contain wildcards.
41
+ - "climatology" (bool): Indicates if the BGC data is climatology data. Defaults to True.
34
42
  model_reference_date : datetime, optional
35
- Reference date for the model. Default is January 1, 2000.
36
- source : str, optional
37
- Source of the initial condition data. Default is "GLORYS".
38
- filename: str
39
- Path to the source data file. Can contain wildcards.
43
+ The reference date for the model. Defaults to January 1, 2000.
40
44
 
41
45
  Attributes
42
46
  ----------
43
47
  ds : xr.Dataset
44
- Xarray Dataset containing the initial condition data.
45
-
48
+ Xarray Dataset containing the initial condition data loaded from the specified files.
49
+
50
+ Examples
51
+ --------
52
+ >>> initial_conditions = InitialConditions(
53
+ ... grid=grid,
54
+ ... vertical_coordinate=vertical_coordinate,
55
+ ... ini_time=datetime(2022, 1, 1),
56
+ ... physics_source={"name": "GLORYS", "path": "physics_data.nc"},
57
+ ... bgc_source={
58
+ ... "name": "CESM_REGRIDDED",
59
+ ... "path": "bgc_data.nc",
60
+ ... "climatology": True,
61
+ ... },
62
+ ... )
46
63
  """
47
64
 
48
65
  grid: Grid
49
66
  vertical_coordinate: VerticalCoordinate
50
67
  ini_time: datetime
68
+ physics_source: Dict[str, Union[str, None]]
69
+ bgc_source: Optional[Dict[str, Union[str, None]]] = None
51
70
  model_reference_date: datetime = datetime(2000, 1, 1)
52
- source: str = "GLORYS"
53
- filename: str
71
+
54
72
  ds: xr.Dataset = field(init=False, repr=False)
55
73
 
56
74
  def __post_init__(self):
57
75
 
58
- # Check that the source is "GLORYS"
59
- if self.source != "GLORYS":
60
- raise ValueError('Only "GLORYS" is a valid option for source.')
61
- if self.source == "GLORYS":
62
- dims = {
63
- "longitude": "longitude",
64
- "latitude": "latitude",
65
- "depth": "depth",
66
- "time": "time",
67
- }
68
-
69
- varnames = {
70
- "temp": "thetao",
71
- "salt": "so",
72
- "u": "uo",
73
- "v": "vo",
74
- "ssh": "zos",
75
- }
76
+ self._input_checks()
77
+ lon, lat, angle, straddle = super().get_target_lon_lat()
76
78
 
77
- data = Dataset(
78
- filename=self.filename,
79
- start_time=self.ini_time,
80
- var_names=varnames.values(),
81
- dim_names=dims,
82
- )
83
-
84
- lon = self.grid.ds.lon_rho
85
- lat = self.grid.ds.lat_rho
86
- angle = self.grid.ds.angle
87
-
88
- # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
89
- lon = xr.where(lon > 180, lon - 360, lon)
90
- straddle = True
91
- if not self.grid.straddle and abs(lon).min() > 5:
92
- lon = xr.where(lon < 0, lon + 360, lon)
93
- straddle = False
94
-
95
- # The following consists of two steps:
96
- # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
97
- # We perform these two steps for two reasons:
98
- # A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
99
- # B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
100
- # can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
101
- # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
102
- # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
79
+ data = self._get_data()
103
80
  data.choose_subdomain(
104
81
  latitude_range=[lat.min().values, lat.max().values],
105
82
  longitude_range=[lon.min().values, lon.max().values],
@@ -107,115 +84,134 @@ class InitialConditions:
107
84
  straddle=straddle,
108
85
  )
109
86
 
110
- # interpolate onto desired grid
111
- fill_dims = [dims["latitude"], dims["longitude"]]
87
+ vars_2d = ["zeta"]
88
+ vars_3d = ["temp", "salt", "u", "v"]
89
+ data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
90
+ data_vars = super().process_velocities(data_vars, angle)
91
+
92
+ if self.bgc_source is not None:
93
+ bgc_data = self._get_bgc_data()
94
+ bgc_data.choose_subdomain(
95
+ latitude_range=[lat.min().values, lat.max().values],
96
+ longitude_range=[lon.min().values, lon.max().values],
97
+ margin=2,
98
+ straddle=straddle,
99
+ )
112
100
 
113
- # 2d interpolation
114
- mask = xr.where(data.ds[varnames["ssh"]].isel(time=0).isnull(), 0, 1)
115
- coords = {dims["latitude"]: lat, dims["longitude"]: lon}
101
+ vars_2d = []
102
+ vars_3d = bgc_data.var_names.keys()
103
+ bgc_data_vars = super().regrid_data(bgc_data, vars_2d, vars_3d, lon, lat)
116
104
 
117
- ssh = fill_and_interpolate(
118
- data.ds[varnames["ssh"]].astype(np.float64),
119
- mask,
120
- fill_dims=fill_dims,
121
- coords=coords,
122
- method="linear",
123
- )
105
+ # Ensure time coordinate matches if climatology is applied in one case but not the other
106
+ if (
107
+ not self.physics_source["climatology"]
108
+ and self.bgc_source["climatology"]
109
+ ):
110
+ for var in bgc_data_vars.keys():
111
+ bgc_data_vars[var] = bgc_data_vars[var].assign_coords(
112
+ {"time": data_vars["temp"]["time"]}
113
+ )
124
114
 
125
- # 3d interpolation
115
+ # Combine data variables from physical and biogeochemical sources
116
+ data_vars.update(bgc_data_vars)
126
117
 
127
- # extrapolate deepest value all the way to bottom ("flooding")
128
- for var in ["temp", "salt", "u", "v"]:
129
- data.ds[varnames[var]] = extrapolate_deepest_to_bottom(
130
- data.ds[varnames[var]], dims["depth"]
131
- )
118
+ d_meta = super().get_variable_metadata()
119
+ ds = self._write_into_dataset(data_vars, d_meta)
132
120
 
133
- mask = xr.where(data.ds[varnames["temp"]].isel(time=0).isnull(), 0, 1)
134
- coords = {
135
- dims["latitude"]: lat,
136
- dims["longitude"]: lon,
137
- dims["depth"]: self.vertical_coordinate.ds["layer_depth_rho"],
138
- }
121
+ ds = self._add_global_metadata(ds)
139
122
 
140
- # setting fillvalue_interp to None means that we allow extrapolation in the
141
- # interpolation step to avoid NaNs at the surface if the lowest depth in original
142
- # data is greater than zero
143
- data_vars = {}
144
- for var in ["temp", "salt", "u", "v"]:
145
- data_vars[var] = fill_and_interpolate(
146
- data.ds[varnames[var]].astype(np.float64),
147
- mask,
148
- fill_dims=fill_dims,
149
- coords=coords,
150
- method="linear",
151
- fillvalue_interp=None,
152
- )
123
+ ds["zeta"].load()
124
+ nan_check(ds["zeta"].squeeze(), self.grid.ds.mask_rho)
153
125
 
154
- # rotate to grid orientation
155
- u_rot = data_vars["u"] * np.cos(angle) + data_vars["v"] * np.sin(angle)
156
- v_rot = data_vars["v"] * np.cos(angle) - data_vars["u"] * np.sin(angle)
126
+ object.__setattr__(self, "ds", ds)
157
127
 
158
- # interpolate to u- and v-points
159
- u = interpolate_from_rho_to_u(u_rot)
160
- v = interpolate_from_rho_to_v(v_rot)
128
+ def _input_checks(self):
161
129
 
162
- # 3d masks for ROMS domain
163
- umask = self.grid.ds.mask_u.expand_dims({"s_rho": u.s_rho})
164
- vmask = self.grid.ds.mask_v.expand_dims({"s_rho": v.s_rho})
130
+ if "name" not in self.physics_source.keys():
131
+ raise ValueError("`physics_source` must include a 'name'.")
132
+ if "path" not in self.physics_source.keys():
133
+ raise ValueError("`physics_source` must include a 'path'.")
134
+ # set self.physics_source["climatology"] to False if not provided
135
+ object.__setattr__(
136
+ self,
137
+ "physics_source",
138
+ {
139
+ **self.physics_source,
140
+ "climatology": self.physics_source.get("climatology", False),
141
+ },
142
+ )
143
+ if self.bgc_source is not None:
144
+ if "name" not in self.bgc_source.keys():
145
+ raise ValueError(
146
+ "`bgc_source` must include a 'name' if it is provided."
147
+ )
148
+ if "path" not in self.bgc_source.keys():
149
+ raise ValueError(
150
+ "`bgc_source` must include a 'path' if it is provided."
151
+ )
152
+ # set self.bgc_source["climatology"] to True if not provided
153
+ object.__setattr__(
154
+ self,
155
+ "bgc_source",
156
+ {
157
+ **self.bgc_source,
158
+ "climatology": self.bgc_source.get("climatology", True),
159
+ },
160
+ )
165
161
 
166
- u = u * umask
167
- v = v * vmask
162
+ def _get_data(self):
168
163
 
169
- # Compute barotropic velocity
170
- # thicknesses
171
- dz = -self.vertical_coordinate.ds["interface_depth_rho"].diff(dim="s_w")
172
- dz = dz.rename({"s_w": "s_rho"})
173
- # thicknesses at u- and v-points
174
- dzu = interpolate_from_rho_to_u(dz)
175
- dzv = interpolate_from_rho_to_v(dz)
164
+ if self.physics_source["name"] == "GLORYS":
165
+ data = GLORYSDataset(
166
+ filename=self.physics_source["path"],
167
+ start_time=self.ini_time,
168
+ climatology=self.physics_source["climatology"],
169
+ )
170
+ else:
171
+ raise ValueError(
172
+ 'Only "GLORYS" is a valid option for physics_source["name"].'
173
+ )
174
+ return data
176
175
 
177
- ubar = (dzu * u).sum(dim="s_rho") / dzu.sum(dim="s_rho")
178
- vbar = (dzv * v).sum(dim="s_rho") / dzv.sum(dim="s_rho")
176
+ def _get_bgc_data(self):
179
177
 
180
- # save in new dataset
181
- ds = xr.Dataset()
178
+ if self.bgc_source["name"] == "CESM_REGRIDDED":
182
179
 
183
- ds["temp"] = data_vars["temp"].astype(np.float32)
184
- ds["temp"].attrs["long_name"] = "Potential temperature"
185
- ds["temp"].attrs["units"] = "Celsius"
180
+ bgc_data = CESMBGCDataset(
181
+ filename=self.bgc_source["path"],
182
+ start_time=self.ini_time,
183
+ climatology=self.bgc_source["climatology"],
184
+ )
185
+ bgc_data.post_process()
186
+ else:
187
+ raise ValueError(
188
+ 'Only "CESM_REGRIDDED" is a valid option for bgc_source["name"].'
189
+ )
186
190
 
187
- ds["salt"] = data_vars["salt"].astype(np.float32)
188
- ds["salt"].attrs["long_name"] = "Salinity"
189
- ds["salt"].attrs["units"] = "PSU"
191
+ return bgc_data
190
192
 
191
- ds["zeta"] = ssh.astype(np.float32)
192
- ds["zeta"].attrs["long_name"] = "Free surface"
193
- ds["zeta"].attrs["units"] = "m"
193
+ def _write_into_dataset(self, data_vars, d_meta):
194
194
 
195
- ds["u"] = u.astype(np.float32)
196
- ds["u"].attrs["long_name"] = "u-flux component"
197
- ds["u"].attrs["units"] = "m/s"
195
+ # save in new dataset
196
+ ds = xr.Dataset()
198
197
 
199
- ds["v"] = v.astype(np.float32)
200
- ds["v"].attrs["long_name"] = "v-flux component"
201
- ds["v"].attrs["units"] = "m/s"
198
+ for var in data_vars.keys():
199
+ ds[var] = data_vars[var].astype(np.float32)
200
+ ds[var].attrs["long_name"] = d_meta[var]["long_name"]
201
+ ds[var].attrs["units"] = d_meta[var]["units"]
202
202
 
203
203
  # initialize vertical velocity to zero
204
204
  ds["w"] = xr.zeros_like(
205
205
  self.vertical_coordinate.ds["interface_depth_rho"].expand_dims(
206
- time=ds[dims["time"]]
206
+ time=data_vars["u"].time
207
207
  )
208
208
  ).astype(np.float32)
209
- ds["w"].attrs["long_name"] = "w-flux component"
210
- ds["w"].attrs["units"] = "m/s"
209
+ ds["w"].attrs["long_name"] = d_meta["w"]["long_name"]
210
+ ds["w"].attrs["units"] = d_meta["w"]["units"]
211
211
 
212
- ds["ubar"] = ubar.transpose(dims["time"], "eta_rho", "xi_u").astype(np.float32)
213
- ds["ubar"].attrs["long_name"] = "vertically integrated u-flux component"
214
- ds["ubar"].attrs["units"] = "m/s"
212
+ return ds
215
213
 
216
- ds["vbar"] = vbar.transpose(dims["time"], "eta_v", "xi_rho").astype(np.float32)
217
- ds["vbar"].attrs["long_name"] = "vertically integrated v-flux component"
218
- ds["vbar"].attrs["units"] = "m/s"
214
+ def _add_global_metadata(self, ds):
219
215
 
220
216
  ds = ds.assign_coords(
221
217
  {
@@ -235,10 +231,9 @@ class InitialConditions:
235
231
  ds.attrs["roms_tools_version"] = roms_tools_version
236
232
  ds.attrs["ini_time"] = str(self.ini_time)
237
233
  ds.attrs["model_reference_date"] = str(self.model_reference_date)
238
- ds.attrs["source"] = self.source
239
-
240
- if dims["time"] != "time":
241
- ds = ds.rename({dims["time"]: "time"})
234
+ ds.attrs["physical_source"] = self.physics_source["name"]
235
+ if self.bgc_source is not None:
236
+ ds.attrs["bgc_source"] = self.bgc_source["name"]
242
237
 
243
238
  # Translate the time coordinate to days since the model reference date
244
239
  model_reference_date = np.datetime64(self.model_reference_date)
@@ -248,22 +243,19 @@ class InitialConditions:
248
243
  ds = ds.assign_coords(ocean_time=("time", np.float32(ocean_time)))
249
244
  ds["ocean_time"].attrs[
250
245
  "long_name"
251
- ] = f"time since {np.datetime_as_string(model_reference_date, unit='D')}"
246
+ ] = f"seconds since {np.datetime_as_string(model_reference_date, unit='s')}"
252
247
  ds["ocean_time"].attrs["units"] = "seconds"
253
248
 
254
- ds["theta_s"] = self.vertical_coordinate.ds["theta_s"]
255
- ds["theta_b"] = self.vertical_coordinate.ds["theta_b"]
256
- ds["Tcline"] = self.vertical_coordinate.ds["Tcline"]
257
- ds["hc"] = self.vertical_coordinate.ds["hc"]
249
+ ds.attrs["theta_s"] = self.vertical_coordinate.ds["theta_s"].item()
250
+ ds.attrs["theta_b"] = self.vertical_coordinate.ds["theta_b"].item()
251
+ ds.attrs["Tcline"] = self.vertical_coordinate.ds["Tcline"].item()
252
+ ds.attrs["hc"] = self.vertical_coordinate.ds["hc"].item()
258
253
  ds["sc_r"] = self.vertical_coordinate.ds["sc_r"]
259
254
  ds["Cs_r"] = self.vertical_coordinate.ds["Cs_r"]
260
255
 
261
256
  ds = ds.drop_vars(["s_rho"])
262
257
 
263
- object.__setattr__(self, "ds", ds)
264
-
265
- ds["zeta"].load()
266
- nan_check(ds["zeta"].squeeze(), self.grid.ds.mask_rho)
258
+ return ds
267
259
 
268
260
  def plot(
269
261
  self,
@@ -289,6 +281,38 @@ class InitialConditions:
289
281
  - "w": w-flux component.
290
282
  - "ubar": Vertically integrated u-flux component.
291
283
  - "vbar": Vertically integrated v-flux component.
284
+ - "PO4": Dissolved Inorganic Phosphate (mmol/m³).
285
+ - "NO3": Dissolved Inorganic Nitrate (mmol/m³).
286
+ - "SiO3": Dissolved Inorganic Silicate (mmol/m³).
287
+ - "NH4": Dissolved Ammonia (mmol/m³).
288
+ - "Fe": Dissolved Inorganic Iron (mmol/m³).
289
+ - "Lig": Iron Binding Ligand (mmol/m³).
290
+ - "O2": Dissolved Oxygen (mmol/m³).
291
+ - "DIC": Dissolved Inorganic Carbon (mmol/m³).
292
+ - "DIC_ALT_CO2": Dissolved Inorganic Carbon, Alternative CO2 (mmol/m³).
293
+ - "ALK": Alkalinity (meq/m³).
294
+ - "ALK_ALT_CO2": Alkalinity, Alternative CO2 (meq/m³).
295
+ - "DOC": Dissolved Organic Carbon (mmol/m³).
296
+ - "DON": Dissolved Organic Nitrogen (mmol/m³).
297
+ - "DOP": Dissolved Organic Phosphorus (mmol/m³).
298
+ - "DOPr": Refractory Dissolved Organic Phosphorus (mmol/m³).
299
+ - "DONr": Refractory Dissolved Organic Nitrogen (mmol/m³).
300
+ - "DOCr": Refractory Dissolved Organic Carbon (mmol/m³).
301
+ - "zooC": Zooplankton Carbon (mmol/m³).
302
+ - "spChl": Small Phytoplankton Chlorophyll (mg/m³).
303
+ - "spC": Small Phytoplankton Carbon (mmol/m³).
304
+ - "spP": Small Phytoplankton Phosphorous (mmol/m³).
305
+ - "spFe": Small Phytoplankton Iron (mmol/m³).
306
+ - "spCaCO3": Small Phytoplankton CaCO3 (mmol/m³).
307
+ - "diatChl": Diatom Chlorophyll (mg/m³).
308
+ - "diatC": Diatom Carbon (mmol/m³).
309
+ - "diatP": Diatom Phosphorus (mmol/m³).
310
+ - "diatFe": Diatom Iron (mmol/m³).
311
+ - "diatSi": Diatom Silicate (mmol/m³).
312
+ - "diazChl": Diazotroph Chlorophyll (mg/m³).
313
+ - "diazC": Diazotroph Carbon (mmol/m³).
314
+ - "diazP": Diazotroph Phosphorus (mmol/m³).
315
+ - "diazFe": Diazotroph Iron (mmol/m³).
292
316
  s : int, optional
293
317
  The index of the vertical layer to plot. Default is None.
294
318
  eta : int, optional
@@ -306,9 +330,9 @@ class InitialConditions:
306
330
  Raises
307
331
  ------
308
332
  ValueError
309
- If the specified varname is not one of the valid options.
310
- If field is 3D and none of s_rho, eta, xi are specified.
311
- If field is 2D and both eta and xi are specified.
333
+ If the specified `varname` is not one of the valid options.
334
+ If the field specified by `varname` is 3D and none of `s`, `eta`, or `xi` are specified.
335
+ If the field specified by `varname` is 2D and both `eta` and `xi` are specified.
312
336
  """
313
337
 
314
338
  if len(self.ds[varname].squeeze().dims) == 3 and not any(
@@ -376,7 +400,10 @@ class InitialConditions:
376
400
  else:
377
401
  vmax = field.max().values
378
402
  vmin = field.min().values
379
- cmap = plt.colormaps.get_cmap("YlOrRd")
403
+ if varname in ["temp", "salt"]:
404
+ cmap = plt.colormaps.get_cmap("YlOrRd")
405
+ else:
406
+ cmap = plt.colormaps.get_cmap("YlGn")
380
407
  cmap.set_bad(color="gray")
381
408
  kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
382
409
 
@@ -454,12 +481,14 @@ class InitialConditions:
454
481
 
455
482
  initial_conditions_data = {
456
483
  "InitialConditions": {
457
- "filename": self.filename,
484
+ "physics_source": self.physics_source,
458
485
  "ini_time": self.ini_time.isoformat(),
459
486
  "model_reference_date": self.model_reference_date.isoformat(),
460
- "source": self.source,
461
487
  }
462
488
  }
489
+ # Include bgc_source if it's not None
490
+ if self.bgc_source is not None:
491
+ initial_conditions_data["InitialConditions"]["bgc_source"] = self.bgc_source
463
492
 
464
493
  yaml_data = {
465
494
  **grid_yaml_data,