pycontrails 0.47.3__cp310-cp310-win_amd64.whl → 0.48.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

pycontrails/core/met.py CHANGED
@@ -96,13 +96,14 @@ class MetBase(ABC, Generic[XArrayType]):
96
96
  for dim in self.dim_order:
97
97
  if dim not in self.data.dims:
98
98
  if dim == "level":
99
- raise ValueError(
99
+ msg = (
100
100
  f"Meteorology data must contain dimension '{dim}'. "
101
101
  "For single level data, set 'level' coordinate to constant -1 "
102
102
  "using `ds = ds.expand_dims({'level': [-1]})`"
103
103
  )
104
104
  else:
105
- raise ValueError(f"Meteorology data must contain dimension '{dim}'.")
105
+ msg = f"Meteorology data must contain dimension '{dim}'."
106
+ raise ValueError(msg)
106
107
 
107
108
  def _validate_longitude(self) -> None:
108
109
  """Check longitude bounds.
@@ -196,15 +197,15 @@ class MetBase(ABC, Generic[XArrayType]):
196
197
 
197
198
  dims_tuple = tuple(self.dim_order)
198
199
 
199
- def _check_da(da: xr.DataArray, key: str | None = None) -> None:
200
+ def _check_da(da: xr.DataArray, key: Hashable | None = None) -> None:
200
201
  if da.dims != dims_tuple:
201
202
  if key is not None:
202
203
  msg = (
203
- "Data dimension not transposed on variable '{key}'. Initiate with"
204
- " `copy=True`."
204
+ f"Data dimension not transposed on variable '{key}'. Initiate with"
205
+ " 'copy=True'."
205
206
  )
206
207
  else:
207
- msg = "Data dimension not transposed. Initiate with `copy=True`."
208
+ msg = "Data dimension not transposed. Initiate with 'copy=True'."
208
209
  raise ValueError(msg)
209
210
 
210
211
  data = self.data
@@ -212,7 +213,7 @@ class MetBase(ABC, Generic[XArrayType]):
212
213
  _check_da(data)
213
214
  return
214
215
 
215
- for key, da in self.data.data_vars.items():
216
+ for key, da in self.data.items():
216
217
  _check_da(da, key)
217
218
 
218
219
  def _validate_dims(self) -> None:
@@ -533,7 +534,7 @@ class MetBase(ABC, Generic[XArrayType]):
533
534
 
534
535
  @property
535
536
  def attrs(self) -> dict[Hashable, Any]:
536
- """Pass through to `self.data.attrs`."""
537
+ """Pass through to :attr:`self.data.attrs`."""
537
538
  return self.data.attrs
538
539
 
539
540
  @abstractmethod
@@ -593,6 +594,16 @@ class MetDataset(MetBase):
593
594
  interval ``[-180, 180]``. Defaults to False.
594
595
  copy : bool, optional
595
596
  Copy data on construction. Defaults to True.
597
+ attrs : dict[str, Any], optional
598
+ Attributes to add to :attr:`data.attrs`. Defaults to None.
599
+ Generally, pycontrails :class:`Models` may use the following attributes:
600
+
601
+ - ``provider``: Name of the data provider (e.g. "ECMWF").
602
+ - ``dataset``: Name of the dataset (e.g. "ERA5").
603
+ - ``product``: Name of the product type (e.g. "reanalysis").
604
+
605
+ **attrs_kwargs : Any
606
+ Keyword arguments to add to :attr:`data.attrs`. Defaults to None.
596
607
 
597
608
  Examples
598
609
  --------
@@ -641,10 +652,13 @@ class MetDataset(MetBase):
641
652
  cachestore: CacheStore | None = None,
642
653
  wrap_longitude: bool = False,
643
654
  copy: bool = True,
644
- ):
645
- # init cache
655
+ attrs: dict[str, Any] | None = None,
656
+ **attrs_kwargs: Any,
657
+ ) -> None:
646
658
  self.cachestore = cachestore
647
659
 
660
+ data.attrs.update(attrs or {}, **attrs_kwargs)
661
+
648
662
  # if input is already a Dataset, copy into data
649
663
  if not isinstance(data, xr.Dataset):
650
664
  raise ValueError("Input `data` must be an xarray Dataset")
@@ -661,7 +675,7 @@ class MetDataset(MetBase):
661
675
  self._validate_dims()
662
676
 
663
677
  def __getitem__(self, key: Hashable) -> MetDataArray:
664
- """Return DataArray of variable `key` cast to a `MetDataArray` object.
678
+ """Return DataArray of variable ``key`` cast to a :class:`MetDataArray` object.
665
679
 
666
680
  Parameters
667
681
  ----------
@@ -676,7 +690,7 @@ class MetDataset(MetBase):
676
690
  Raises
677
691
  ------
678
692
  KeyError
679
- If `key` not found in :attr:`data`
693
+ If ``key`` not found in :attr:`data`
680
694
  """
681
695
  try:
682
696
  da = self.data[key]
@@ -789,7 +803,7 @@ class MetDataset(MetBase):
789
803
  yield str(key)
790
804
 
791
805
  def __contains__(self, key: Hashable) -> bool:
792
- """Check if key `key` is in :attr:`data`.
806
+ """Check if key ``key`` is in :attr:`data`.
793
807
 
794
808
  Parameters
795
809
  ----------
@@ -799,7 +813,7 @@ class MetDataset(MetBase):
799
813
  Returns
800
814
  -------
801
815
  bool
802
- True if `key` is in :attr:`data`, False otherwise
816
+ True if ``key`` is in :attr:`data`, False otherwise
803
817
  """
804
818
  return key in self.data
805
819
 
@@ -857,7 +871,7 @@ class MetDataset(MetBase):
857
871
  Raises when dataset does not contain variable in ``vars``
858
872
  """
859
873
  if isinstance(vars, (MetVariable, str)):
860
- vars = [vars]
874
+ vars = (vars,)
861
875
 
862
876
  met_keys: list[str] = []
863
877
  for variable in vars:
@@ -1001,7 +1015,7 @@ class MetDataset(MetBase):
1001
1015
  time [2022-03-01 00:00:00, 2022-03-01 01:00:00]
1002
1016
  longitude [-180.0, 179.75]
1003
1017
  latitude [-90.0, 90.0]
1004
- altitude [10362.848672411146, 11783.938524404566]
1018
+ altitude [10362.8, 11783.9]
1005
1019
  crs EPSG:4326
1006
1020
 
1007
1021
  """
@@ -1022,6 +1036,98 @@ class MetDataset(MetBase):
1022
1036
  vector.attrs.update({str(k): v for k, v in self.attrs.items()})
1023
1037
  return vector
1024
1038
 
1039
+ def _get_pycontrails_attr_template(
1040
+ self,
1041
+ name: str,
1042
+ supported: tuple[str, ...],
1043
+ examples: dict[str, str],
1044
+ ) -> str:
1045
+ """Look up an attribute with a custom error message."""
1046
+ try:
1047
+ out = self.attrs[name]
1048
+ except KeyError as e:
1049
+ msg = f"Specify '{name}' attribute on underlying dataset."
1050
+
1051
+ for i, (k, v) in enumerate(examples.items()):
1052
+ if i == 0:
1053
+ msg = f"{msg} For example, set attrs['{name}'] = '{k}' for {v}."
1054
+ else:
1055
+ msg = f"{msg} Set attrs['{name}'] = '{k}' for {v}."
1056
+ raise KeyError(msg) from e
1057
+
1058
+ if out not in supported:
1059
+ warnings.warn(
1060
+ f"Unknown {name} '{out}'. Data may not be processed correctly. "
1061
+ f"Known {name}s are {supported}. Contact the pycontrails "
1062
+ "developers if you believe this is an error."
1063
+ )
1064
+
1065
+ return out
1066
+
1067
+ @property
1068
+ def provider_attr(self) -> str:
1069
+ """Look up the 'provider' attribute with a custom error message.
1070
+
1071
+ Returns
1072
+ -------
1073
+ str
1074
+ Provider of the data. If not one of 'ECMWF' or 'NCEP',
1075
+ a warning is issued.
1076
+ """
1077
+ supported = ("ECMWF", "NCEP")
1078
+ examples = {"ECMWF": "data provided by ECMWF", "NCEP": "GFS data"}
1079
+ return self._get_pycontrails_attr_template("provider", supported, examples)
1080
+
1081
+ @property
1082
+ def dataset_attr(self) -> str:
1083
+ """Look up the 'dataset' attribute with a custom error message.
1084
+
1085
+ Returns
1086
+ -------
1087
+ str
1088
+ Dataset of the data. If not one of 'ERA5', 'HRES', 'IFS',
1089
+ or 'GFS', a warning is issued.
1090
+ """
1091
+ supported = ("ERA5", "HRES", "IFS", "GFS")
1092
+ examples = {
1093
+ "ERA5": "ECMWF ERA5 reanalysis data",
1094
+ "HRES": "ECMWF HRES forecast data",
1095
+ "GFS": "NCEP GFS forecast data",
1096
+ }
1097
+ return self._get_pycontrails_attr_template("dataset", supported, examples)
1098
+
1099
+ @property
1100
+ def product_attr(self) -> str:
1101
+ """Look up the 'product' attribute with a custom error message.
1102
+
1103
+ Returns
1104
+ -------
1105
+ str
1106
+ Product of the data. If not one of 'forecast', 'ensemble', or 'reanalysis',
1107
+ a warning is issued.
1108
+
1109
+ """
1110
+ supported = ("reanalysis", "forecast", "ensemble")
1111
+ examples = {
1112
+ "reanalysis": "ECMWF ERA5 reanalysis data",
1113
+ "ensemble": "ECMWF ERA5 ensemble member data",
1114
+ }
1115
+ return self._get_pycontrails_attr_template("product", supported, examples)
1116
+
1117
+ def standardize_variables(self, variables: Iterable[MetVariable]) -> None:
1118
+ """Standardize variables **in-place**.
1119
+
1120
+ Parameters
1121
+ ----------
1122
+ variables : Iterable[MetVariable]
1123
+ Data source variables
1124
+
1125
+ See Also
1126
+ --------
1127
+ :func:`standardize_variables`
1128
+ """
1129
+ standardize_variables(self, variables)
1130
+
1025
1131
  @classmethod
1026
1132
  def from_coords(
1027
1133
  cls,
@@ -1177,9 +1283,10 @@ class MetDataArray(MetBase):
1177
1283
  in the case that `copy=True`. Validation only introduces a very small overhead.
1178
1284
  This parameter should only be set to `False` if working with data derived from an
1179
1285
  existing MetDataset or :class`MetDataArray`. By default `True`.
1286
+ name : Hashable, optional
1287
+ Name of the data variable. If not specified, the name will be set to "met".
1180
1288
  **kwargs
1181
- If `data` input is not a xr.DataArray, `data` will be passed
1182
- passed directly to xr.DataArray constructor with these keyword arguments.
1289
+ To be removed in future versions. Passed directly to xr.DataArray constructor.
1183
1290
 
1184
1291
  Examples
1185
1292
  --------
@@ -1218,6 +1325,7 @@ class MetDataArray(MetBase):
1218
1325
  wrap_longitude: bool = False,
1219
1326
  copy: bool = True,
1220
1327
  validate: bool = True,
1328
+ name: Hashable | None = None,
1221
1329
  **kwargs: Any,
1222
1330
  ) -> None:
1223
1331
  # init cache
@@ -1225,6 +1333,10 @@ class MetDataArray(MetBase):
1225
1333
 
1226
1334
  # try to create DataArray out of input data and **kwargs
1227
1335
  if not isinstance(data, xr.DataArray):
1336
+ DeprecationWarning(
1337
+ "Input 'data' must be an xarray DataArray. "
1338
+ "Passing arbitrary kwargs will be removed in future versions."
1339
+ )
1228
1340
  data = xr.DataArray(data, **kwargs)
1229
1341
 
1230
1342
  if copy:
@@ -1237,13 +1349,9 @@ class MetDataArray(MetBase):
1237
1349
  if validate:
1238
1350
  self._validate_dims()
1239
1351
 
1240
- # if name is specified in kwargs, overwrite any other name in DataArray
1241
- if "name" in kwargs:
1242
- self.data.name = kwargs["name"]
1243
-
1244
- # if at this point now "name" exists on DataArray, set default
1245
- if self.data.name is None:
1246
- self.data.name = "met"
1352
+ # Priority: name > data.name > "met"
1353
+ name = name or self.data.name or "met"
1354
+ self.data.name = name
1247
1355
 
1248
1356
  @property
1249
1357
  def values(self) -> np.ndarray:
@@ -1263,7 +1371,7 @@ class MetDataArray(MetBase):
1263
1371
  """
1264
1372
  if not self.in_memory:
1265
1373
  self._check_memory("Extracting numpy array from")
1266
- self.data = self.data.load()
1374
+ self.data.load()
1267
1375
 
1268
1376
  return self.data.values
1269
1377
 
@@ -1509,8 +1617,20 @@ class MetDataArray(MetBase):
1509
1617
  )
1510
1618
 
1511
1619
  def _check_memory(self, msg_start: str) -> None:
1620
+ """Check the memory usage of the underlying data.
1621
+
1622
+ If the data is larger than 4 GB, a warning is issued. If the data is
1623
+ larger than 32 GB, a RuntimeError is raised.
1624
+ """
1512
1625
  n_bytes = self.data.nbytes
1626
+ mb = round(n_bytes / int(1e6), 2)
1627
+ logger.debug("Loading %s into memory consumes %s MB.", self.name, mb)
1628
+
1513
1629
  n_gb = n_bytes // int(1e9)
1630
+ if n_gb <= 4:
1631
+ return
1632
+
1633
+ # Prevent something stupid
1514
1634
  msg = (
1515
1635
  f"{msg_start} MetDataArray {self.name} requires loading "
1516
1636
  f"at least {n_gb} GB of data into memory. Downselect data if possible. "
@@ -1518,13 +1638,9 @@ class MetDataArray(MetBase):
1518
1638
  "with the method 'downselect_met'."
1519
1639
  )
1520
1640
 
1521
- if n_gb > 32: # Prevent something stupid
1641
+ if n_gb > 32:
1522
1642
  raise RuntimeError(msg)
1523
- if n_gb > 4:
1524
- warnings.warn(msg)
1525
-
1526
- mb = round(n_bytes / int(1e6), 2)
1527
- logger.debug("Loading %s into memory consumes %s MB.", self.name, mb)
1643
+ warnings.warn(msg)
1528
1644
 
1529
1645
  def save(self, **kwargs: Any) -> list[str]:
1530
1646
  """Save intermediate to :attr:`cachestore` as netcdf.
@@ -2119,24 +2235,29 @@ def _is_zarr(ds: xr.Dataset | xr.DataArray) -> bool:
2119
2235
  return dask0.array.array.array.__class__.__name__ == "ZarrArrayWrapper"
2120
2236
 
2121
2237
 
2122
- def shift_longitude(data: XArrayType) -> XArrayType:
2123
- """Shift longitude values from [0, 360) to [-180, 180) domain.
2238
+ def shift_longitude(data: XArrayType, bound: float = -180.0) -> XArrayType:
2239
+ """Shift longitude values from any input domain to [bound, 360 + bound) domain.
2124
2240
 
2125
2241
  Sorts data by ascending longitude values.
2126
2242
 
2243
+
2127
2244
  Parameters
2128
2245
  ----------
2129
2246
  data : XArrayType
2130
2247
  :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2248
+ bound : float, optional
2249
+ Lower bound of the domain.
2250
+ Output domain will be [bound, 360 + bound).
2251
+ Defaults to -180, which results in longitude domain [-180, 180).
2131
2252
 
2132
2253
 
2133
2254
  Returns
2134
2255
  -------
2135
2256
  XArrayType
2136
- :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [-180, 180).
2257
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [a, 360 + a).
2137
2258
  """
2138
2259
  return data.assign_coords(
2139
- longitude=((data["longitude"].values + 180.0) % 360.0) - 180.0
2260
+ longitude=((data["longitude"].values - bound) % 360.0) + bound
2140
2261
  ).sortby("longitude", ascending=True)
2141
2262
 
2142
2263
 
@@ -2389,9 +2510,12 @@ def originates_from_ecmwf(met: MetDataset | MetDataArray) -> bool:
2389
2510
  - :class:`HRES`
2390
2511
 
2391
2512
  """
2392
- return met.attrs.get("met_source") in ("ERA5", "HRES") or "ecmwf" in met.attrs.get(
2393
- "history", ""
2394
- )
2513
+ if isinstance(met, MetDataset):
2514
+ try:
2515
+ return met.provider_attr == "ECMWF"
2516
+ except KeyError:
2517
+ pass
2518
+ return "ecmwf" in met.attrs.get("history", "")
2395
2519
 
2396
2520
 
2397
2521
  def _load(hash: str, cachestore: CacheStore, chunks: dict[str, int]) -> xr.Dataset:
@@ -2,24 +2,24 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from dataclasses import dataclass
5
+ import dataclasses
6
6
 
7
7
 
8
- @dataclass
8
+ @dataclasses.dataclass(frozen=True)
9
9
  class MetVariable:
10
10
  """Met variable defined using CF, ECMWF, and WMO conventions.
11
11
 
12
12
  When there is a conflict between CF, ECMWF, and WMO conventions,
13
- CF takes precendence, then WMO, then ECMWF.
13
+ CF takes precedence, then WMO, then ECMWF.
14
14
 
15
15
  References
16
16
  ----------
17
17
  - `CF Standard Names, version 77
18
18
  <https://cfconventions.org/Data/cf-standard-names/77/build/cf-standard-name-table.html>`_
19
19
  - `ECMWF Parameter Database <https://apps.ecmwf.int/codes/grib/param-db>`_
20
- - `NCEP Grib v1 Code Table <https://www.nco.ncep.noaa.gov/pmb/docs/on388/table2.html>`
20
+ - `NCEP Grib v1 Code Table <https://www.nco.ncep.noaa.gov/pmb/docs/on388/table2.html>`_
21
21
  - `WMO Codes Registry, Grib Edition 2 <https://codes.wmo.int/_grib2>`_
22
- - `NCEP Grib v2 Code Table <https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-2.shtml>`
22
+ - `NCEP Grib v2 Code Table <https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-2.shtml>`_
23
23
 
24
24
  Used for defining support parameters in a grib-like fashion.
25
25
  """ # noqa: E501
@@ -70,7 +70,7 @@ class MetVariable:
70
70
  ValueError
71
71
  If any of the inputs have an unknown :attr:`level_type`.
72
72
  """
73
- level_types = ["surface", "isobaricInPa", "isobaricInhPa", "nominalTop"]
73
+ level_types = ("surface", "isobaricInPa", "isobaricInhPa", "nominalTop")
74
74
  if self.level_type is not None and self.level_type not in level_types:
75
75
  raise ValueError(f"`level_type` must be one of {level_types}")
76
76
 
@@ -102,8 +102,8 @@ class MetVariable:
102
102
  """
103
103
 
104
104
  # return only these keys if they are not None
105
- keys = ["short_name", "standard_name", "long_name", "units"]
106
- return {k: getattr(self, k) for k in keys if getattr(self, k, None) is not None}
105
+ keys = ("short_name", "standard_name", "long_name", "units")
106
+ return {k: v for k in keys if (v := getattr(self, k, None)) is not None}
107
107
 
108
108
 
109
109
  # ----
@@ -135,7 +135,7 @@ Altitude = MetVariable(
135
135
  amip="ta",
136
136
  description=(
137
137
  "Altitude is the (geometric) height above the geoid, which is the "
138
- "reference geopotential surface. The geoid is similar to mean sea level."
138
+ "reference geopotential surface. The geoid is similar to mean sea level."
139
139
  ),
140
140
  )
141
141
 
@@ -692,6 +692,33 @@ class Model(ABC):
692
692
  if self.params["interpolation_use_indices"] and isinstance(self.source, GeoVectorDataset):
693
693
  self.source._invalidate_indices()
694
694
 
695
+ def transfer_met_source_attrs(self, source: SourceType | None = None) -> None:
696
+ """Transfer met source metadata from :attr:`met` to ``source``."""
697
+
698
+ if self.met is None:
699
+ return
700
+
701
+ source = source or self.source
702
+ try:
703
+ source.attrs["met_source_provider"] = self.met.provider_attr
704
+ except KeyError:
705
+ pass
706
+
707
+ try:
708
+ source.attrs["met_source_dataset"] = self.met.dataset_attr
709
+ except KeyError:
710
+ pass
711
+
712
+ try:
713
+ source.attrs["met_source_product"] = self.met.product_attr
714
+ except KeyError:
715
+ pass
716
+
717
+ try:
718
+ source.attrs["met_source_forecast_time"] = self.met.attrs["forecast_time"]
719
+ except KeyError:
720
+ pass
721
+
695
722
 
696
723
  def _interp_grid_to_grid(
697
724
  met_key: str,