roms-tools 3.1.2__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. roms_tools/__init__.py +3 -0
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +75 -21
  8. roms_tools/setup/boundary_forcing.py +44 -19
  9. roms_tools/setup/cdr_forcing.py +122 -8
  10. roms_tools/setup/cdr_release.py +161 -8
  11. roms_tools/setup/datasets.py +626 -340
  12. roms_tools/setup/grid.py +138 -137
  13. roms_tools/setup/initial_conditions.py +113 -48
  14. roms_tools/setup/mask.py +63 -7
  15. roms_tools/setup/nesting.py +67 -42
  16. roms_tools/setup/river_forcing.py +45 -19
  17. roms_tools/setup/surface_forcing.py +4 -6
  18. roms_tools/setup/tides.py +1 -2
  19. roms_tools/setup/topography.py +4 -4
  20. roms_tools/setup/utils.py +134 -22
  21. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  22. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  23. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  24. roms_tools/tests/test_setup/test_boundary_forcing.py +54 -52
  25. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  26. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  27. roms_tools/tests/test_setup/test_datasets.py +392 -44
  28. roms_tools/tests/test_setup/test_grid.py +222 -115
  29. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  30. roms_tools/tests/test_setup/test_surface_forcing.py +2 -1
  31. roms_tools/tests/test_setup/test_utils.py +91 -1
  32. roms_tools/tests/test_setup/utils.py +71 -0
  33. roms_tools/tests/test_tiling/test_join.py +241 -0
  34. roms_tools/tests/test_utils.py +139 -17
  35. roms_tools/tiling/join.py +189 -0
  36. roms_tools/utils.py +131 -99
  37. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/METADATA +12 -2
  38. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/RECORD +41 -33
  39. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  40. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  41. {roms_tools-3.1.2.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ from collections import defaultdict
3
3
  from dataclasses import dataclass, field
4
4
  from datetime import datetime
5
5
  from pathlib import Path
6
+ from typing import TypeAlias
6
7
 
7
8
  import matplotlib.pyplot as plt
8
9
  import numpy as np
@@ -18,6 +19,7 @@ from roms_tools.plot import (
18
19
  )
19
20
  from roms_tools.setup.datasets import (
20
21
  DaiRiverDataset,
22
+ RawDataSource,
21
23
  get_indices_of_nearest_grid_cell_for_rivers,
22
24
  )
23
25
  from roms_tools.setup.utils import (
@@ -38,6 +40,8 @@ from roms_tools.utils import save_datasets
38
40
  INCLUDE_ALL_RIVER_NAMES = "all"
39
41
  MAX_RIVERS_TO_PLOT = 20 # must be <= MAX_DISTINCT_COLORS
40
42
 
43
+ TRiverIndex: TypeAlias = dict[tuple[int, int], list[str]]
44
+
41
45
 
42
46
  @dataclass(kw_only=True)
43
47
  class RiverForcing:
@@ -51,7 +55,7 @@ class RiverForcing:
51
55
  Start time of the desired river forcing data.
52
56
  end_time : datetime
53
57
  End time of the desired river forcing data.
54
- source : Dict[str, Union[str, Path, List[Union[str, Path]]], bool], optional
58
+ source : RawDataSource, optional
55
59
  Dictionary specifying the source of the river forcing data. Keys include:
56
60
 
57
61
  - "name" (str): Name of the data source (e.g., "DAI").
@@ -75,7 +79,7 @@ class RiverForcing:
75
79
  Whether to include BGC tracers. Defaults to `False`.
76
80
  model_reference_date : datetime, optional
77
81
  Reference date for the ROMS simulation. Default is January 1, 2000.
78
- indices : dict[str, list[tuple]], optional
82
+ indices : dict[str, list[tuple[int, int]]], optional
79
83
  A dictionary specifying the river indices for each river to be included in the river forcing. This parameter is optional. If not provided,
80
84
  the river indices will be automatically determined based on the grid and the source dataset. If provided, it allows for explicit specification
81
85
  of river locations. The dictionary structure consists of river names as keys, and each value is a list of tuples. Each tuple represents
@@ -101,7 +105,7 @@ class RiverForcing:
101
105
  """Start time of the desired river forcing data."""
102
106
  end_time: datetime
103
107
  """End time of the desired river forcing data."""
104
- source: dict[str, str | Path | list[str | Path]] = None
108
+ source: RawDataSource | None = None
105
109
  """Dictionary specifying the source of the river forcing data."""
106
110
  convert_to_climatology: str = "if_any_missing"
107
111
  """Determines when to compute climatology for river forcing."""
@@ -110,7 +114,7 @@ class RiverForcing:
110
114
  model_reference_date: datetime = datetime(2000, 1, 1)
111
115
  """Reference date for the ROMS simulation."""
112
116
 
113
- indices: dict[str, dict[str, int | list[int]]] | None = None
117
+ indices: dict[str, list[tuple[int, int]]] | None = None
114
118
  """A dictionary of river indices.
115
119
 
116
120
  If not provided during initialization, it will be automatically determined based on
@@ -462,7 +466,7 @@ class RiverForcing:
462
466
 
463
467
  return ds_updated
464
468
 
465
- def _get_overlapping_rivers(self) -> dict[tuple[int, int], list[str]]:
469
+ def _get_overlapping_rivers(self) -> TRiverIndex:
466
470
  """Identify grid cells shared by multiple rivers.
467
471
 
468
472
  Scans through the river indices and finds all grid cell indices
@@ -474,7 +478,10 @@ class RiverForcing:
474
478
  A dictionary mapping grid cell indices (eta_rho, xi_rho) to a list
475
479
  of river names that overlap at that grid cell.
476
480
  """
477
- index_to_rivers = defaultdict(list)
481
+ if self.indices is None:
482
+ return {}
483
+
484
+ index_to_rivers: TRiverIndex = defaultdict(list)
478
485
 
479
486
  # Collect all index pairs used by multiple rivers
480
487
  for river_name, index_list in self.indices.items():
@@ -520,6 +527,9 @@ class RiverForcing:
520
527
  The volume-weighted tracer concentration at the overlapping grid cell,
521
528
  as a new 1-entry DataArray with updated coordinates.
522
529
  """
530
+ if self.indices is None:
531
+ self.indices = {}
532
+
523
533
  new_name = f"overlap_{i}"
524
534
  self.indices[new_name] = [idx_pair]
525
535
 
@@ -578,7 +588,7 @@ class RiverForcing:
578
588
  return combined_river_volume, combined_river_tracer
579
589
 
580
590
  def _reduce_river_volumes(
581
- self, ds: xr.Dataset, overlapping_rivers: dict[tuple[int, int], list[str]]
591
+ self, ds: xr.Dataset, overlapping_rivers: TRiverIndex
582
592
  ) -> xr.Dataset:
583
593
  """Reduce river volumes for rivers contributing to overlapping grid cells.
584
594
 
@@ -595,8 +605,14 @@ class RiverForcing:
595
605
  ds : xr.Dataset
596
606
  Updated dataset with reduced river volumes.
597
607
  """
608
+ if self.indices is None:
609
+ raise ValueError(
610
+ "`self.indices` must be set before calling _reduce_river_volumes"
611
+ )
612
+
598
613
  # Count number of overlaps for each river
599
- river_overlap_count = defaultdict(int)
614
+ river_overlap_count: dict[str, int] = defaultdict(int)
615
+
600
616
  for rivers in overlapping_rivers.values():
601
617
  for name in rivers:
602
618
  river_overlap_count[name] += 1
@@ -671,13 +687,15 @@ class RiverForcing:
671
687
  Warning
672
688
  If NaN values are found in any of the dataset variables, a warning message is logged.
673
689
  """
674
- for var_name in ds.data_vars:
675
- da = ds[var_name]
676
- if da.isnull().any().values:
690
+ var_name = "river_volume"
691
+ da = ds[var_name]
692
+ if da.isnull().any().values:
693
+ logging.warning(
694
+ f"NaNs detected in '{var_name}' and set to zero. This may indicate missing river data and affect model accuracy. "
695
+ )
696
+ if not self.climatology:
677
697
  logging.warning(
678
- f"NaN values detected in the '{var_name}' field. These values are being set to zero. "
679
- "This may indicate missing river data, which could affect model accuracy. Consider setting "
680
- "`convert_to_climatology = 'if_any_missing'` to automatically fill missing values with climatological data."
698
+ "Consider `convert_to_climatology='if_any_missing'` to fill missing values with climatological data."
681
699
  )
682
700
 
683
701
  def plot_locations(self, river_names: list[str] | str = INCLUDE_ALL_RIVER_NAMES):
@@ -691,7 +709,8 @@ class RiverForcing:
691
709
  Defaults to "all".
692
710
 
693
711
  """
694
- valid_river_names = list(self.indices.keys())
712
+ valid_river_names = list(self.indices or [])
713
+
695
714
  river_names = _validate_river_names(river_names, valid_river_names)
696
715
  if len(valid_river_names) > MAX_DISTINCT_COLORS:
697
716
  colors = assign_category_colors(river_names)
@@ -806,7 +825,8 @@ class RiverForcing:
806
825
  Defaults to "all".
807
826
 
808
827
  """
809
- valid_river_names = list(self.indices.keys())
828
+ valid_river_names = list(self.indices or [])
829
+
810
830
  river_names = _validate_river_names(river_names, valid_river_names)
811
831
  if len(valid_river_names) > MAX_DISTINCT_COLORS:
812
832
  colors = assign_category_colors(river_names)
@@ -900,11 +920,17 @@ class RiverForcing:
900
920
  """
901
921
  forcing_dict = to_dict(self, exclude=["climatology"])
902
922
 
903
- # Convert indices format
904
- indices_data = forcing_dict["RiverForcing"]["indices"]
905
- serialized_indices = {"_convention": "eta_rho, xi_rho"}
923
+ indices_data = forcing_dict.get("RiverForcing", {}).get("indices")
924
+ if not indices_data:
925
+ # If no indices, just write the dict as is
926
+ write_to_yaml(forcing_dict, filepath)
927
+ return
928
+
929
+ # Convert tuple indices to string format for YAML
930
+ serialized_indices: dict[str, str | list[str]] = {}
906
931
  for key, value in indices_data.items():
907
932
  serialized_indices[key] = [f"{tup[0]}, {tup[1]}" for tup in value]
933
+ serialized_indices["_convention"] = "eta_rho, xi_rho"
908
934
 
909
935
  # Remove keys starting with "overlap_"
910
936
  filtered_indices = {
@@ -16,6 +16,7 @@ from roms_tools.setup.datasets import (
16
16
  ERA5ARCODataset,
17
17
  ERA5Correction,
18
18
  ERA5Dataset,
19
+ RawDataSource,
19
20
  UnifiedBGCSurfaceDataset,
20
21
  )
21
22
  from roms_tools.setup.utils import (
@@ -56,7 +57,7 @@ class SurfaceForcing:
56
57
  The end time of the desired surface forcing data. This time is used to filter the dataset
57
58
  to include only records on or before this time, with a single record at or after this time.
58
59
  If no time filtering is desired, set it to None. Default is None.
59
- source : Dict[str, Union[str, Path, List[Union[str, Path]]], bool]
60
+ source : RawDataSource
60
61
  Dictionary specifying the source of the surface forcing data. Keys include:
61
62
 
62
63
  - "name" (str): Name of the data source. Currently supported: "ERA5"
@@ -116,7 +117,7 @@ class SurfaceForcing:
116
117
  """The start time of the desired surface forcing data."""
117
118
  end_time: datetime | None = None
118
119
  """The end time of the desired surface forcing data."""
119
- source: dict[str, str | Path | list[str | Path]]
120
+ source: RawDataSource
120
121
  """Dictionary specifying the source of the surface forcing data."""
121
122
  type: str = "physics"
122
123
  """Specifies the type of forcing data ("physics", "bgc")."""
@@ -169,7 +170,6 @@ class SurfaceForcing:
169
170
 
170
171
  data.choose_subdomain(
171
172
  target_coords,
172
- buffer_points=20, # lateral fill needs some buffer from data margin
173
173
  )
174
174
  # Enforce double precision to ensure reproducibility
175
175
  data.convert_to_float64()
@@ -447,9 +447,7 @@ class SurfaceForcing:
447
447
  "lat": data.ds[data.dim_names["latitude"]],
448
448
  "lon": data.ds[data.dim_names["longitude"]],
449
449
  }
450
- correction_data.choose_subdomain(
451
- coords_correction, straddle=self.target_coords["straddle"]
452
- )
450
+ correction_data.match_subdomain(coords_correction)
453
451
  correction_data.ds["mask"] = data.ds["mask"] # use mask from ERA5 data
454
452
  correction_data.ds["time"] = correction_data.ds["time"].dt.days
455
453
 
roms_tools/setup/tides.py CHANGED
@@ -80,7 +80,7 @@ class TidalForcing:
80
80
 
81
81
  grid: Grid
82
82
  """Object representing the grid information."""
83
- source: dict[str, str | Path | list[str | Path]]
83
+ source: dict[str, str | Path | dict[str, str | Path]]
84
84
  """Dictionary specifying the source of the tidal data."""
85
85
  ntides: int = 10
86
86
  """Number of constituents to consider."""
@@ -105,7 +105,6 @@ class TidalForcing:
105
105
  if key != "omega":
106
106
  data.choose_subdomain(
107
107
  target_coords,
108
- buffer_points=20,
109
108
  )
110
109
  # Enforce double precision to ensure reproducibility
111
110
  data.convert_to_float64()
@@ -12,7 +12,7 @@ from roms_tools.setup.datasets import ETOPO5Dataset, SRTM15Dataset
12
12
  from roms_tools.setup.utils import handle_boundaries
13
13
 
14
14
 
15
- def _add_topography(
15
+ def add_topography(
16
16
  ds,
17
17
  target_coords,
18
18
  topography_source,
@@ -241,7 +241,7 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
241
241
  rmax_log = 0.0
242
242
 
243
243
  # Apply hmin threshold
244
- h = _clip_depth(h, hmin)
244
+ h = clip_depth(h, hmin)
245
245
 
246
246
  # Perform logarithmic transformation of the height field
247
247
  h_log = np.log(h / hmin)
@@ -324,7 +324,7 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
324
324
  h = hmin * np.exp(h_log)
325
325
 
326
326
  # Apply hmin threshold again
327
- h = _clip_depth(h, hmin)
327
+ h = clip_depth(h, hmin)
328
328
 
329
329
  # Compute maximum slope parameter r
330
330
  r_eta, r_xi = _compute_rfactor(h)
@@ -335,7 +335,7 @@ def _smooth_topography_locally(h, hmin=5, rmax=0.2):
335
335
  return h
336
336
 
337
337
 
338
- def _clip_depth(h: xr.DataArray, hmin: float) -> xr.DataArray:
338
+ def clip_depth(h: xr.DataArray, hmin: float) -> xr.DataArray:
339
339
  """Ensures that depth values do not fall below a minimum threshold.
340
340
 
341
341
  This function replaces all depth values in `h` that are less than `hmin` with `hmin`,
roms_tools/setup/utils.py CHANGED
@@ -1,11 +1,13 @@
1
1
  import importlib.metadata
2
2
  import logging
3
+ import time
4
+ import typing
3
5
  from collections.abc import Sequence
4
6
  from dataclasses import asdict, fields, is_dataclass
5
7
  from datetime import datetime
6
8
  from enum import StrEnum
7
9
  from pathlib import Path
8
- from typing import Any
10
+ from typing import Any, Literal
9
11
 
10
12
  import cftime
11
13
  import numba as nb
@@ -18,11 +20,53 @@ from pydantic import BaseModel
18
20
  from roms_tools.constants import R_EARTH
19
21
  from roms_tools.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
20
22
 
23
+ if typing.TYPE_CHECKING:
24
+ from roms_tools.setup.grid import Grid
25
+
21
26
  yaml.SafeDumper.add_multi_representer(
22
27
  StrEnum,
23
28
  yaml.representer.SafeRepresenter.represent_str,
24
29
  )
25
30
 
31
+ HEADER_SZ = 96
32
+ HEADER_CHAR = "="
33
+
34
+
35
+ def log_the_separator() -> None:
36
+ """Log a separator line using HEADER_CHAR repeated HEADER_SZ times."""
37
+ logging.info(HEADER_CHAR * HEADER_SZ)
38
+
39
+
40
+ class Timed:
41
+ """Context manager to time a block and log messages."""
42
+
43
+ def __init__(self, message: str = "", verbose: bool = True) -> None:
44
+ """
45
+ Initialize the context manager.
46
+
47
+ Parameters
48
+ ----------
49
+ message : str, optional
50
+ A log message printed at the start of the block (default: "").
51
+ verbose : bool, optional
52
+ Whether to log timing information (default: True).
53
+ """
54
+ self.message = message
55
+ self.verbose = verbose
56
+ self.start: float | None = None
57
+
58
+ def __enter__(self) -> "Timed":
59
+ if self.verbose:
60
+ self.start = time.time()
61
+ if self.message:
62
+ logging.info(self.message)
63
+ return self
64
+
65
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
66
+ if self.verbose and self.start is not None:
67
+ logging.info(f"Total time: {time.time() - self.start:.3f} seconds")
68
+ log_the_separator()
69
+
26
70
 
27
71
  def nan_check(field, mask, error_message=None) -> None:
28
72
  """Checks for NaN values at wet points in the field.
@@ -437,157 +481,193 @@ def get_variable_metadata():
437
481
  "long_name": "dissolved inorganic phosphate",
438
482
  "units": "mmol/m^3",
439
483
  "flux_units": "mmol/s",
484
+ "integrated_units": "mmol",
440
485
  },
441
486
  "NO3": {
442
487
  "long_name": "dissolved inorganic nitrate",
443
488
  "units": "mmol/m^3",
444
489
  "flux_units": "mmol/s",
490
+ "integrated_units": "mmol",
445
491
  },
446
492
  "SiO3": {
447
493
  "long_name": "dissolved inorganic silicate",
448
494
  "units": "mmol/m^3",
449
495
  "flux_units": "mmol/s",
496
+ "integrated_units": "mmol",
450
497
  },
451
498
  "NH4": {
452
499
  "long_name": "dissolved ammonia",
453
500
  "units": "mmol/m^3",
454
501
  "flux_units": "mmol/s",
502
+ "integrated_units": "mmol",
455
503
  },
456
504
  "Fe": {
457
505
  "long_name": "dissolved inorganic iron",
458
506
  "units": "mmol/m^3",
459
507
  "flux_units": "mmol/s",
508
+ "integrated_units": "mmol",
460
509
  },
461
510
  "Lig": {
462
511
  "long_name": "iron binding ligand",
463
512
  "units": "mmol/m^3",
464
513
  "flux_units": "mmol/s",
514
+ "integrated_units": "mmol",
465
515
  },
466
516
  "O2": {
467
517
  "long_name": "dissolved oxygen",
468
518
  "units": "mmol/m^3",
469
519
  "flux_units": "mmol/s",
520
+ "integrated_units": "mmol",
470
521
  },
471
522
  "DIC": {
472
523
  "long_name": "dissolved inorganic carbon",
473
524
  "units": "mmol/m^3",
474
525
  "flux_units": "mmol/s",
526
+ "integrated_units": "mmol",
475
527
  },
476
528
  "DIC_ALT_CO2": {
477
529
  "long_name": "dissolved inorganic carbon, alternative CO2",
478
530
  "units": "mmol/m^3",
479
531
  "flux_units": "meq/s",
532
+ "integrated_units": "meq",
533
+ },
534
+ "ALK": {
535
+ "long_name": "alkalinity",
536
+ "units": "meq/m^3",
537
+ "flux_units": "meq/s",
538
+ "integrated_units": "meq",
480
539
  },
481
- "ALK": {"long_name": "alkalinity", "units": "meq/m^3", "flux_units": "meq/s"},
482
540
  "ALK_ALT_CO2": {
483
541
  "long_name": "alkalinity, alternative CO2",
484
542
  "units": "meq/m^3",
485
543
  "flux_units": "meq/s",
544
+ "integrated_units": "meq",
486
545
  },
487
546
  "DOC": {
488
547
  "long_name": "dissolved organic carbon",
489
548
  "units": "mmol/m^3",
490
549
  "flux_units": "mmol/s",
550
+ "integrated_units": "mmol",
491
551
  },
492
552
  "DON": {
493
553
  "long_name": "dissolved organic nitrogen",
494
554
  "units": "mmol/m^3",
495
555
  "flux_units": "mmol/s",
556
+ "integrated_units": "mmol",
496
557
  },
497
558
  "DOP": {
498
559
  "long_name": "dissolved organic phosphorus",
499
560
  "units": "mmol/m^3",
500
561
  "flux_units": "mmol/s",
562
+ "integrated_units": "mmol",
501
563
  },
502
564
  "DOCr": {
503
565
  "long_name": "refractory dissolved organic carbon",
504
566
  "units": "mmol/m^3",
505
567
  "flux_units": "mmol/s",
568
+ "integrated_units": "mmol",
506
569
  },
507
570
  "DONr": {
508
571
  "long_name": "refractory dissolved organic nitrogen",
509
572
  "units": "mmol/m^3",
510
573
  "flux_units": "mmol/s",
574
+ "integrated_units": "mmol",
511
575
  },
512
576
  "DOPr": {
513
577
  "long_name": "refractory dissolved organic phosphorus",
514
578
  "units": "mmol/m^3",
515
579
  "flux_units": "mmol/s",
580
+ "integrated_units": "mmol",
516
581
  },
517
582
  "zooC": {
518
583
  "long_name": "zooplankton carbon",
519
584
  "units": "mmol/m^3",
520
585
  "flux_units": "mmol/s",
586
+ "integrated_units": "mmol",
521
587
  },
522
588
  "spChl": {
523
589
  "long_name": "small phytoplankton chlorophyll",
524
590
  "units": "mg/m^3",
525
591
  "flux_units": "mg/s",
592
+ "integrated_units": "mg",
526
593
  },
527
594
  "spC": {
528
595
  "long_name": "small phytoplankton carbon",
529
596
  "units": "mmol/m^3",
530
597
  "flux_units": "mmol/s",
598
+ "integrated_units": "mmol",
531
599
  },
532
600
  "spP": {
533
601
  "long_name": "small phytoplankton phosphorous",
534
602
  "units": "mmol/m^3",
535
603
  "flux_units": "mmol/s",
604
+ "integrated_units": "mmol",
536
605
  },
537
606
  "spFe": {
538
607
  "long_name": "small phytoplankton iron",
539
608
  "units": "mmol/m^3",
540
609
  "flux_units": "mmol/s",
610
+ "integrated_units": "mmol",
541
611
  },
542
612
  "spCaCO3": {
543
613
  "long_name": "small phytoplankton CaCO3",
544
614
  "units": "mmol/m^3",
545
615
  "flux_units": "mmol/s",
616
+ "integrated_units": "mmol",
546
617
  },
547
618
  "diatChl": {
548
619
  "long_name": "diatom chloropyll",
549
620
  "units": "mg/m^3",
550
621
  "flux_units": "mg/s",
622
+ "integrated_units": "mg",
551
623
  },
552
624
  "diatC": {
553
625
  "long_name": "diatom carbon",
554
626
  "units": "mmol/m^3",
555
627
  "flux_units": "mmol/s",
628
+ "integrated_units": "mmol",
556
629
  },
557
630
  "diatP": {
558
631
  "long_name": "diatom phosphorus",
559
632
  "units": "mmol/m^3",
560
633
  "flux_units": "mmol/s",
634
+ "integrated_units": "mmol",
561
635
  },
562
636
  "diatFe": {
563
637
  "long_name": "diatom iron",
564
638
  "units": "mmol/m^3",
565
639
  "flux_units": "mmol/s",
640
+ "integrated_units": "mmol",
566
641
  },
567
642
  "diatSi": {
568
643
  "long_name": "diatom silicate",
569
644
  "units": "mmol/m^3",
570
645
  "flux_units": "mmol/s",
646
+ "integrated_units": "mmol",
571
647
  },
572
648
  "diazChl": {
573
649
  "long_name": "diazotroph chloropyll",
574
650
  "units": "mg/m^3",
575
651
  "flux_units": "mg/s",
652
+ "integrated_units": "mg",
576
653
  },
577
654
  "diazC": {
578
655
  "long_name": "diazotroph carbon",
579
656
  "units": "mmol/m^3",
580
657
  "flux_units": "mmol/s",
658
+ "integrated_units": "mmol",
581
659
  },
582
660
  "diazP": {
583
661
  "long_name": "diazotroph phosphorus",
584
662
  "units": "mmol/m^3",
585
663
  "flux_units": "mmol/s",
664
+ "integrated_units": "mmol",
586
665
  },
587
666
  "diazFe": {
588
667
  "long_name": "diazotroph iron",
589
668
  "units": "mmol/m^3",
590
669
  "flux_units": "mmol/s",
670
+ "integrated_units": "mmol",
591
671
  },
592
672
  "pco2_air": {"long_name": "atmospheric pCO2", "units": "ppmv"},
593
673
  "pco2_air_alt": {
@@ -720,7 +800,10 @@ def compute_missing_surface_bgc_variables(bgc_data):
720
800
  return bgc_data
721
801
 
722
802
 
723
- def get_tracer_metadata_dict(include_bgc=True, with_flux_units=False):
803
+ def get_tracer_metadata_dict(
804
+ include_bgc: bool = True,
805
+ unit_type: Literal["concentration", "flux", "integrated"] = "concentration",
806
+ ):
724
807
  """Generate a dictionary containing metadata for model tracers.
725
808
 
726
809
  The returned dictionary maps tracer names to their associated units and long names.
@@ -732,9 +815,8 @@ def get_tracer_metadata_dict(include_bgc=True, with_flux_units=False):
732
815
  If True (default), includes biogeochemical tracers in the output.
733
816
  If False, returns only physical tracers (e.g., temperature, salinity).
734
817
 
735
- with_flux_units : bool, optional
736
- If True, uses units appropriate for tracer fluxes (e.g., mmol/s).
737
- If False (default), uses units appropriate for tracer concentrations (e.g., mmol/m³).
818
+ unit_type : str
819
+ One of "concentration" (default), "flux", or "integrated".
738
820
 
739
821
  Returns
740
822
  -------
@@ -784,15 +866,19 @@ def get_tracer_metadata_dict(include_bgc=True, with_flux_units=False):
784
866
 
785
867
  metadata = get_variable_metadata()
786
868
 
787
- tracer_dict = {
788
- tracer: {
789
- "units": metadata[tracer]["flux_units"]
790
- if with_flux_units
791
- else metadata[tracer]["units"],
869
+ tracer_dict = {}
870
+ for tracer in tracer_names:
871
+ if unit_type == "flux":
872
+ unit = metadata[tracer]["flux_units"]
873
+ elif unit_type == "integrated":
874
+ unit = metadata[tracer].get("integrated_units", None)
875
+ else: # default to concentration units
876
+ unit = metadata[tracer]["units"]
877
+
878
+ tracer_dict[tracer] = {
879
+ "units": unit,
792
880
  "long_name": metadata[tracer]["long_name"],
793
881
  }
794
- for tracer in tracer_names
795
- }
796
882
 
797
883
  return tracer_dict
798
884
 
@@ -819,7 +905,8 @@ def add_tracer_metadata_to_ds(ds, include_bgc=True, with_flux_units=False):
819
905
  xarray.Dataset
820
906
  The dataset with added tracer metadata.
821
907
  """
822
- tracer_dict = get_tracer_metadata_dict(include_bgc, with_flux_units)
908
+ unit_type = "flux" if with_flux_units else "concentration"
909
+ tracer_dict = get_tracer_metadata_dict(include_bgc, unit_type=unit_type)
823
910
 
824
911
  tracer_names = list(tracer_dict.keys())
825
912
  tracer_units = [tracer_dict[tracer]["units"] for tracer in tracer_names]
@@ -1045,22 +1132,47 @@ def group_by_year(ds, filepath):
1045
1132
  return dataset_list, output_filenames
1046
1133
 
1047
1134
 
1048
- def get_target_coords(grid, use_coarse_grid=False):
1049
- """Retrieves longitude and latitude coordinates from the grid, adjusting them based
1050
- on longitude range.
1135
+ def get_target_coords(
1136
+ grid: "Grid", use_coarse_grid: bool = False
1137
+ ) -> dict[str, xr.DataArray | bool | None]:
1138
+ """
1139
+ Retrieve longitude, latitude, and auxiliary grid coordinates, adjusting for
1140
+ longitude ranges and coarse grid usage.
1051
1141
 
1052
1142
  Parameters
1053
1143
  ----------
1054
1144
  grid : Grid
1055
- Object representing the grid information used for the model.
1145
+ Grid object.
1056
1146
  use_coarse_grid : bool, optional
1057
- Use coarse grid data if True. Defaults to False.
1147
+ If True, use the coarse grid variables (`lat_coarse`, `lon_coarse`, etc.)
1148
+ instead of the native grid. Defaults to False.
1058
1149
 
1059
1150
  Returns
1060
1151
  -------
1061
- dict
1062
- Dictionary containing the longitude, latitude, angle and mask arrays, along with a boolean indicating
1063
- if the grid straddles the meridian.
1152
+ dict[str, xr.DataArray | bool | None]
1153
+ Dictionary containing the following keys:
1154
+
1155
+ - `"lat"` : xr.DataArray
1156
+ Latitude at rho points.
1157
+ - `"lon"` : xr.DataArray
1158
+ Longitude at rho points, adjusted to -180 to 180 or 0 to 360 range.
1159
+ - `"lat_psi"` : xr.DataArray | None
1160
+ Latitude at psi points, if available.
1161
+ - `"lon_psi"` : xr.DataArray | None
1162
+ Longitude at psi points, if available.
1163
+ - `"angle"` : xr.DataArray
1164
+ Grid rotation angle.
1165
+ - `"mask"` : xr.DataArray | None
1166
+ Land/sea mask at rho points.
1167
+ - `"straddle"` : bool
1168
+ True if the grid crosses the Greenwich meridian, False otherwise.
1169
+
1170
+ Notes
1171
+ -----
1172
+ - If `grid.straddle` is False and the ROMS domain lies more than 5° from
1173
+ the Greenwich meridian, longitudes are adjusted to 0-360 range.
1174
+ - Renaming of coarse grid dimensions is applied to match the rho-point
1175
+ naming convention (`eta_rho`, `xi_rho`) for compatibility with ROMS-Tools.
1064
1176
  """
1065
1177
  # Select grid variables based on whether the coarse grid is used
1066
1178
  if use_coarse_grid: