roms-tools 2.6.2__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +11 -77
  3. roms_tools/analysis/utils.py +0 -66
  4. roms_tools/constants.py +2 -0
  5. roms_tools/download.py +46 -3
  6. roms_tools/plot.py +22 -5
  7. roms_tools/setup/cdr_forcing.py +1126 -0
  8. roms_tools/setup/datasets.py +742 -87
  9. roms_tools/setup/grid.py +42 -4
  10. roms_tools/setup/river_forcing.py +11 -84
  11. roms_tools/setup/tides.py +81 -411
  12. roms_tools/setup/utils.py +241 -37
  13. roms_tools/tests/test_setup/test_cdr_forcing.py +772 -0
  14. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/.zmetadata +53 -1
  15. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/river_tracer/.zattrs +1 -1
  16. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_long_name/.zarray +20 -0
  17. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_long_name/.zattrs +6 -0
  18. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_long_name/0 +0 -0
  19. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_unit/.zarray +20 -0
  20. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_unit/.zattrs +6 -0
  21. roms_tools/tests/test_setup/test_data/river_forcing_no_climatology.zarr/tracer_unit/0 +0 -0
  22. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/.zmetadata +53 -1
  23. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/river_tracer/.zattrs +1 -1
  24. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_long_name/.zarray +20 -0
  25. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_long_name/.zattrs +6 -0
  26. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_long_name/0 +0 -0
  27. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_unit/.zarray +20 -0
  28. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_unit/.zattrs +6 -0
  29. roms_tools/tests/test_setup/test_data/river_forcing_with_bgc.zarr/tracer_unit/0 +0 -0
  30. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/.zattrs +1 -2
  31. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/.zmetadata +27 -5
  32. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ntides/.zarray +20 -0
  33. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ntides/.zattrs +5 -0
  34. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ntides/0 +0 -0
  35. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/omega/.zattrs +1 -3
  36. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/pot_Im/0.0.0 +0 -0
  37. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/pot_Re/0.0.0 +0 -0
  38. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ssh_Im/0.0.0 +0 -0
  39. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/ssh_Re/0.0.0 +0 -0
  40. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/u_Im/0.0.0 +0 -0
  41. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/u_Re/0.0.0 +0 -0
  42. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/v_Im/0.0.0 +0 -0
  43. roms_tools/tests/test_setup/test_data/tidal_forcing.zarr/v_Re/0.0.0 +0 -0
  44. roms_tools/tests/test_setup/test_datasets.py +103 -1
  45. roms_tools/tests/test_setup/test_tides.py +112 -47
  46. roms_tools/utils.py +115 -1
  47. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/METADATA +1 -1
  48. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/RECORD +51 -33
  49. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/WHEEL +1 -1
  50. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/licenses/LICENSE +0 -0
  51. {roms_tools-2.6.2.dist-info → roms_tools-2.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1126 @@
1
+ from dataclasses import dataclass, field
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ from typing import Optional, List, Dict, Union
5
+ import numpy as np
6
+ import xarray as xr
7
+ import cartopy.crs as ccrs
8
+ import logging
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.cm as cm
11
+ from roms_tools import Grid
12
+ from roms_tools.constants import NUM_TRACERS
13
+ from roms_tools.plot import _plot, _get_projection
14
+ from roms_tools.regrid import LateralRegridFromROMS
15
+ from roms_tools.utils import (
16
+ _generate_coordinate_range,
17
+ _remove_edge_nans,
18
+ save_datasets,
19
+ )
20
+ from roms_tools.setup.utils import (
21
+ gc_dist,
22
+ get_tracer_defaults,
23
+ get_tracer_metadata_dict,
24
+ add_tracer_metadata_to_ds,
25
+ to_float,
26
+ _to_yaml,
27
+ _from_yaml,
28
+ )
29
+
30
+
31
+ @dataclass(kw_only=True)
32
+ class CDRVolumePointSource:
33
+ """Represents one or several volume sources of water with tracers at specific
34
+ location(s). This class is particularly useful for modeling point sources of Carbon
35
+ Dioxide Removal (CDR) forcing data, such as the injection of water and
36
+ biogeochemical tracers, e.g., alkalinity (ALK) or dissolved inorganic carbon (DIC),
37
+ through a pipe.
38
+
39
+ Parameters
40
+ ----------
41
+ grid : Grid, optional
42
+ Object representing the grid for spatial context.
43
+ start_time : datetime
44
+ Start time of the ROMS model simulation.
45
+ end_time : datetime
46
+ End time of the ROMS model simulation.
47
+ model_reference_date : datetime, optional
48
+ Reference date for converting absolute times to model-relative time. Defaults to Jan 1, 2000.
49
+ releases : dict, optional
50
+ A dictionary of existing releases. Defaults to empty dictionary.
51
+
52
+ Attributes
53
+ ----------
54
+ ds : xr.Dataset
55
+ The xarray dataset containing release metadata and forcing variables.
56
+ """
57
+
58
+ grid: Optional["Grid"] = None
59
+ start_time: datetime
60
+ end_time: datetime
61
+ model_reference_date: datetime = datetime(2000, 1, 1)
62
+ releases: Optional[dict] = field(default_factory=dict)
63
+
64
+ ds: xr.Dataset = field(init=False, repr=False)
65
+
66
+ def __post_init__(self):
67
+ if self.start_time >= self.end_time:
68
+ raise ValueError("`start_time` must be earlier than `end_time`.")
69
+
70
+ # Start with an empty dataset representing zero releases
71
+ ds = xr.Dataset(
72
+ {
73
+ "cdr_time": (["time"], np.empty(0)),
74
+ "cdr_lon": (["ncdr"], np.empty(0)),
75
+ "cdr_lat": (["ncdr"], np.empty(0)),
76
+ "cdr_dep": (["ncdr"], np.empty(0)),
77
+ "cdr_hsc": (["ncdr"], np.empty(0)),
78
+ "cdr_vsc": (["ncdr"], np.empty(0)),
79
+ "cdr_volume": (["time", "ncdr"], np.empty((0, 0))),
80
+ "cdr_tracer": (
81
+ ["time", "ntracers", "ncdr"],
82
+ np.empty((0, NUM_TRACERS, 0)),
83
+ ),
84
+ },
85
+ coords={
86
+ "time": (["time"], np.empty(0)),
87
+ "release_name": (["ncdr"], np.empty(0, dtype=str)),
88
+ },
89
+ )
90
+ ds = add_tracer_metadata_to_ds(ds)
91
+ self.ds = ds
92
+
93
+ tracer_metadata = get_tracer_metadata_dict()
94
+ self.releases["_tracer_metadata"] = tracer_metadata
95
+
96
+ if self.releases:
97
+ if "_metadata" not in self.releases:
98
+ tracer_metadata = get_tracer_metadata_dict()
99
+ self.releases["_tracer_metadata"] = tracer_metadata
100
+
101
+ for name, params in self.releases.items():
102
+ if name == "_tracer_metadata":
103
+ continue # skip metadata entry
104
+ self._validate_release_location(
105
+ name=name,
106
+ lat=params["lat"],
107
+ lon=params["lon"],
108
+ depth=params["depth"],
109
+ )
110
+ self._add_release_to_ds(name=name, **params)
111
+
112
+ def add_release(
113
+ self,
114
+ *,
115
+ name: str,
116
+ lat: float,
117
+ lon: float,
118
+ depth: float,
119
+ times: Optional[List[datetime]] = None,
120
+ volume_fluxes: Union[float, List[float]] = 0.0,
121
+ tracer_concentrations: Optional[Dict[str, Union[float, List[float]]]] = None,
122
+ fill_values: str = "auto",
123
+ ):
124
+ """Adds a release (point source) of water with tracers to the forcing dataset
125
+ and dictionary.
126
+
127
+ This method registers a point source at a specific location (latitude, longitude, and depth).
128
+ The release includes both a volume flux of water and tracer
129
+ concentrations, which can be constant or time-varying.
130
+
131
+ Parameters
132
+ ----------
133
+ name : str
134
+ Unique identifier for the release.
135
+ lat : float or int
136
+ Latitude of the release location in degrees North. Must be between -90 and 90.
137
+ lon : float or int
138
+ Longitude of the release location in degrees East. No restrictions on bounds.
139
+ depth : float or int
140
+ Depth of the release in meters. Must be non-negative.
141
+ times : list of datetime.datetime, optional
142
+ Explicit time points for volume fluxes and tracer concentrations. Defaults to [self.start_time, self.end_time] if None.
143
+
144
+ Example: `times=[datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)]`
145
+
146
+ volume_fluxes : float, int, or list of float/int, optional
147
+
148
+ Volume flux(es) of the release in m³/s over time.
149
+
150
+ - Constant: applies uniformly across the entire simulation period.
151
+ - Time-varying: must match the length of `times`.
152
+
153
+ Example:
154
+
155
+ - Constant: `volume_fluxes=1000.0` (uniform across the entire simulation period).
156
+ - Time-varying: `volume_fluxes=[1000.0, 1500.0, 2000.0]` (corresponds to each `times` entry).
157
+
158
+ tracer_concentrations : dict, optional
159
+
160
+ Dictionary of tracer names and their concentration values. The concentration values can be either
161
+ a float/int (constant in time) or a list of float/int (time-varying).
162
+
163
+ - Constant: applies uniformly across the entire simulation period.
164
+ - Time-varying: must match the length of `times`.
165
+
166
+ Default is an empty dictionary (`{}`) if not provided.
167
+ Example:
168
+
169
+ - Constant: `{"ALK": 2000.0, "DIC": 1900.0}`
170
+ - Time-varying: `{"ALK": [2000.0, 2050.0, 2013.3], "DIC": [1900.0, 1920.0, 1910.2]}`
171
+ - Mixed: `{"ALK": 2000.0, "DIC": [1900.0, 1920.0, 1910.2]}`
172
+
173
+ fill_values : str, optional
174
+
175
+ Strategy for filling missing tracer concentration values. Options:
176
+
177
+ - "auto" (default): automatically set values to non-zero defaults
178
+ - "zero": fill missing values with 0.0
179
+ """
180
+ # Check that the name is unique
181
+ if name in self.releases:
182
+ raise ValueError(f"A release with the name '{name}' already exists.")
183
+
184
+ # Check that fill_values has proper string
185
+ if fill_values not in ("auto", "zero"):
186
+ raise ValueError(
187
+ f"Invalid fill_values option: '{fill_values}'. "
188
+ "Must be 'auto' or 'zero'."
189
+ )
190
+
191
+ # Set default for times if None
192
+ if times is None:
193
+ times = []
194
+
195
+ # Set default for tracer_concentrations if None
196
+ if tracer_concentrations is None:
197
+ tracer_concentrations = {}
198
+
199
+ # Fill in missing tracer concentrations
200
+ defaults = get_tracer_defaults()
201
+ for tracer_name in self.ds.tracer_name.values:
202
+ if tracer_name not in tracer_concentrations:
203
+ tracer_name = str(tracer_name)
204
+ if tracer_name in ["temp", "salt"]:
205
+ tracer_concentrations[tracer_name] = defaults[tracer_name]
206
+ else:
207
+ if fill_values == "auto":
208
+ tracer_concentrations[tracer_name] = defaults[tracer_name]
209
+ elif fill_values == "zero":
210
+ tracer_concentrations[tracer_name] = 0.0
211
+
212
+ # Check input parameters
213
+ self._input_checks(
214
+ name=name,
215
+ lat=lat,
216
+ lon=lon,
217
+ depth=depth,
218
+ times=times,
219
+ volume_fluxes=volume_fluxes,
220
+ tracer_concentrations=tracer_concentrations,
221
+ )
222
+
223
+ # Convert integers to floats
224
+ lat = float(lat)
225
+ lon = float(lon)
226
+ depth = float(depth)
227
+ volume_fluxes = to_float(volume_fluxes)
228
+ tracer_concentrations = {
229
+ tracer: to_float(vals) for tracer, vals in tracer_concentrations.items()
230
+ }
231
+
232
+ # Extend volume fluxes and tracer_concentrations across simulation period if necessary
233
+ times, volume_fluxes, tracer_concentrations = self._handle_simulation_endpoints(
234
+ times, volume_fluxes, tracer_concentrations
235
+ )
236
+
237
+ # Validate release location
238
+ self._validate_release_location(name=name, lat=lat, lon=lon, depth=depth)
239
+
240
+ self._add_release_to_dict(
241
+ name=name,
242
+ lat=lat,
243
+ lon=lon,
244
+ depth=depth,
245
+ times=times,
246
+ volume_fluxes=volume_fluxes,
247
+ tracer_concentrations=tracer_concentrations,
248
+ )
249
+
250
+ self._add_release_to_ds(
251
+ name=name,
252
+ lat=lat,
253
+ lon=lon,
254
+ depth=depth,
255
+ times=times,
256
+ volume_fluxes=volume_fluxes,
257
+ tracer_concentrations=tracer_concentrations,
258
+ )
259
+
260
+ def _add_release_to_ds(
261
+ self,
262
+ *,
263
+ name: str,
264
+ lat: float,
265
+ lon: float,
266
+ depth: float,
267
+ times: Optional[List[datetime]] = None,
268
+ tracer_concentrations: Optional[Dict[str, Union[float, List[float]]]] = None,
269
+ volume_fluxes: Union[float, List[float]] = 0.0,
270
+ ):
271
+ """Add the release data for a specific release to the forcing dataset."""
272
+
273
+ # Convert times to datetime64[ns]
274
+ times = np.array(times, dtype="datetime64[ns]")
275
+
276
+ # Ensure reference date is also datetime64[ns]
277
+ ref = np.datetime64(self.model_reference_date, "ns")
278
+
279
+ # Compute model-relative times in days
280
+ rel_times = (times - ref) / np.timedelta64(1, "D")
281
+
282
+ # Merge with existing time dimension
283
+ existing_times = (
284
+ self.ds["time"].values
285
+ if len(self.ds["time"]) > 0
286
+ else np.array([], dtype="datetime64[ns]")
287
+ )
288
+ existing_rel_times = (
289
+ self.ds["cdr_time"].values if len(self.ds["cdr_time"]) > 0 else []
290
+ )
291
+ union_times = np.union1d(existing_times, times)
292
+ union_rel_times = np.union1d(existing_rel_times, rel_times)
293
+
294
+ # Initialize a fresh dataset to accommodate the new release.
295
+ # xarray does not handle dynamic resizing of dimensions well (e.g., increasing 'ncdr' by 1),
296
+ # so we recreate the dataset with the updated size.
297
+ ds = xr.Dataset()
298
+ ds["time"] = ("time", union_times)
299
+ ds["cdr_time"] = ("time", union_rel_times)
300
+ ds = add_tracer_metadata_to_ds(ds)
301
+
302
+ release_names = np.concatenate([self.ds.release_name.values, [name]])
303
+ ds = ds.assign_coords({"release_name": (["ncdr"], release_names)})
304
+ ds["cdr_lon"] = xr.zeros_like(ds.ncdr, dtype=np.float64)
305
+ ds["cdr_lat"] = xr.zeros_like(ds.ncdr, dtype=np.float64)
306
+ ds["cdr_dep"] = xr.zeros_like(ds.ncdr, dtype=np.float64)
307
+ ds["cdr_hsc"] = xr.zeros_like(ds.ncdr, dtype=np.float64)
308
+ ds["cdr_vsc"] = xr.zeros_like(ds.ncdr, dtype=np.float64)
309
+
310
+ ds["cdr_volume"] = xr.zeros_like(ds.cdr_time * ds.ncdr, dtype=np.float64)
311
+ ds["cdr_tracer"] = xr.zeros_like(
312
+ ds.cdr_time * ds.ntracers * ds.ncdr, dtype=np.float64
313
+ )
314
+
315
+ # Retain previous experiment locations
316
+ if len(self.ds["ncdr"]) > 0:
317
+ for i in range(len(self.ds.ncdr)):
318
+ for var_name in ["cdr_lon", "cdr_lat", "cdr_dep", "cdr_hsc", "cdr_vsc"]:
319
+ ds[var_name].loc[{"ncdr": i}] = self.ds[var_name].isel(ncdr=i)
320
+
321
+ # Add the new experiment location
322
+ for var_name, value in zip(
323
+ ["cdr_lon", "cdr_lat", "cdr_dep", "cdr_hsc", "cdr_vsc"],
324
+ [lon, lat, depth, 0.0, 0.0],
325
+ ):
326
+ ds[var_name].loc[{"ncdr": ds.sizes["ncdr"] - 1}] = np.float64(value)
327
+
328
+ # Interpolate and retain previous experiment volume fluxes and tracer concentrations
329
+ if len(self.ds["ncdr"]) > 0:
330
+ for i in range(len(self.ds.ncdr)):
331
+ interpolated = np.interp(
332
+ union_rel_times,
333
+ self.ds["cdr_time"].values,
334
+ self.ds["cdr_volume"].isel(ncdr=i).values,
335
+ )
336
+ ds["cdr_volume"].loc[{"ncdr": i}] = interpolated
337
+
338
+ for n in range(len(self.ds.ntracers)):
339
+ interpolated = np.interp(
340
+ union_rel_times,
341
+ self.ds["cdr_time"].values,
342
+ self.ds["cdr_tracer"].isel(ntracers=n, ncdr=i).values,
343
+ )
344
+ ds["cdr_tracer"].loc[{"ntracers": n, "ncdr": i}] = interpolated
345
+
346
+ # Handle new experiment volume fluxes and tracer concentrations
347
+ if isinstance(volume_fluxes, list):
348
+ interpolated = np.interp(union_rel_times, rel_times, volume_fluxes)
349
+ else:
350
+ interpolated = np.full(len(union_rel_times), volume_fluxes)
351
+
352
+ ds["cdr_volume"].loc[{"ncdr": ds.sizes["ncdr"] - 1}] = interpolated
353
+
354
+ for n in range(len(self.ds.ntracers)):
355
+ tracer_name = ds.tracer_name[n].item()
356
+ if isinstance(tracer_concentrations[tracer_name], list):
357
+ interpolated = np.interp(
358
+ union_rel_times, rel_times, tracer_concentrations[tracer_name]
359
+ )
360
+ else:
361
+ interpolated = np.full(
362
+ len(union_rel_times), tracer_concentrations[tracer_name]
363
+ )
364
+
365
+ ds["cdr_tracer"].loc[
366
+ {"ntracers": n, "ncdr": ds.sizes["ncdr"] - 1}
367
+ ] = interpolated
368
+
369
+ self.ds = ds
370
+
371
+ def _add_release_to_dict(self, name: str, **params):
372
+ """Add the release data for a specific 'name' to the releases dictionary.
373
+
374
+ Parameters
375
+ ----------
376
+ name : str
377
+ The unique name for the release to be added to the dictionary.
378
+ **params : keyword arguments
379
+ Parameters to be added for the specific release (e.g., location, volume fluxes, etc.).
380
+ """
381
+ # Add the parameters to the dictionary under the given name
382
+ if name not in self.releases:
383
+ self.releases[name] = {}
384
+ self.releases[name].update(params)
385
+
386
+ def plot_volume_flux(self, start=None, end=None, releases="all"):
387
+ """Plot the volume flux for each specified release within the given time range.
388
+
389
+ Parameters
390
+ ----------
391
+ start : datetime or None
392
+ Start datetime for the plot. If None, defaults to `self.start_time`.
393
+ end : datetime or None
394
+ End datetime for the plot. If None, defaults to `self.end_time`.
395
+ releases : str, list of str, or "all", optional
396
+ A string or list of release names to plot.
397
+ If "all", the method will plot all releases.
398
+ The default is "all".
399
+ """
400
+
401
+ start = start or self.start_time
402
+ end = end or self.end_time
403
+
404
+ # Handle "all" releases case
405
+ if releases == "all":
406
+ releases = [k for k in self.releases if k != "_tracer_metadata"]
407
+ # Validate input for release names
408
+ self._validate_release_input(releases)
409
+
410
+ data = self.ds["cdr_volume"]
411
+
412
+ self._plot_line(
413
+ data,
414
+ releases,
415
+ start,
416
+ end,
417
+ title="Volume flux of release(s)",
418
+ ylabel=r"m$^3$/s",
419
+ )
420
+
421
+ def plot_tracer_concentration(
422
+ self, name: str, start=None, end=None, releases="all"
423
+ ):
424
+ """Plot the concentration of a given tracer for each specified release within
425
+ the given time range.
426
+
427
+ Parameters
428
+ ----------
429
+ name : str
430
+ Name of the tracer to plot, e.g., "ALK", "DIC", etc.
431
+ start : datetime or None
432
+ Start datetime for the plot. If None, defaults to `self.start_time`.
433
+ end : datetime or None
434
+ End datetime for the plot. If None, defaults to `self.end_time`.
435
+ releases : str, list of str, or "all", optional
436
+ A string or list of release names to plot.
437
+ If "all", the method will plot all releases.
438
+ The default is "all".
439
+ """
440
+ start = start or self.start_time
441
+ end = end or self.end_time
442
+
443
+ # Handle "all" releases case
444
+ if releases == "all":
445
+ releases = [k for k in self.releases if k != "_tracer_metadata"]
446
+ # Validate input for release names
447
+ self._validate_release_input(releases)
448
+
449
+ tracer_names = list(self.ds["tracer_name"].values)
450
+ if name not in tracer_names:
451
+ raise ValueError(
452
+ f"Tracer '{name}' not found. Available: {', '.join(tracer_names)}"
453
+ )
454
+
455
+ tracer_index = tracer_names.index(name)
456
+ data = self.ds["cdr_tracer"].isel(ntracers=tracer_index)
457
+
458
+ if name == "temp":
459
+ title = "Temperature of release water"
460
+ elif name == "salt":
461
+ title = "Salinity of release water"
462
+ else:
463
+ title = f"{name} concentration of release(s)"
464
+
465
+ self._plot_line(
466
+ data,
467
+ releases,
468
+ start,
469
+ end,
470
+ title=title,
471
+ ylabel=f"{self.ds['tracer_unit'].isel(ntracers=tracer_index).values.item()}",
472
+ )
473
+
474
+ def _plot_line(self, data, releases, start, end, title="", ylabel=""):
475
+ """Plots a line graph for the specified releases and time range."""
476
+ colors = self._get_release_colors()
477
+
478
+ fig, ax = plt.subplots(1, 1, figsize=(7, 4))
479
+ for release in releases:
480
+ ncdr = np.where(self.ds["release_name"].values == release)[0][0]
481
+ data.isel(ncdr=ncdr).plot(
482
+ ax=ax,
483
+ linewidth=2,
484
+ label=release,
485
+ color=colors[release],
486
+ marker="x",
487
+ )
488
+
489
+ if len(releases) > 0:
490
+ ax.legend()
491
+
492
+ ax.set(title=title, ylabel=ylabel)
493
+ ax.set_xlim([start, end])
494
+
495
+ def plot_location_top_view(self, releases="all"):
496
+ """Plot the top-down view of release locations.
497
+
498
+ Parameters
499
+ ----------
500
+ releases : list of str or str, optional
501
+ A single release name (string) or a list of release names (strings) to plot.
502
+ Default is 'all', which will plot all releases.
503
+
504
+ Raises
505
+ ------
506
+ ValueError
507
+ If `self.grid` is not set.
508
+ If `releases` is not a string or list of strings.
509
+ If any of the specified releases do not exist in `self.releases`.
510
+ """
511
+ # Ensure that the grid is provided
512
+ if self.grid is None:
513
+ raise ValueError(
514
+ "A grid must be provided for plotting. Please pass a valid `Grid` object."
515
+ )
516
+
517
+ # Handle "all" releases case
518
+ if releases == "all":
519
+ releases = [k for k in self.releases if k != "_tracer_metadata"]
520
+
521
+ # Validate input for release names
522
+ self._validate_release_input(releases)
523
+
524
+ # Proceed with plotting
525
+ field = self.grid.ds.mask_rho
526
+ lon_deg = self.grid.ds.lon_rho
527
+ lat_deg = self.grid.ds.lat_rho
528
+ if self.grid.straddle:
529
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
530
+ field = field.assign_coords({"lon": lon_deg, "lat": lat_deg})
531
+
532
+ vmax = 6
533
+ vmin = 0
534
+ cmap = plt.colormaps.get_cmap("Blues")
535
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
536
+
537
+ trans = _get_projection(lon_deg, lat_deg)
538
+
539
+ fig, ax = plt.subplots(1, 1, figsize=(13, 7), subplot_kw={"projection": trans})
540
+
541
+ _plot(field, kwargs=kwargs, ax=ax, c=None, add_colorbar=False)
542
+
543
+ proj = ccrs.PlateCarree()
544
+
545
+ colors = self._get_release_colors()
546
+
547
+ for name in releases:
548
+ # transform coordinates to projected space
549
+ transformed_lon, transformed_lat = trans.transform_point(
550
+ self.releases[name]["lon"],
551
+ self.releases[name]["lat"],
552
+ proj,
553
+ )
554
+
555
+ ax.plot(
556
+ transformed_lon,
557
+ transformed_lat,
558
+ marker="x",
559
+ markersize=8,
560
+ markeredgewidth=2,
561
+ label=name,
562
+ color=colors[name],
563
+ )
564
+
565
+ ax.set_title("Release locations")
566
+ ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
567
+
568
+ def plot_location_side_view(self, release: str = None):
569
+ """Plot the release location from a side view, showing bathymetry sections along
570
+ both fixed longitude and latitude.
571
+
572
+ This method creates two plots:
573
+
574
+ - A bathymetry section along a fixed longitude (latitudinal view),
575
+ with the release location marked by an "x".
576
+ - A bathymetry section along a fixed latitude (longitudinal view),
577
+ with the release location also marked by an "x".
578
+
579
+ Parameters
580
+ ----------
581
+ release : str, optional
582
+ Name of the release to plot. If only one release is available,
583
+ it is used by default. If multiple releases are available, this must be specified.
584
+
585
+ Raises
586
+ ------
587
+ ValueError
588
+
589
+ If `self.grid` is not set.
590
+ If the specified `release` does not exist in `self.releases`.
591
+ If no `release` is provided when multiple releases are available.
592
+ """
593
+ if self.grid is None:
594
+ raise ValueError(
595
+ "A grid must be provided for plotting. Please pass a valid `Grid` object."
596
+ )
597
+
598
+ valid_releases = [r for r in self.releases if r != "_tracer_metadata"]
599
+ if release is None:
600
+ if len(valid_releases) == 1:
601
+ release = valid_releases[0]
602
+ else:
603
+ raise ValueError(
604
+ f"Multiple releases found: {valid_releases}. Please specify a single release to plot."
605
+ )
606
+
607
+ self._validate_release_input(release, list_allowed=False)
608
+
609
+ def _plot_bathymetry_section(
610
+ ax, h, dim, fixed_val, coord_deg, resolution, title
611
+ ):
612
+ """Plots a bathymetry section along a fixed latitude or longitude.
613
+
614
+ Parameters
615
+ ----------
616
+ ax : matplotlib.axes.Axes
617
+ The axis on which the plot will be drawn.
618
+
619
+ h : xarray.DataArray
620
+ The bathymetry data to plot.
621
+
622
+ dim : str
623
+ The dimension along which to plot the section, either "lat" or "lon".
624
+
625
+ fixed_val : float
626
+ The fixed value of latitude or longitude for the section.
627
+
628
+ coord_deg : xarray.DataArray
629
+ The array of latitude or longitude coordinates.
630
+
631
+ resolution : float
632
+ The resolution at which to generate the coordinate range.
633
+
634
+ title : str
635
+ The title for the plot.
636
+
637
+ Returns
638
+ -------
639
+ None
640
+ The function does not return anything. It directly plots the bathymetry section on the provided axis.
641
+ """
642
+ # Determine coordinate names and build target range
643
+ var_range = _generate_coordinate_range(
644
+ coord_deg.min().values, coord_deg.max().values, resolution
645
+ )
646
+ var_name = "lat" if dim == "lon" else "lon"
647
+ range_da = xr.DataArray(
648
+ var_range,
649
+ dims=[var_name],
650
+ attrs={"units": "°N" if var_name == "lat" else "°E"},
651
+ )
652
+
653
+ # Construct target coordinates for regridding
654
+ target_coords = {dim: [fixed_val], var_name: range_da}
655
+ regridder = LateralRegridFromROMS(h, target_coords)
656
+ section = regridder.apply(h)
657
+ section, _ = _remove_edge_nans(section, var_name)
658
+
659
+ # Plot the bathymetry section
660
+ section.plot(ax=ax, color="k")
661
+ ax.fill_between(section[var_name], section.squeeze(), y2=0, color="#deebf7")
662
+ ax.invert_yaxis()
663
+ ax.set_xlabel("Latitude [°N]" if var_name == "lat" else "Longitude [°E]")
664
+ ax.set_ylabel("Depth [m]")
665
+ ax.set_title(title)
666
+
667
+ # Prepare grid coordinates
668
+ lon_deg = self.grid.ds.lon_rho
669
+ lat_deg = self.grid.ds.lat_rho
670
+ if self.grid.straddle:
671
+ lon_deg = xr.where(lon_deg > 180, lon_deg - 360, lon_deg)
672
+
673
+ resolution = self.grid._infer_nominal_horizontal_resolution()
674
+ h = self.grid.ds.h.assign_coords({"lon": lon_deg, "lat": lat_deg})
675
+
676
+ # Set up plot
677
+ fig, axs = plt.subplots(2, 1, figsize=(7, 8))
678
+
679
+ # Plot along fixed longitude
680
+ _plot_bathymetry_section(
681
+ ax=axs[0],
682
+ h=h,
683
+ dim="lon",
684
+ fixed_val=self.releases[release]["lon"],
685
+ coord_deg=lat_deg,
686
+ resolution=resolution,
687
+ title=f"Longitude: {self.releases[release]['lon']}°E",
688
+ )
689
+
690
+ colors = self._get_release_colors()
691
+
692
+ axs[0].plot(
693
+ self.releases[release]["lat"],
694
+ self.releases[release]["depth"],
695
+ color=colors[release],
696
+ marker="x",
697
+ markersize=8,
698
+ markeredgewidth=2,
699
+ )
700
+
701
+ # Plot along fixed latitude
702
+ _plot_bathymetry_section(
703
+ ax=axs[1],
704
+ h=h,
705
+ dim="lat",
706
+ fixed_val=self.releases[release]["lat"],
707
+ coord_deg=lon_deg,
708
+ resolution=resolution,
709
+ title=f"Latitude: {self.releases[release]['lat']}°N",
710
+ )
711
+ axs[1].plot(
712
+ self.releases[release]["lon"],
713
+ self.releases[release]["depth"],
714
+ color=colors[release],
715
+ marker="x",
716
+ markersize=8,
717
+ markeredgewidth=2,
718
+ )
719
+
720
+ # Adjust layout and title
721
+ fig.subplots_adjust(hspace=0.4)
722
+ fig.suptitle(f"Release location for: {release}")
723
+
724
+ def save(
725
+ self,
726
+ filepath: Union[str, Path],
727
+ ) -> None:
728
+ """Save the volume source with tracers to netCDF4 file.
729
+
730
+ Parameters
731
+ ----------
732
+ filepath : Union[str, Path]
733
+ The base path and filename for the output files.
734
+
735
+ Returns
736
+ -------
737
+ List[Path]
738
+ A list of `Path` objects for the saved files. Each element in the list corresponds to a file that was saved.
739
+ """
740
+
741
+ # Ensure filepath is a Path object
742
+ filepath = Path(filepath)
743
+
744
+ # Remove ".nc" suffix if present
745
+ if filepath.suffix == ".nc":
746
+ filepath = filepath.with_suffix("")
747
+
748
+ dataset_list = [self.ds]
749
+ output_filenames = [str(filepath)]
750
+
751
+ saved_filenames = save_datasets(dataset_list, output_filenames)
752
+
753
+ return saved_filenames
754
+
755
+ def to_yaml(self, filepath: Union[str, Path]) -> None:
756
+ """Export the parameters of the class to a YAML file, including the version of
757
+ roms-tools.
758
+
759
+ Parameters
760
+ ----------
761
+ filepath : Union[str, Path]
762
+ The path to the YAML file where the parameters will be saved.
763
+ """
764
+
765
+ _to_yaml(self, filepath)
766
+
767
+ @classmethod
768
+ def from_yaml(cls, filepath: Union[str, Path]) -> "CDRVolumePointSource":
769
+ """Create an instance of the CDRVolumePointSource class from a YAML file.
770
+
771
+ Parameters
772
+ ----------
773
+ filepath : Union[str, Path]
774
+ The path to the YAML file from which the parameters will be read.
775
+
776
+ Returns
777
+ -------
778
+ CDRVolumePointSource
779
+ An instance of the CDRVolumePointSource class.
780
+ """
781
+ filepath = Path(filepath)
782
+
783
+ grid = Grid.from_yaml(filepath)
784
+ params = _from_yaml(cls, filepath)
785
+
786
+ return cls(grid=grid, **params)
787
+
788
+ def _input_checks(
789
+ self,
790
+ name,
791
+ lat,
792
+ lon,
793
+ depth,
794
+ times,
795
+ volume_fluxes,
796
+ tracer_concentrations,
797
+ ):
798
+ """Perform various input checks on release parameters.
799
+
800
+ - Checks that latitude is between -90 and 90.
801
+ - Checks that depth is non-negative.
802
+ - Ensures 'times' is a list of datetime objects and is monotonically increasing.
803
+ - Verifies that times are within the defined start and end time.
804
+ - Ensures volume fluxes is either a list of floats/ints or a single float/int.
805
+ - Ensures each tracer concentration is either a float/int or a list of floats/ints.
806
+ - Ensures the lengths of 'volume_fluxes' and 'tracer_concentrations' match the length of 'times' if they are lists.
807
+ - Ensures all entries in 'volume_fluxes' and 'tracer_concentrations' are non-negative.
808
+ """
809
+
810
+ # Check that lat is valid
811
+ if not (-90 <= lat <= 90):
812
+ raise ValueError(
813
+ f"Invalid latitude {lat}. Latitude must be between -90 and 90."
814
+ )
815
+
816
+ # Check that depth is non-negative
817
+ if depth < 0:
818
+ raise ValueError(
819
+ f"Invalid depth {depth}. Depth must be a non-negative number."
820
+ )
821
+
822
+ # Ensure that times is a list of datetimes
823
+ if not all(isinstance(t, datetime) for t in times):
824
+ raise ValueError(
825
+ f"If 'times' is provided, all entries must be datetime objects. Got: {[type(t) for t in times]}"
826
+ )
827
+
828
+ if len(times) > 0:
829
+ if len(times) > 1:
830
+ # Check that times is strictly monotonically increasing sequence
831
+ if not all(t1 < t2 for t1, t2 in zip(times, times[1:])):
832
+ raise ValueError(
833
+ f"The 'times' list must be strictly monotonically increasing. Got: {[t for t in times]}"
834
+ )
835
+
836
+ # Check that first time is not before start_time
837
+ if times[0] < self.start_time:
838
+ raise ValueError(
839
+ f"First entry in `times` ({times[0]}) cannot be before `self.start_time` ({self.start_time})."
840
+ )
841
+
842
+ # Check that last time is not after end_time
843
+ if times[-1] > self.end_time:
844
+ raise ValueError(
845
+ f"Last entry in `times` ({times[-1]}) cannot be after `self.end_time` ({self.end_time})."
846
+ )
847
+
848
+ # Ensure volume fluxes is either a list of floats/ints or a single float/int
849
+ if not isinstance(volume_fluxes, (float, int)) and not (
850
+ isinstance(volume_fluxes, list)
851
+ and all(isinstance(v, (float, int)) for v in volume_fluxes)
852
+ ):
853
+ raise ValueError(
854
+ "Invalid 'volume_fluxes' input: must be a float/int or a list of floats/ints."
855
+ )
856
+
857
+ # Ensure each tracer concentration is either a float/int or a list of floats/ints
858
+ for key, val in tracer_concentrations.items():
859
+ if not isinstance(val, (float, int)) and not (
860
+ isinstance(val, list) and all(isinstance(v, (float, int)) for v in val)
861
+ ):
862
+ raise ValueError(
863
+ f"Invalid tracer concentration for '{key}': must be a float/int or a list of floats/ints."
864
+ )
865
+
866
+ # Ensure that time series for 'times', 'volume_fluxes', and 'tracer_concentrations' are all the same length
867
+ num_times = len(times)
868
+
869
+ # Check that volume fluxes is either a constant or has the same length as 'times'
870
+ if isinstance(volume_fluxes, list) and len(volume_fluxes) != num_times:
871
+ raise ValueError(
872
+ f"The length of `volume_fluxes` ({len(volume_fluxes)}) does not match the length of `times` ({num_times})."
873
+ )
874
+
875
+ # Check that tracer_concentrations are either constants or have the same length as 'times'
876
+ for key, tracer_values in tracer_concentrations.items():
877
+ if isinstance(tracer_values, list) and len(tracer_values) != num_times:
878
+ raise ValueError(
879
+ f"The length of tracer '{key}' ({len(tracer_values)}) does not match the length of `times` ({num_times})."
880
+ )
881
+
882
+ # Check that volume fluxes and tracer concentrations are valid
883
+ if isinstance(volume_fluxes, (float, int)) and volume_fluxes < 0:
884
+ raise ValueError(f"Volume flux must be non-negative. Got: {volume_fluxes}")
885
+ elif isinstance(volume_fluxes, list) and not all(v >= 0 for v in volume_fluxes):
886
+ raise ValueError(
887
+ f"All entries in `volume_fluxes` must be non-negative. Got: {volume_fluxes}"
888
+ )
889
+ for key, tracer_values in tracer_concentrations.items():
890
+ if key != "temp":
891
+ if isinstance(tracer_values, (float, int)) and tracer_values < 0:
892
+ raise ValueError(
893
+ f"The concentration of tracer '{key}' must be non-negative. Got: {tracer_values}"
894
+ )
895
+ elif isinstance(tracer_values, list) and not all(
896
+ c >= 0 for c in tracer_values
897
+ ):
898
+ raise ValueError(
899
+ f"All entries in `tracer_concentrations['{key}']` must be non-negative. Got: {tracer_values}"
900
+ )
901
+
902
+ def _handle_simulation_endpoints(self, times, volume_fluxes, tracer_concentrations):
903
+ """Ensure that the release time series starts at self.start_time and ends at
904
+ self.end_time.
905
+
906
+ If `volume_fluxes` is a list and does not cover the endpoints, zero volume fluxes are added.
907
+ Tracer concentrations are extended accordingly by duplicating endpoint values.
908
+ """
909
+
910
+ if len(times) > 0:
911
+ # Handle start_time
912
+ if times[0] != self.start_time:
913
+ if isinstance(volume_fluxes, list):
914
+ volume_fluxes.insert(0, 0.0)
915
+
916
+ for key, vals in tracer_concentrations.items():
917
+ if isinstance(vals, list):
918
+ vals.insert(0, vals[0])
919
+
920
+ times.insert(0, self.start_time)
921
+
922
+ # Handle end_time
923
+ if times[-1] != self.end_time:
924
+ if isinstance(volume_fluxes, list):
925
+ volume_fluxes.append(0.0)
926
+
927
+ for key, vals in tracer_concentrations.items():
928
+ if isinstance(vals, list):
929
+ vals.append(vals[-1])
930
+
931
+ times.append(self.end_time)
932
+
933
+ else:
934
+ times = [self.start_time, self.end_time]
935
+
936
+ return times, volume_fluxes, tracer_concentrations
937
+
938
+ def _validate_release_location(self, name, lat, lon, depth):
939
+ """Validates the closest grid location for a release site.
940
+
941
+ This function ensures that the given release site (lat, lon, depth) lies
942
+ within the ocean portion of the model grid domain. It:
943
+
944
+ - Checks if the point is within the grid domain (with buffer for boundary artifacts).
945
+ - Verifies that the location is not on land.
946
+ - Verifies that the location is not below the seafloor.
947
+
948
+ Parameters
949
+ ----------
950
+ name : str
951
+ A unique identifier for the release location.
952
+ lat : float
953
+ Latitude of the release location.
954
+ lon : float
955
+ Longitude of the release location.
956
+ depth : float
957
+ Depth (positive, in meters) of the release location.
958
+
959
+ Raises
960
+ ------
961
+ ValueError
962
+ If the location is:
963
+ - Outside the model grid.
964
+ - On the boundary of the grid domain (eta_rho, xi_rho = 0 or max).
965
+ - On land (based on `mask_rho`).
966
+ - Below the ocean bottom (`h < depth`).
967
+ Warning
968
+ If no grid is available to validate the location.
969
+ """
970
+ if self.grid:
971
+ # Adjust longitude based on whether it crosses the International Date Line (straddle case)
972
+ if self.grid.straddle:
973
+ lon = xr.where(lon > 180, lon - 360, lon)
974
+ else:
975
+ lon = xr.where(lon < 0, lon + 360, lon)
976
+
977
+ dx = 1 / self.grid.ds.pm
978
+ dy = 1 / self.grid.ds.pn
979
+ max_grid_spacing = np.sqrt(dx**2 + dy**2) / 2
980
+
981
+ # Compute great-circle distance to all grid points
982
+ dist = gc_dist(self.grid.ds.lon_rho, self.grid.ds.lat_rho, lon, lat)
983
+ dist_min = dist.min(dim=["eta_rho", "xi_rho"])
984
+
985
+ if (dist_min > max_grid_spacing).all():
986
+ raise ValueError(
987
+ f"Release site '{name}' is outside of the grid domain. "
988
+ "Ensure the provided (lat, lon) falls within the model grid extent."
989
+ )
990
+
991
+ # Find the indices of the closest grid cell
992
+ indices = np.where(dist == dist_min)
993
+ eta_rho = indices[0][0]
994
+ xi_rho = indices[1][0]
995
+
996
+ eta_max = self.grid.ds.sizes["eta_rho"] - 1
997
+ xi_max = self.grid.ds.sizes["xi_rho"] - 1
998
+
999
+ if eta_rho in [0, eta_max] or xi_rho in [0, xi_max]:
1000
+ raise ValueError(
1001
+ f"Release site '{name}' is located too close to the grid boundary. "
1002
+ "Place release location (lat, lon) away from grid boundaries."
1003
+ )
1004
+
1005
+ if self.grid.ds.mask_rho[eta_rho, xi_rho].values == 0:
1006
+ raise ValueError(
1007
+ f"Release site '{name}' is on land. "
1008
+ "Please provide coordinates (lat, lon) over ocean."
1009
+ )
1010
+
1011
+ if self.grid.ds.h[eta_rho, xi_rho].values < depth:
1012
+ raise ValueError(
1013
+ f"Release site '{name}' lies below the seafloor. "
1014
+ f"Seafloor depth is {self.grid.ds.h[eta_rho, xi_rho].values:.2f} m, "
1015
+ f"but requested depth is {depth:.2f} m. Adjust depth or location (lat, lon)."
1016
+ )
1017
+
1018
+ else:
1019
+ logging.warning(
1020
+ "Grid not provided: cannot verify whether the specified lat/lon/depth location is within the domain or on land. "
1021
+ "Please check manually or provide a grid when instantiating the class."
1022
+ )
1023
+
1024
+ def _validate_release_input(self, releases, list_allowed=True):
1025
+ """Validates the input for release names in plotting methods to ensure they are
1026
+ in an acceptable format and exist within the set of valid releases.
1027
+
1028
+ This method ensures that the `releases` parameter is either a single release name (string) or a list
1029
+ of release names (strings), and checks that each release exists in the set of valid releases.
1030
+
1031
+ Parameters
1032
+ ----------
1033
+ releases : str or list of str
1034
+ A single release name as a string, or a list of release names (strings) to validate.
1035
+
1036
+ list_allowed : bool, optional
1037
+ If `True`, a list of release names is allowed. If `False`, only a single release name (string)
1038
+ is allowed. Default is `True`.
1039
+
1040
+ Raises
1041
+ ------
1042
+ ValueError
1043
+ If `releases` is not a string or list of strings, or if any release name is invalid (not in `self.releases`).
1044
+
1045
+ Notes
1046
+ -----
1047
+ This method checks that the `releases` input is in a valid format (either a string or a list of strings),
1048
+ and ensures each release is present in the set of valid releases defined in `self.releases`. Invalid releases
1049
+ are reported in the error message.
1050
+
1051
+ If `list_allowed` is set to `False`, only a single release name (string) will be accepted. Otherwise, a
1052
+ list of release names is also acceptable.
1053
+ """
1054
+
1055
+ # Ensure that a list of releases is only allowed if `list_allowed` is True
1056
+ if not list_allowed and not isinstance(releases, str):
1057
+ raise ValueError(
1058
+ f"Only a single release name (string) is allowed. Got: {releases}"
1059
+ )
1060
+
1061
+ if isinstance(releases, str):
1062
+ releases = [releases] # Convert to list if a single string is provided
1063
+ elif isinstance(releases, list):
1064
+ if not all(isinstance(r, str) for r in releases):
1065
+ raise ValueError("All elements in `releases` list must be strings.")
1066
+ else:
1067
+ raise ValueError(
1068
+ "`releases` should be a string (single release name) or a list of strings (release names)."
1069
+ )
1070
+
1071
+ # Validate that the specified releases exist in self.releases
1072
+ valid_releases = [k for k in self.releases if k != "_tracer_metadata"]
1073
+ invalid_releases = [
1074
+ release for release in releases if release not in valid_releases
1075
+ ]
1076
+ if invalid_releases:
1077
+ raise ValueError(f"Invalid releases: {', '.join(invalid_releases)}")
1078
+
1079
+ def _get_release_colors(self):
1080
+ """Returns a dictionary of colors for the valid releases, based on a consistent
1081
+ colormap.
1082
+
1083
+ Parameters
1084
+ ----------
1085
+ None
1086
+
1087
+ Returns
1088
+ -------
1089
+ dict
1090
+ A dictionary where the keys are release names and the values are their corresponding colors,
1091
+ assigned based on the order of releases in the valid releases list.
1092
+
1093
+ Raises
1094
+ ------
1095
+ ValueError
1096
+ If the number of valid releases exceeds the available colormap capacity.
1097
+
1098
+ Notes
1099
+ -----
1100
+ The colormap is chosen dynamically based on the number of valid releases:
1101
+
1102
+ - If there are 10 or fewer releases, the "tab10" colormap is used.
1103
+ - If there are more than 10 but fewer than or equal to 20 releases, the "tab20" colormap is used.
1104
+ - For more than 20 releases, the "tab20b" colormap is used.
1105
+ """
1106
+
1107
+ valid_releases = [k for k in self.releases if k != "_tracer_metadata"]
1108
+
1109
+ # Determine the colormap based on the number of releases
1110
+ if len(valid_releases) <= 10:
1111
+ color_map = cm.get_cmap("tab10")
1112
+ elif len(valid_releases) <= 20:
1113
+ color_map = cm.get_cmap("tab20")
1114
+ else:
1115
+ color_map = cm.get_cmap("tab20b")
1116
+
1117
+ # Ensure the number of releases doesn't exceed the available colormap capacity
1118
+ if len(valid_releases) > color_map.N:
1119
+ raise ValueError(
1120
+ f"Too many releases. The selected colormap supports up to {color_map.N} releases."
1121
+ )
1122
+
1123
+ # Create a dictionary of colors based on the release indices
1124
+ colors = {name: color_map(i) for i, name in enumerate(valid_releases)}
1125
+
1126
+ return colors