roms-tools 3.1.1__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. roms_tools/__init__.py +8 -1
  2. roms_tools/analysis/cdr_analysis.py +203 -0
  3. roms_tools/analysis/cdr_ensemble.py +198 -0
  4. roms_tools/analysis/roms_output.py +80 -46
  5. roms_tools/data/grids/GLORYS_global_grid.nc +0 -0
  6. roms_tools/download.py +4 -0
  7. roms_tools/plot.py +131 -30
  8. roms_tools/regrid.py +6 -1
  9. roms_tools/setup/boundary_forcing.py +94 -44
  10. roms_tools/setup/cdr_forcing.py +123 -15
  11. roms_tools/setup/cdr_release.py +161 -8
  12. roms_tools/setup/datasets.py +709 -341
  13. roms_tools/setup/grid.py +167 -139
  14. roms_tools/setup/initial_conditions.py +113 -48
  15. roms_tools/setup/mask.py +63 -7
  16. roms_tools/setup/nesting.py +67 -42
  17. roms_tools/setup/river_forcing.py +45 -19
  18. roms_tools/setup/surface_forcing.py +16 -10
  19. roms_tools/setup/tides.py +1 -2
  20. roms_tools/setup/topography.py +4 -4
  21. roms_tools/setup/utils.py +134 -22
  22. roms_tools/tests/test_analysis/test_cdr_analysis.py +144 -0
  23. roms_tools/tests/test_analysis/test_cdr_ensemble.py +202 -0
  24. roms_tools/tests/test_analysis/test_roms_output.py +61 -3
  25. roms_tools/tests/test_setup/test_boundary_forcing.py +111 -52
  26. roms_tools/tests/test_setup/test_cdr_forcing.py +54 -0
  27. roms_tools/tests/test_setup/test_cdr_release.py +118 -1
  28. roms_tools/tests/test_setup/test_datasets.py +458 -34
  29. roms_tools/tests/test_setup/test_grid.py +238 -121
  30. roms_tools/tests/test_setup/test_initial_conditions.py +94 -41
  31. roms_tools/tests/test_setup/test_surface_forcing.py +28 -3
  32. roms_tools/tests/test_setup/test_utils.py +91 -1
  33. roms_tools/tests/test_setup/test_validation.py +21 -15
  34. roms_tools/tests/test_setup/utils.py +71 -0
  35. roms_tools/tests/test_tiling/test_join.py +241 -0
  36. roms_tools/tests/test_tiling/test_partition.py +45 -0
  37. roms_tools/tests/test_utils.py +224 -2
  38. roms_tools/tiling/join.py +189 -0
  39. roms_tools/tiling/partition.py +44 -30
  40. roms_tools/utils.py +488 -161
  41. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/METADATA +15 -4
  42. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/RECORD +45 -37
  43. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/WHEEL +0 -0
  44. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {roms_tools-3.1.1.dist-info → roms_tools-3.2.0.dist-info}/top_level.txt +0 -0
@@ -2,19 +2,19 @@ import itertools
2
2
  import logging
3
3
  from collections import Counter
4
4
  from collections.abc import Iterator
5
- from datetime import datetime
5
+ from datetime import datetime, timedelta
6
6
  from pathlib import Path
7
7
  from typing import Annotated
8
8
 
9
9
  import matplotlib.gridspec as gridspec
10
10
  import matplotlib.pyplot as plt
11
11
  import numpy as np
12
+ import pandas as pd
12
13
  import xarray as xr
13
14
  from pydantic import (
14
15
  BaseModel,
15
16
  Field,
16
17
  RootModel,
17
- conlist,
18
18
  model_serializer,
19
19
  model_validator,
20
20
  )
@@ -40,6 +40,7 @@ from roms_tools.setup.utils import (
40
40
  from_yaml,
41
41
  gc_dist,
42
42
  get_target_coords,
43
+ get_tracer_metadata_dict,
43
44
  to_dict,
44
45
  validate_names,
45
46
  write_to_yaml,
@@ -103,14 +104,16 @@ class ReleaseSimulationManager(BaseModel):
103
104
  class ReleaseCollector(RootModel):
104
105
  """Collects and validates multiple releases against each other."""
105
106
 
106
- root: conlist(
107
- Annotated[
108
- VolumeRelease | TracerPerturbation, Field(discriminator="release_type")
107
+ root: Annotated[
108
+ list[
109
+ Annotated[
110
+ VolumeRelease | TracerPerturbation, Field(discriminator="release_type")
111
+ ]
109
112
  ],
110
- min_length=1,
111
- ) = Field(alias="releases")
113
+ Field(alias="releases", min_length=1),
114
+ ]
112
115
 
113
- _release_type: ReleaseType = None
116
+ _release_type: ReleaseType | None = None
114
117
 
115
118
  def __iter__(self) -> Iterator[Release]:
116
119
  return iter(self.root)
@@ -126,6 +129,9 @@ class ReleaseCollector(RootModel):
126
129
  else:
127
130
  raise TypeError(f"Invalid key type: {type(item)}. Must be int or str.")
128
131
 
132
+ def __len__(self):
133
+ return len(self.root)
134
+
129
135
  @model_validator(mode="before")
130
136
  @classmethod
131
137
  def unpack_dict(cls, data):
@@ -774,6 +780,62 @@ class CDRForcing(BaseModel):
774
780
  fig.subplots_adjust(hspace=0.45)
775
781
  fig.suptitle(f"Release distribution for: {release_name}")
776
782
 
783
+ def compute_total_cdr_source(self, dt: float) -> pd.DataFrame:
784
+ """
785
+ Compute integrated tracer quantities for all releases and return a DataFrame.
786
+
787
+ Parameters
788
+ ----------
789
+ dt : float
790
+ Time step in seconds for reconstructing ROMS time stamps.
791
+
792
+ Returns
793
+ -------
794
+ pd.DataFrame
795
+ DataFrame with one row per release and one row of units at the top.
796
+ Columns 'temp' and 'salt' are excluded from integrated totals.
797
+ """
798
+ # Reconstruct ROMS time stamps
799
+ _, rel_seconds = _reconstruct_roms_time_stamps(
800
+ self.start_time, self.end_time, dt, self.model_reference_date
801
+ )
802
+
803
+ # Collect accounting results for all releases
804
+ records = []
805
+ release_names = []
806
+ for release in self.releases:
807
+ result = release._do_accounting(rel_seconds, self.model_reference_date)
808
+ records.append(result)
809
+ release_names.append(getattr(release, "name", f"release_{len(records)}"))
810
+
811
+ # Build DataFrame: rows = releases, columns = tracer names
812
+ df = pd.DataFrame(records, index=release_names)
813
+
814
+ # Exclude temp and salt from units row and integrated totals
815
+ integrated_tracers = [col for col in df.columns if col not in ("temp", "salt")]
816
+
817
+ # Add a row of units only for integrated tracers
818
+ tracer_meta = get_tracer_metadata_dict(include_bgc=True, unit_type="integrated")
819
+ units_row = {
820
+ col: tracer_meta.get(col, {}).get("units", "") for col in integrated_tracers
821
+ }
822
+
823
+ df_units = pd.DataFrame([units_row], index=["units"])
824
+
825
+ # Keep only integrated_tracers columns in df, drop temp and salt
826
+ df_integrated = df[integrated_tracers]
827
+
828
+ # Concatenate units row on top
829
+ df_final = pd.concat([df_units, df_integrated])
830
+
831
+ # Store dt as metadata
832
+ df_final.attrs["time_step"] = dt
833
+ df_final.attrs["start_time"] = self.start_time
834
+ df_final.attrs["end_time"] = self.end_time
835
+ df_final.attrs["title"] = "Integrated tracer releases"
836
+
837
+ return df_final
838
+
777
839
  def save(
778
840
  self,
779
841
  filepath: str | Path,
@@ -1039,21 +1101,67 @@ def _map_3d_gaussian(
1039
1101
  # Stack 2D distribution at that vertical level
1040
1102
  distribution_3d[{"s_rho": vertical_idx}] = distribution_2d
1041
1103
  else:
1042
- # Compute layer thickness
1043
- depth_interface = compute_depth_coordinates(
1044
- grid.ds, zeta=0, depth_type="interface", location="rho"
1045
- )
1046
- dz = depth_interface.diff("s_w").rename({"s_w": "s_rho"})
1047
-
1048
1104
  # Compute vertical Gaussian shape
1049
1105
  exponent = -(((depth - release.depth) / release.vsc) ** 2)
1050
1106
  vertical_profile = np.exp(exponent)
1051
1107
 
1052
1108
  # Apply vertical Gaussian scaling
1053
- distribution_3d = distribution_2d * vertical_profile * dz
1109
+ distribution_3d = distribution_2d * vertical_profile
1054
1110
 
1055
1111
  # Normalize
1056
1112
  distribution_3d /= release.vsc * np.sqrt(np.pi)
1057
1113
  distribution_3d /= distribution_3d.sum()
1058
1114
 
1059
1115
  return distribution_3d
1116
+
1117
+
1118
+ def _reconstruct_roms_time_stamps(
1119
+ start_time: datetime,
1120
+ end_time: datetime,
1121
+ dt: float,
1122
+ model_reference_date: datetime,
1123
+ ) -> tuple[list[datetime], np.ndarray]:
1124
+ """
1125
+ Reconstruct ROMS time stamps between `start_time` and `end_time` with step `dt`.
1126
+
1127
+ Parameters
1128
+ ----------
1129
+ start_time : datetime
1130
+ Beginning of the time series.
1131
+ end_time : datetime
1132
+ End of the time series (inclusive if it falls exactly on a step).
1133
+ dt : float
1134
+ Time step in seconds (can be fractional if needed).
1135
+ model_reference_date : datetime
1136
+ The reference date for ROMS time (elapsed time will be relative to this).
1137
+
1138
+ Returns
1139
+ -------
1140
+ times : list of datetime
1141
+ Sequence of datetimes from `start_time` to `end_time`.
1142
+ rel_days : np.ndarray
1143
+ Array of elapsed times in **seconds** relative to `model_reference_date`.
1144
+
1145
+ Raises
1146
+ ------
1147
+ ValueError
1148
+ If `end_time` is not after `start_time` or if `dt` is not positive.
1149
+ """
1150
+ if end_time <= start_time:
1151
+ raise ValueError("end_time must be after start_time")
1152
+ if dt <= 0:
1153
+ raise ValueError("dt must be positive")
1154
+
1155
+ # Generate absolute times
1156
+ delta = timedelta(seconds=dt)
1157
+ times: list[datetime] = []
1158
+ t = start_time
1159
+ while t <= end_time:
1160
+ times.append(t)
1161
+ t += delta
1162
+
1163
+ # Convert to relative ROMS time (days since model_reference_date)
1164
+ rel_days = convert_to_relative_days(times, model_reference_date)
1165
+ rel_seconds = rel_days * 3600 * 24
1166
+
1167
+ return times, rel_seconds
@@ -5,6 +5,8 @@ from datetime import datetime
5
5
  from enum import StrEnum, auto
6
6
  from typing import Annotated, Literal
7
7
 
8
+ import numpy as np
9
+ import pandas as pd
8
10
  from annotated_types import Ge, Le
9
11
  from pydantic import (
10
12
  BaseModel,
@@ -16,10 +18,17 @@ from pydantic import (
16
18
  )
17
19
  from pydantic_core.core_schema import ValidationInfo
18
20
 
19
- from roms_tools.setup.utils import get_tracer_defaults, get_tracer_metadata_dict
21
+ from roms_tools.setup.utils import (
22
+ convert_to_relative_days,
23
+ get_tracer_defaults,
24
+ get_tracer_metadata_dict,
25
+ )
20
26
 
21
27
  NonNegativeFloat = Annotated[float, Ge(0)]
22
28
 
29
+ # Show all columns when printing a DataFrame
30
+ pd.set_option("display.max_columns", None)
31
+
23
32
 
24
33
  @dataclass
25
34
  class ValueArray(ABC):
@@ -272,10 +281,68 @@ class Release(BaseModel):
272
281
  if self.times[-1] < end_time:
273
282
  self.times.append(end_time)
274
283
 
275
- @staticmethod
276
- def get_tracer_metadata():
284
+ @classmethod
285
+ def get_tracer_metadata(cls):
277
286
  return {}
278
287
 
288
+ @classmethod
289
+ def get_metadata(cls):
290
+ return pd.DataFrame(cls.get_tracer_metadata())
291
+
292
+ def _compute_integrated_tracers(
293
+ self,
294
+ roms_time_stamps: np.ndarray,
295
+ model_reference_date: datetime,
296
+ tracer_series_dict: dict[str, np.ndarray],
297
+ ) -> dict[str, float]:
298
+ """
299
+ Compute time-integrated tracer quantities over ROMS time steps using a left-hold rule.
300
+
301
+ This method performs a left-hold (stepwise constant) integration of tracer fluxes
302
+ over the intervals defined by the ROMS time stamps. It first interpolates the
303
+ tracer time series from the release schedule onto the ROMS time stamps, then
304
+ multiplies the value at the start of each interval by the duration of that interval.
305
+
306
+ Parameters
307
+ ----------
308
+ roms_time_stamps : np.ndarray
309
+ 1D array of ROMS model time stamps in seconds since `model_reference_date`.
310
+ Must be strictly increasing and contain at least two entries.
311
+ model_reference_date : datetime
312
+ Reference datetime of the ROMS model calendar, used to compute relative times
313
+ for interpolation.
314
+ tracer_series_dict : dict[str, np.ndarray]
315
+ Dictionary mapping tracer names to 1D arrays of tracer flux values at the
316
+ release schedule times (`self.times`). Each array must have the same length
317
+ as `self.times`.
318
+
319
+ Returns
320
+ -------
321
+ dict[str, float]
322
+ Dictionary mapping each tracer name to its integrated quantity over the
323
+ ROMS time period. Integration is performed using the left-hold rule,
324
+ ignoring the last release point because it defines the end of the final interval.
325
+
326
+ Raises
327
+ ------
328
+ ValueError
329
+ If `roms_time_stamps` has fewer than two entries, since at least one interval
330
+ is required for integration.
331
+ """
332
+ if len(roms_time_stamps) < 2:
333
+ raise ValueError("Need at least two ROMS time stamps to define intervals.")
334
+
335
+ dt = np.diff(roms_time_stamps)
336
+ results = {}
337
+ for tracer, series in tracer_series_dict.items():
338
+ interp_values = np.interp(
339
+ roms_time_stamps,
340
+ convert_to_relative_days(self.times, model_reference_date) * 3600 * 24,
341
+ series,
342
+ )
343
+ results[tracer] = np.sum(interp_values[:-1] * dt)
344
+ return results
345
+
279
346
 
280
347
  class VolumeRelease(Release):
281
348
  """Represents a CDR release with volume flux and tracer concentrations.
@@ -389,9 +456,11 @@ class VolumeRelease(Release):
389
456
  num_times = len(self.times)
390
457
 
391
458
  for tracer_concentrations in self.tracer_concentrations.values():
392
- tracer_concentrations.check_length(num_times)
459
+ if isinstance(tracer_concentrations, Concentration):
460
+ tracer_concentrations.check_length(num_times)
393
461
 
394
- self.volume_fluxes.check_length(num_times)
462
+ if isinstance(self.volume_fluxes, Flux):
463
+ self.volume_fluxes.check_length(num_times)
395
464
 
396
465
  return self
397
466
 
@@ -410,7 +479,52 @@ class VolumeRelease(Release):
410
479
  @staticmethod
411
480
  def get_tracer_metadata():
412
481
  """Returns long names and expected units for the tracer concentrations."""
413
- return get_tracer_metadata_dict(include_bgc=True, with_flux_units=False)
482
+ return get_tracer_metadata_dict(include_bgc=True, unit_type="concentration")
483
+
484
+ def _do_accounting(
485
+ self,
486
+ roms_time_stamps: np.ndarray,
487
+ model_reference_date: datetime,
488
+ ) -> dict[str, float]:
489
+ """
490
+ Compute time-integrated tracer quantities over ROMS time steps.
491
+
492
+ This method interpolates tracer flux time series from the CDR schedule
493
+ onto the provided ROMS time stamps (in seconds since model reference date),
494
+ then applies a "left-hold" rule: the interpolated value at t₀ is applied
495
+ across the full interval [t₀, t₁).
496
+
497
+ Parameters
498
+ ----------
499
+ roms_time_stamps : np.ndarray
500
+ 1D array of ROMS time stamps (seconds since `model_reference_date`).
501
+ Must be strictly increasing.
502
+ model_reference_date : datetime
503
+ Reference date of the ROMS model calendar.
504
+
505
+ Returns
506
+ -------
507
+ dict[str, float]
508
+ Dictionary mapping tracer names to the total integrated quantity over
509
+ the entire ROMS time period. Each value is the sum of the interpolated
510
+ tracer fluxes multiplied by the corresponding ROMS time step durations.
511
+ """
512
+ tracer_series_dict = {}
513
+ volume_array = (
514
+ np.asarray(self.volume_fluxes.values)
515
+ if isinstance(self.volume_fluxes, Flux)
516
+ else np.asarray(self.volume_fluxes)
517
+ )
518
+ for tracer, conc in self.tracer_concentrations.items():
519
+ tracer_array = (
520
+ np.asarray(conc.values)
521
+ if isinstance(conc, Concentration)
522
+ else np.asarray(conc)
523
+ )
524
+ tracer_series_dict[tracer] = volume_array * tracer_array
525
+ return self._compute_integrated_tracers(
526
+ roms_time_stamps, model_reference_date, tracer_series_dict
527
+ )
414
528
 
415
529
  @model_serializer(mode="wrap")
416
530
  def _simplified_dump(self, pydantic_serializer) -> dict:
@@ -503,7 +617,8 @@ class TracerPerturbation(Release):
503
617
  def _check_tracer_flux_lengths(self):
504
618
  num_times = len(self.times)
505
619
  for flux in self.tracer_fluxes.values():
506
- flux.check_length(num_times)
620
+ if isinstance(flux, Flux):
621
+ flux.check_length(num_times)
507
622
  return self
508
623
 
509
624
  def _extend_to_endpoints(self, start_time, end_time):
@@ -520,7 +635,45 @@ class TracerPerturbation(Release):
520
635
  @staticmethod
521
636
  def get_tracer_metadata():
522
637
  """Returns long names and expected units for the tracer fluxes."""
523
- return get_tracer_metadata_dict(include_bgc=True, with_flux_units=True)
638
+ return get_tracer_metadata_dict(include_bgc=True, unit_type="flux")
639
+
640
+ def _do_accounting(
641
+ self,
642
+ roms_time_stamps: np.ndarray,
643
+ model_reference_date: datetime,
644
+ ) -> dict[str, float]:
645
+ """
646
+ Compute time-integrated tracer quantities over ROMS time steps.
647
+
648
+ This method interpolates tracer flux time series from the CDR schedule
649
+ onto the provided ROMS time stamps (in days since model reference date),
650
+ then applies a "left-hold" rule: the interpolated value at t₀ is applied
651
+ across the full interval [t₀, t₁).
652
+
653
+ Parameters
654
+ ----------
655
+ roms_time_stamps : np.ndarray
656
+ 1D array of ROMS time stamps (days since `model_reference_date`).
657
+ Must be strictly increasing.
658
+ model_reference_date : datetime
659
+ Reference date of the ROMS model calendar.
660
+
661
+ Returns
662
+ -------
663
+ dict[str, float]
664
+ Dictionary mapping tracer names to the total integrated quantity over
665
+ the entire ROMS time period. Each value is the sum of the interpolated
666
+ tracer fluxes multiplied by the corresponding ROMS time step durations.
667
+ """
668
+ tracer_series_dict = {
669
+ tracer: np.asarray(flux.values)
670
+ if isinstance(flux, Flux)
671
+ else np.asarray(flux)
672
+ for tracer, flux in self.tracer_fluxes.items()
673
+ }
674
+ return self._compute_integrated_tracers(
675
+ roms_time_stamps, model_reference_date, tracer_series_dict
676
+ )
524
677
 
525
678
  @model_serializer(mode="wrap")
526
679
  def _simplified_dump(self, pydantic_serializer) -> dict: