pycontrails 0.52.1__cp311-cp311-win_amd64.whl → 0.52.3__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycontrails/core/met.py CHANGED
@@ -9,7 +9,15 @@ import pathlib
9
9
  import typing
10
10
  import warnings
11
11
  from abc import ABC, abstractmethod
12
- from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence
12
+ from collections.abc import (
13
+ Generator,
14
+ Hashable,
15
+ Iterable,
16
+ Iterator,
17
+ Mapping,
18
+ MutableMapping,
19
+ Sequence,
20
+ )
13
21
  from contextlib import ExitStack
14
22
  from datetime import datetime
15
23
  from typing import (
@@ -1502,6 +1510,7 @@ class MetDataArray(MetBase):
1502
1510
  bounds_error: bool = ...,
1503
1511
  fill_value: float | np.float64 | None = ...,
1504
1512
  localize: bool = ...,
1513
+ lowmem: bool = ...,
1505
1514
  indices: interpolation.RGIArtifacts | None = ...,
1506
1515
  return_indices: Literal[False] = ...,
1507
1516
  ) -> npt.NDArray[np.float64]: ...
@@ -1518,6 +1527,7 @@ class MetDataArray(MetBase):
1518
1527
  bounds_error: bool = ...,
1519
1528
  fill_value: float | np.float64 | None = ...,
1520
1529
  localize: bool = ...,
1530
+ lowmem: bool = ...,
1521
1531
  indices: interpolation.RGIArtifacts | None = ...,
1522
1532
  return_indices: Literal[True],
1523
1533
  ) -> tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]: ...
@@ -1533,6 +1543,7 @@ class MetDataArray(MetBase):
1533
1543
  bounds_error: bool = False,
1534
1544
  fill_value: float | np.float64 | None = np.nan,
1535
1545
  localize: bool = False,
1546
+ lowmem: bool = False,
1536
1547
  indices: interpolation.RGIArtifacts | None = None,
1537
1548
  return_indices: bool = False,
1538
1549
  ) -> npt.NDArray[np.float64] | tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]:
@@ -1540,7 +1551,9 @@ class MetDataArray(MetBase):
1540
1551
 
1541
1552
  Zero dimensional coordinates are reshaped to 1D arrays.
1542
1553
 
1543
- Method automatically loads underlying :attr:`data` into memory.
1554
+ If ``lowmem == False``, method automatically loads underlying :attr:`data` into
1555
+ memory. Otherwise, method iterates through smaller subsets of :attr:`data` and releases
1556
+ subsets from memory once interpolation against each subset is finished.
1544
1557
 
1545
1558
  If ``method == "nearest"``, the out array will have the same ``dtype`` as
1546
1559
  the underlying :attr:`data`.
@@ -1586,10 +1599,18 @@ class MetDataArray(MetBase):
1586
1599
  localize: bool, optional
1587
1600
  Experimental. If True, downselect gridded data to smallest bounding box containing
1588
1601
  all points. By default False.
1602
+ lowmem: bool, optional
1603
+ Experimental. If True, iterate through points binned by the time coordinate of the
1604
+ grided data, and downselect gridded data to the smallest bounding box containing
1605
+ each binned set of point *before loading into memory*. This can significantly reduce
1606
+ memory consumption with large numbers of points at the cost of increased runtime.
1607
+ By default False.
1589
1608
  indices: tuple | None, optional
1590
1609
  Experimental. See :func:`interpolation.interp`. None by default.
1591
1610
  return_indices: bool, optional
1592
1611
  Experimental. See :func:`interpolation.interp`. False by default.
1612
+ Note that values returned differ when ``lowmem=True`` and ``lowmem=False``,
1613
+ so output should only be re-used in calls with the same ``lowmem`` value.
1593
1614
 
1594
1615
  Returns
1595
1616
  -------
@@ -1632,10 +1653,29 @@ class MetDataArray(MetBase):
1632
1653
  >>> level = np.linspace(200, 300, 10)
1633
1654
  >>> time = pd.date_range("2022-03-01T14", periods=10, freq="5min")
1634
1655
  >>> mda.interpolate(longitude, latitude, level, time)
1656
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1657
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1658
+ 239.99274623, 242.10792167])
1659
+
1660
+ >>> # Can easily switch to alternative low-memory implementation
1661
+ >>> mda.interpolate(longitude, latitude, level, time, lowmem=True)
1635
1662
  array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1636
1663
  231.10858599, 233.54857391, 235.71504913, 237.86478872,
1637
1664
  239.99274623, 242.10792167])
1638
1665
  """
1666
+ if lowmem:
1667
+ return self._interp_lowmem(
1668
+ longitude,
1669
+ latitude,
1670
+ level,
1671
+ time,
1672
+ method=method,
1673
+ bounds_error=bounds_error,
1674
+ fill_value=fill_value,
1675
+ indices=indices,
1676
+ return_indices=return_indices,
1677
+ )
1678
+
1639
1679
  # Load if necessary
1640
1680
  if not self.in_memory:
1641
1681
  self._check_memory("Interpolation over")
@@ -1660,6 +1700,100 @@ class MetDataArray(MetBase):
1660
1700
  return_indices=return_indices,
1661
1701
  )
1662
1702
 
1703
+ def _interp_lowmem(
1704
+ self,
1705
+ longitude: float | npt.NDArray[np.float64],
1706
+ latitude: float | npt.NDArray[np.float64],
1707
+ level: float | npt.NDArray[np.float64],
1708
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1709
+ *,
1710
+ method: str = "linear",
1711
+ bounds_error: bool = False,
1712
+ fill_value: float | np.float64 | None = np.nan,
1713
+ minimize_memory: bool = False,
1714
+ indices: interpolation.RGIArtifacts | None = None,
1715
+ return_indices: bool = False,
1716
+ ) -> npt.NDArray[np.float64] | tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]:
1717
+ """Interpolate values against underlying DataArray.
1718
+
1719
+ This method is used by :meth:`interpolate` when ``lowmem=True``.
1720
+ Parameters and return types are identical to :meth:`interpolate`, except
1721
+ that the ``localize`` keyword argument is omitted.
1722
+ """
1723
+ # Convert all inputs to 1d arrays
1724
+ # Not validating against ndim >= 2
1725
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1726
+
1727
+ if bounds_error:
1728
+ _lowmem_boundscheck(time, self.data)
1729
+
1730
+ # Create buffers for holding interpolation output
1731
+ # Use np.full rather than np.empty so points not covered
1732
+ # by masks are filled with correct out-of-bounds values.
1733
+ out = np.full(longitude.shape, fill_value, dtype=self.data.dtype)
1734
+ if return_indices:
1735
+ rgi_artifacts = interpolation.RGIArtifacts(
1736
+ xi_indices=np.full((4, longitude.size), -1, dtype=np.int64),
1737
+ norm_distances=np.full((4, longitude.size), np.nan, dtype=np.float64),
1738
+ out_of_bounds=np.full((longitude.size,), True, dtype=np.bool_),
1739
+ )
1740
+
1741
+ # Iterate over portions of points between adjacent time steps in gridded data
1742
+ for mask in _lowmem_masks(time, self.data["time"].values):
1743
+ if mask is None or not np.any(mask):
1744
+ continue
1745
+
1746
+ lon_sl = longitude[mask]
1747
+ lat_sl = latitude[mask]
1748
+ lev_sl = level[mask]
1749
+ t_sl = time[mask]
1750
+ if indices is not None:
1751
+ indices_sl = interpolation.RGIArtifacts(
1752
+ xi_indices=indices.xi_indices[:, mask],
1753
+ norm_distances=indices.norm_distances[:, mask],
1754
+ out_of_bounds=indices.out_of_bounds[mask],
1755
+ )
1756
+ else:
1757
+ indices_sl = None
1758
+
1759
+ coords = {"longitude": lon_sl, "latitude": lat_sl, "level": lev_sl, "time": t_sl}
1760
+ if any(np.all(np.isnan(coord)) for coord in coords.values()):
1761
+ continue
1762
+ da = interpolation._localize(self.data, coords)
1763
+ if not da._in_memory:
1764
+ logger.debug(
1765
+ "Loading %s MB subset of %s into memory.",
1766
+ round(da.nbytes / 1_000_000, 2),
1767
+ da.name,
1768
+ )
1769
+ da.load()
1770
+
1771
+ tmp = interpolation.interp(
1772
+ longitude=lon_sl,
1773
+ latitude=lat_sl,
1774
+ level=lev_sl,
1775
+ time=t_sl,
1776
+ da=da,
1777
+ method=method,
1778
+ bounds_error=bounds_error,
1779
+ fill_value=fill_value,
1780
+ localize=False, # would be no-op; da is localized already
1781
+ indices=indices_sl,
1782
+ return_indices=return_indices,
1783
+ )
1784
+
1785
+ if return_indices:
1786
+ out[mask], rgi_sl = tmp
1787
+ rgi_artifacts.xi_indices[:, mask] = rgi_sl.xi_indices
1788
+ rgi_artifacts.norm_distances[:, mask] = rgi_sl.norm_distances
1789
+ rgi_artifacts.out_of_bounds[mask] = rgi_sl.out_of_bounds
1790
+ else:
1791
+ out[mask] = tmp
1792
+
1793
+ if return_indices:
1794
+ return out, rgi_artifacts
1795
+ return out
1796
+
1663
1797
  def _check_memory(self, msg_start: str) -> None:
1664
1798
  """Check the memory usage of the underlying data.
1665
1799
 
@@ -2656,3 +2790,46 @@ def _add_vertical_coords(data: XArrayType) -> XArrayType:
2656
2790
  data.coords["altitude"] = data.coords["altitude"].astype(dtype, copy=False)
2657
2791
 
2658
2792
  return data
2793
+
2794
+
2795
+ def _lowmem_boundscheck(time: npt.NDArray[np.datetime64], da: xr.DataArray) -> None:
2796
+ """Extra bounds check required with low-memory interpolation strategy.
2797
+
2798
+ Because the main loop in `_interp_lowmem` processes points between time steps
2799
+ in gridded data, it will never encounter points that are out-of-bounds in time
2800
+ and may fail to produce requested out-of-bounds errors.
2801
+ """
2802
+ da_time = da["time"].to_numpy()
2803
+ if not np.all((time >= da_time.min()) & (time <= da_time.max())):
2804
+ axis = da.get_axis_num("time")
2805
+ msg = f"One of the requested xi is out of bounds in dimension {axis}"
2806
+ raise ValueError(msg)
2807
+
2808
+
2809
+ def _lowmem_masks(
2810
+ time: npt.NDArray[np.datetime64], t_met: npt.NDArray[np.datetime64]
2811
+ ) -> Generator[npt.NDArray[np.bool_], None, None]:
2812
+ """Generate sequence of masks for low-memory interpolation."""
2813
+ t_met_max = t_met.max()
2814
+ t_met_min = t_met.min()
2815
+ inbounds = (time >= t_met_min) & (time <= t_met_max)
2816
+ if not np.any(inbounds):
2817
+ return
2818
+
2819
+ earliest = np.nanmin(time)
2820
+ istart = 0 if earliest < t_met_min else np.flatnonzero(t_met <= earliest).max()
2821
+ latest = np.nanmax(time)
2822
+ iend = t_met.size - 1 if latest > t_met_max else np.flatnonzero(t_met >= latest).min()
2823
+ if istart == iend:
2824
+ yield inbounds
2825
+ return
2826
+
2827
+ # Sequence of masks covers elements in time in the interval [t_met[istart], t_met[iend]].
2828
+ # The first iteration masks elements in the interval [t_met[istart], t_met[istart+1]]
2829
+ # (inclusive of both endpoints).
2830
+ # Subsequent iterations mask elements in the interval (t_met[i], t_met[i+1]]
2831
+ # (inclusive of right endpoint only).
2832
+ for i in range(istart, iend):
2833
+ mask = ((time >= t_met[i]) if i == istart else (time > t_met[i])) & (time <= t_met[i + 1])
2834
+ if np.any(mask):
2835
+ yield mask
@@ -362,6 +362,8 @@ class Model(ABC):
362
362
  def interp_kwargs(self) -> dict[str, Any]:
363
363
  """Shortcut to create interpolation arguments from :attr:`params`.
364
364
 
365
+ The output of this is useful for passing to :func:`interpolate_met`.
366
+
365
367
  Returns
366
368
  -------
367
369
  dict[str, Any]
@@ -376,13 +378,14 @@ class Model(ABC):
376
378
 
377
379
  as determined by :attr:`params`.
378
380
  """
381
+ params = self.params
379
382
  return {
380
- "method": self.params["interpolation_method"],
381
- "bounds_error": self.params["interpolation_bounds_error"],
382
- "fill_value": self.params["interpolation_fill_value"],
383
- "localize": self.params["interpolation_localize"],
384
- "use_indices": self.params["interpolation_use_indices"],
385
- "q_method": self.params["interpolation_q_method"],
383
+ "method": params["interpolation_method"],
384
+ "bounds_error": params["interpolation_bounds_error"],
385
+ "fill_value": params["interpolation_fill_value"],
386
+ "localize": params["interpolation_localize"],
387
+ "use_indices": params["interpolation_use_indices"],
388
+ "q_method": params["interpolation_q_method"],
386
389
  }
387
390
 
388
391
  def require_met(self) -> MetDataset:
@@ -585,16 +588,7 @@ class Model(ABC):
585
588
  KeyError
586
589
  Variable not found in :attr:`source` or :attr:`met`.
587
590
  """
588
- variables: Sequence[MetVariable | tuple[MetVariable, ...]]
589
- if variable is None:
590
- if optional:
591
- variables = (*self.met_variables, *self.optional_met_variables)
592
- else:
593
- variables = self.met_variables
594
- elif isinstance(variable, MetVariable):
595
- variables = (variable,)
596
- else:
597
- variables = variable
591
+ variables = self._determine_relevant_variables(optional, variable)
598
592
 
599
593
  q_method = self.params["interpolation_q_method"]
600
594
 
@@ -640,6 +634,20 @@ class Model(ABC):
640
634
  met_key, da, self.source, self.params, q_method
641
635
  )
642
636
 
637
+ def _determine_relevant_variables(
638
+ self,
639
+ optional: bool,
640
+ variable: MetVariable | Sequence[MetVariable] | None,
641
+ ) -> Sequence[MetVariable | tuple[MetVariable, ...]]:
642
+ """Determine the relevant variables used in :meth:`set_source_met`."""
643
+ if variable is None:
644
+ if optional:
645
+ return (*self.met_variables, *self.optional_met_variables)
646
+ return self.met_variables
647
+ if isinstance(variable, MetVariable):
648
+ return (variable,)
649
+ return variable
650
+
643
651
  # Following python implementation
644
652
  # https://github.com/python/cpython/blob/618b7a8260bb40290d6551f24885931077309590/Lib/collections/__init__.py#L231
645
653
  __marker = object()
@@ -814,6 +822,7 @@ def interpolate_met(
814
822
  vector: GeoVectorDataset,
815
823
  met_key: str,
816
824
  vector_key: str | None = None,
825
+ *,
817
826
  q_method: str | None = None,
818
827
  **interp_kwargs: Any,
819
828
  ) -> npt.NDArray[np.float64]:
@@ -657,7 +657,7 @@ class VectorDataset:
657
657
  8 15 18
658
658
 
659
659
  """
660
- vectors = [v for v in vectors if v] # remove empty vectors
660
+ vectors = [v for v in vectors if v is not None] # remove None values
661
661
 
662
662
  if not vectors:
663
663
  return cls()
@@ -707,36 +707,33 @@ class VectorDataset:
707
707
  bool
708
708
  True if both instances have identical :attr:`data` and :attr:`attrs`.
709
709
  """
710
- if isinstance(other, VectorDataset):
711
- # assert attrs equal
712
- for key in self.attrs:
713
- if isinstance(self.attrs[key], np.ndarray):
714
- # equal_nan not supported for non-numeric data
715
- equal_nan = not np.issubdtype(self.attrs[key].dtype, "O")
716
- try:
717
- eq = np.array_equal(self.attrs[key], other.attrs[key], equal_nan=equal_nan)
718
- except KeyError:
719
- return False
720
- else:
721
- eq = self.attrs[key] == other.attrs[key]
722
-
723
- if not eq:
724
- return False
710
+ if not isinstance(other, VectorDataset):
711
+ return False
725
712
 
726
- # assert data equal
727
- for key in self:
728
- # equal_nan not supported for non-numeric data (e.g. strings)
729
- equal_nan = not np.issubdtype(self[key].dtype, "O")
730
- try:
731
- eq = np.array_equal(self[key], other[key], equal_nan=equal_nan)
732
- except KeyError:
733
- return False
713
+ # Check attrs
714
+ if self.attrs.keys() != other.attrs.keys():
715
+ return False
734
716
 
735
- if not eq:
717
+ for key, val in self.attrs.items():
718
+ if isinstance(val, np.ndarray):
719
+ # equal_nan not supported for non-numeric data
720
+ equal_nan = not np.issubdtype(val.dtype, "O")
721
+ if not np.array_equal(val, other.attrs[key], equal_nan=equal_nan):
736
722
  return False
723
+ elif val != other.attrs[key]:
724
+ return False
725
+
726
+ # Check data
727
+ if self.data.keys() != other.data.keys():
728
+ return False
737
729
 
738
- return True
739
- return False
730
+ for key, val in self.data.items():
731
+ # equal_nan not supported for non-numeric data (e.g. strings)
732
+ equal_nan = not np.issubdtype(val.dtype, "O")
733
+ if not np.array_equal(val, other[key], equal_nan=equal_nan):
734
+ return False
735
+
736
+ return True
740
737
 
741
738
  @property
742
739
  def size(self) -> int:
@@ -986,7 +983,7 @@ class VectorDataset:
986
983
  numeric_attrs = (
987
984
  attr
988
985
  for attr, val in self.attrs.items()
989
- if (isinstance(val, (int, float)) and attr not in ignore_keys)
986
+ if (isinstance(val, (int, float, np.number)) and attr not in ignore_keys)
990
987
  )
991
988
  self.broadcast_attrs(numeric_attrs, overwrite)
992
989
 
@@ -148,7 +148,6 @@ class Cocip(Model):
148
148
 
149
149
  This implementation is regression tested against
150
150
  results from :cite:`teohAviationContrailClimate2022`.
151
- See `tests/benchmark/north-atlantic-study/validate.py`.
152
151
 
153
152
  **Outputs**
154
153
 
@@ -549,6 +548,8 @@ class Cocip(Model):
549
548
  verbose_outputs = self.params["verbose_outputs"]
550
549
 
551
550
  interp_kwargs = self.interp_kwargs
551
+ if self.params["preprocess_lowmem"]:
552
+ interp_kwargs["lowmem"] = True
552
553
  interpolate_met(met, self.source, "air_temperature", **interp_kwargs)
553
554
  interpolate_met(met, self.source, "specific_humidity", **interp_kwargs)
554
555
  interpolate_met(met, self.source, "eastward_wind", "u_wind", **interp_kwargs)
@@ -750,6 +751,8 @@ class Cocip(Model):
750
751
 
751
752
  # get full met grid or flight data interpolated to the pressure level `p_dz`
752
753
  interp_kwargs = self.interp_kwargs
754
+ if self.params["preprocess_lowmem"]:
755
+ interp_kwargs["lowmem"] = True
753
756
  air_temperature_lower = interpolate_met(
754
757
  met,
755
758
  self._sac_flight,
@@ -861,6 +864,8 @@ class Cocip(Model):
861
864
 
862
865
  # get met post wake vortex along initial contrail
863
866
  interp_kwargs = self.interp_kwargs
867
+ if self.params["preprocess_lowmem"]:
868
+ interp_kwargs["lowmem"] = True
864
869
  air_temperature_1 = interpolate_met(met, contrail_1, "air_temperature", **interp_kwargs)
865
870
  interpolate_met(met, contrail_1, "specific_humidity", **interp_kwargs)
866
871
 
@@ -952,11 +957,14 @@ class Cocip(Model):
952
957
  )
953
958
  logger.debug("None are filtered out!")
954
959
 
955
- def _simulate_contrail_evolution(self) -> None:
956
- """Simulate contrail evolution."""
957
- # Calculate all properties for "downwash_contrail" which is
958
- # a contrail representation of the waypoints of the downwash flight.
959
- # The downwash_contrail has already been filtered for initial persistent waypoints.
960
+ def _process_downwash_flight(self) -> tuple[MetDataset | None, MetDataset | None]:
961
+ """Create and calculate properties of contrails created by downwash vortex.
962
+
963
+ ``_downwash_contrail`` is a contrail representation of the waypoints of
964
+ ``_downwash_flight``, which has already been filtered for initial persistent waypoints.
965
+
966
+ Returns MetDatasets for subsequent use if ``preprocess_lowmem=False``.
967
+ """
960
968
  self._downwash_contrail = self._create_downwash_contrail()
961
969
  buffers = {
962
970
  f"{coord}_buffer": self.params[f"met_{coord}_buffer"]
@@ -971,6 +979,8 @@ class Cocip(Model):
971
979
  calc_timestep_geometry(self._downwash_contrail)
972
980
 
973
981
  interp_kwargs = self.interp_kwargs
982
+ if self.params["preprocess_lowmem"]:
983
+ interp_kwargs["lowmem"] = True
974
984
  calc_timestep_meteorology(self._downwash_contrail, met, self.params, **interp_kwargs)
975
985
  calc_shortwave_radiation(rad, self._downwash_contrail, **interp_kwargs)
976
986
  calc_outgoing_longwave_radiation(rad, self._downwash_contrail, **interp_kwargs)
@@ -985,6 +995,16 @@ class Cocip(Model):
985
995
  # Intersect with rad dataset
986
996
  calc_radiative_properties(self._downwash_contrail, self.params)
987
997
 
998
+ if self.params["preprocess_lowmem"]:
999
+ return None, None
1000
+ return met, rad
1001
+
1002
+ def _simulate_contrail_evolution(self) -> None:
1003
+ """Simulate contrail evolution."""
1004
+
1005
+ met, rad = self._process_downwash_flight()
1006
+ interp_kwargs = self.interp_kwargs
1007
+
988
1008
  contrail_contrail_overlapping = self.params["contrail_contrail_overlapping"]
989
1009
  if contrail_contrail_overlapping and not isinstance(self.source, Fleet):
990
1010
  warnings.warn("Contrail-Contrail Overlapping is only valid for Fleet mode.")
@@ -1022,22 +1042,7 @@ class Cocip(Model):
1022
1042
  continue
1023
1043
 
1024
1044
  # Update met, rad slices as needed
1025
- # We need to both interpolate latest_contrail, as well as the "contrail_2"
1026
- # created by calc_timestep_contrail_evolution. This "contrail_2" object
1027
- # has constant time at "time_end", hence the buffer we apply below.
1028
- # After the downwash_contrails is all used up, these updates are intended
1029
- # to happen once each hour
1030
- buffers["time_buffer"] = (
1031
- np.timedelta64(0, "ns"),
1032
- time_end - latest_contrail["time"].max(),
1033
- )
1034
- if time_end > met.indexes["time"].to_numpy()[-1]:
1035
- logger.debug("Downselect met at time_end %s within Cocip evolution", time_end)
1036
- met = latest_contrail.downselect_met(self.met, **buffers, copy=False)
1037
- met = add_tau_cirrus(met)
1038
- if time_end > rad.indexes["time"].to_numpy()[-1]:
1039
- logger.debug("Downselect rad at time_end %s within Cocip evolution", time_end)
1040
- rad = latest_contrail.downselect_met(self.rad, **buffers, copy=False)
1045
+ met, rad = self._maybe_downselect_met_rad(met, rad, latest_contrail, time_end)
1041
1046
 
1042
1047
  # Recalculate latest_contrail with new values
1043
1048
  # NOTE: We are doing a substantial amount of redundant computation here
@@ -1075,6 +1080,75 @@ class Cocip(Model):
1075
1080
 
1076
1081
  self.contrail_list.append(final_contrail)
1077
1082
 
1083
+ def _maybe_downselect_met_rad(
1084
+ self,
1085
+ met: MetDataset | None,
1086
+ rad: MetDataset | None,
1087
+ latest_contrail: GeoVectorDataset,
1088
+ time_end: np.datetime64,
1089
+ ) -> tuple[MetDataset, MetDataset]:
1090
+ """Downselect ``self.met`` and ``self.rad`` if necessary to cover ``time_end``.
1091
+
1092
+ If current ``met`` and ``rad`` slices to not include ``time_end``, new slices are selected
1093
+ from ``self.met`` and ``self.rad``. Downselection in space will cover
1094
+ - locations of current contrails (``latest_contrail``),
1095
+ - locations of additional contrails that will be loaded from ``self._downwash_flight``
1096
+ before the new slices expire,
1097
+ plus a user-defined buffer.
1098
+ """
1099
+ if met is None or time_end > met.indexes["time"].to_numpy()[-1]:
1100
+ logger.debug("Downselect met at time_end %s within Cocip evolution", time_end)
1101
+ met = self._definitely_downselect_met_or_rad(self.met, latest_contrail, time_end)
1102
+ met = add_tau_cirrus(met)
1103
+
1104
+ if rad is None or time_end > rad.indexes["time"].to_numpy()[-1]:
1105
+ logger.debug("Downselect rad at time_end %s within Cocip evolution", time_end)
1106
+ rad = self._definitely_downselect_met_or_rad(self.rad, latest_contrail, time_end)
1107
+
1108
+ return met, rad
1109
+
1110
+ def _definitely_downselect_met_or_rad(
1111
+ self, met: MetDataset, latest_contrail: GeoVectorDataset, time_end: np.datetime64
1112
+ ) -> MetDataset:
1113
+ """Perform downselection when required by :meth:`_maybe_downselect_met_rad`.
1114
+
1115
+ Downselects ``met`` (which should be one of ``self.met`` or ``self.rad``)
1116
+ to cover ``time_end``. Downselection in space covers
1117
+ - locations of current contrails (``latest_contrail``),
1118
+ - locations of additional contrails that will be loaded from ``self._downwash_flight``
1119
+ before the new slices expire,
1120
+ plus a user-defined buffer, as described in :meth:`_maybe_downselect_met_rad`.
1121
+ """
1122
+ # compute lookahead for future contrails from downwash_flight
1123
+ met_time = met.indexes["time"].to_numpy()
1124
+ mask = met_time >= time_end
1125
+ lookahead = np.min(met_time[mask]) if np.any(mask) else time_end
1126
+
1127
+ # create vector for downselection based on current + future contrails
1128
+ future_contrails = self._downwash_flight.filter(
1129
+ (self._downwash_flight["time"] >= time_end)
1130
+ & (self._downwash_flight["time"] <= lookahead),
1131
+ copy=False,
1132
+ )
1133
+ vector = GeoVectorDataset(
1134
+ {
1135
+ key: np.concat((latest_contrail[key], future_contrails[key]))
1136
+ for key in ("longitude", "latitude", "level", "time")
1137
+ }
1138
+ )
1139
+
1140
+ # compute time buffer to ensure downselection extends to time_end
1141
+ buffers = {
1142
+ f"{coord}_buffer": self.params[f"met_{coord}_buffer"]
1143
+ for coord in ("longitude", "latitude", "level")
1144
+ }
1145
+ buffers["time_buffer"] = (
1146
+ np.timedelta64(0, "ns"),
1147
+ max(np.timedelta64(0, "ns"), time_end - vector["time"].max()),
1148
+ )
1149
+
1150
+ return vector.downselect_met(met, **buffers, copy=False)
1151
+
1078
1152
  def _create_downwash_contrail(self) -> GeoVectorDataset:
1079
1153
  """Get Contrail representation of downwash flight."""
1080
1154
 
@@ -1166,49 +1240,43 @@ class Cocip(Model):
1166
1240
  # ---
1167
1241
  # Create contrail dataframe (self.contrail)
1168
1242
  # ---
1169
- dfs = [contrail.dataframe for contrail in self.contrail_list]
1170
- dfs = [df.assign(timestep=t_idx) for t_idx, df in enumerate(dfs)]
1171
- self.contrail = pd.concat(dfs)
1243
+ self.contrail = GeoVectorDataset.sum(self.contrail_list).dataframe
1244
+ self.contrail["timestep"] = np.concatenate(
1245
+ [np.full(c.size, i) for i, c in enumerate(self.contrail_list)]
1246
+ )
1172
1247
 
1173
1248
  # add age in hours to the contrail waypoint outputs
1174
1249
  age_hours = np.empty_like(self.contrail["ef"])
1175
1250
  np.divide(self.contrail["age"], np.timedelta64(1, "h"), out=age_hours)
1176
1251
  self.contrail["age_hours"] = age_hours
1177
1252
 
1178
- if self.params["verbose_outputs"]:
1253
+ verbose_outputs = self.params["verbose_outputs"]
1254
+ if verbose_outputs:
1179
1255
  # Compute dt_integration -- logic is somewhat complicated, but
1180
1256
  # we're simply addressing that the first dt_integration
1181
1257
  # is different from the rest
1182
1258
 
1183
- # We call reset_index twice. The first call introduces an `index`
1184
- # column, and the second introduces a `level_0` column. This `level_0`
1185
- # is a RangeIndex, which we use in the `groupby` to identify the
1259
+ # We call reset_index to introduces an `index` RangeIndex column,
1260
+ # Which we use in the `groupby` to identify the
1186
1261
  # index of the first evolution step at each waypoint.
1187
- # The `level_0` is used to insert back into the `seq_index` dataframe,
1188
- # then it is dropped in replace of the original `index`.
1189
- seq_index = self.contrail.reset_index().reset_index()
1190
- cols = ["formation_time", "time", "level_0"]
1191
- first_form_time = seq_index.groupby("waypoint")[cols].first()
1262
+ tmp = self.contrail.reset_index()
1263
+ cols = ["formation_time", "time", "index"]
1264
+ first_form_time = tmp.groupby("waypoint")[cols].first()
1192
1265
  first_dt = first_form_time["time"] - first_form_time["formation_time"]
1193
- first_dt.index = first_form_time["level_0"]
1266
+ first_dt = first_dt.set_axis(first_form_time["index"])
1194
1267
 
1195
- seq_index = seq_index.set_index("level_0")
1196
- seq_index["dt_integration"] = first_dt
1197
- seq_index.fillna({"dt_integration": self.params["dt_integration"]}, inplace=True)
1198
-
1199
- self.contrail = seq_index.set_index("index")
1268
+ self.contrail = tmp.set_index("index")
1269
+ self.contrail["dt_integration"] = first_dt
1270
+ self.contrail.fillna({"dt_integration": self.params["dt_integration"]}, inplace=True)
1200
1271
 
1201
1272
  # ---
1202
1273
  # Create contrail xr.Dataset (self.contrail_dataset)
1203
1274
  # ---
1204
1275
  if isinstance(self.source, Fleet):
1205
- self.contrail_dataset = xr.Dataset.from_dataframe(
1206
- self.contrail.set_index(["flight_id", "timestep", "waypoint"])
1207
- )
1276
+ keys = ["flight_id", "timestep", "waypoint"]
1208
1277
  else:
1209
- self.contrail_dataset = xr.Dataset.from_dataframe(
1210
- self.contrail.set_index(["timestep", "waypoint"])
1211
- )
1278
+ keys = ["timestep", "waypoint"]
1279
+ self.contrail_dataset = xr.Dataset.from_dataframe(self.contrail.set_index(keys))
1212
1280
 
1213
1281
  # ---
1214
1282
  # Create output Flight / Fleet (self.source)
@@ -1229,7 +1297,7 @@ class Cocip(Model):
1229
1297
  ]
1230
1298
 
1231
1299
  # add additional columns
1232
- if self.params["verbose_outputs"]:
1300
+ if verbose_outputs:
1233
1301
  sac_cols += ["dT_dz", "ds_dz", "dz_max"]
1234
1302
 
1235
1303
  downwash_cols = ["rho_air_1", "iwc_1", "n_ice_per_m_1"]
@@ -1253,7 +1321,7 @@ class Cocip(Model):
1253
1321
 
1254
1322
  rad_keys = ["sdr", "rsr", "olr", "rf_sw", "rf_lw", "rf_net"]
1255
1323
  for key in rad_keys:
1256
- if self.params["verbose_outputs"]:
1324
+ if verbose_outputs:
1257
1325
  agg_dict[key] = ["mean", "min", "max"]
1258
1326
  else:
1259
1327
  agg_dict[key] = ["mean"]