pycontrails 0.47.3__cp311-cp311-macosx_11_0_arm64.whl → 0.48.1__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

@@ -18,7 +18,7 @@ from pycontrails.core import coordinates, interpolation
18
18
  from pycontrails.core import met as met_module
19
19
  from pycontrails.physics import units
20
20
  from pycontrails.utils import dependencies
21
- from pycontrails.utils import json as json_module
21
+ from pycontrails.utils import json as json_utils
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
@@ -492,7 +492,7 @@ class VectorDataset:
492
492
  return self.size
493
493
 
494
494
  def _display_attrs(self) -> dict[str, str]:
495
- """Return properties used in `repr` constructions`.
495
+ """Return properties used in `repr` constructions.
496
496
 
497
497
  Returns
498
498
  -------
@@ -515,7 +515,7 @@ class VectorDataset:
515
515
  n_keys = len(self.data)
516
516
  _repr = f"{class_name} [{n_keys} keys x {self.size} length, {n_attrs} attributes]"
517
517
 
518
- keys = list(self.data.keys())
518
+ keys = list(self)
519
519
  keys = keys[0:5] + ["..."] + keys[-1:] if len(keys) > 5 else keys
520
520
  _repr += f"\n\tKeys: {', '.join(keys)}"
521
521
 
@@ -633,6 +633,8 @@ class VectorDataset:
633
633
  8 15 18
634
634
 
635
635
  """
636
+ vectors = [v for v in vectors if v] # remove empty vectors
637
+
636
638
  if not vectors:
637
639
  return cls()
638
640
 
@@ -753,7 +755,7 @@ class VectorDataset:
753
755
  str
754
756
  Unique hash for flight instance (sha1)
755
757
  """
756
- _hash = json.dumps(self.data, cls=json_module.NumpyEncoder)
758
+ _hash = json.dumps(self.data, cls=json_utils.NumpyEncoder)
757
759
  return hashlib.sha1(bytes(_hash, "utf-8")).hexdigest()
758
760
 
759
761
  # ------------
@@ -951,11 +953,11 @@ class VectorDataset:
951
953
  ignore_keys = (ignore_keys,)
952
954
 
953
955
  # Somewhat brittle: Only checking for int or float type
954
- numeric_attrs = [
956
+ numeric_attrs = (
955
957
  attr
956
958
  for attr, val in self.attrs.items()
957
959
  if (isinstance(val, (int, float)) and attr not in ignore_keys)
958
- ]
960
+ )
959
961
  self.broadcast_attrs(numeric_attrs, overwrite)
960
962
 
961
963
  # ------------
@@ -982,6 +984,104 @@ class VectorDataset:
982
984
  df.attrs = self.attrs
983
985
  return df
984
986
 
987
+ def to_dict(self) -> dict[str, Any]:
988
+ """Create dictionary with :attr:`data` and :attr:`attrs`.
989
+
990
+ If geo-spatial coordinates (e.g. `"latitude"`, `"longitude"`, `"altitude"`)
991
+ are present, round to a reasonable precision. If a `"time"` variable is present,
992
+ round to unix seconds. When the instance is a :class:`GeoVectorDataset`,
993
+ disregard any `"altitude"` or `"level"` coordinate and only include
994
+ `"altitude_ft"`in the output.
995
+
996
+ Returns
997
+ -------
998
+ dict[str, Any]
999
+ Dictionary with :attr:`data` and :attr:`attrs`.
1000
+
1001
+ See Also
1002
+ --------
1003
+ :meth:`from_dict`
1004
+
1005
+ Examples
1006
+ --------
1007
+ >>> import pprint
1008
+ >>> from pycontrails import Flight
1009
+ >>> fl = Flight(
1010
+ ... longitude=[-100, -110],
1011
+ ... latitude=[40, 50],
1012
+ ... level=[200, 200],
1013
+ ... time=[np.datetime64("2020-01-01T09"), np.datetime64("2020-01-01T09:30")],
1014
+ ... aircraft_type="B737",
1015
+ ... )
1016
+ >>> fl = fl.resample_and_fill("5T")
1017
+ >>> pprint.pprint(fl.to_dict())
1018
+ {'aircraft_type': 'B737',
1019
+ 'altitude_ft': [38661.0, 38661.0, 38661.0, 38661.0, 38661.0, 38661.0, 38661.0],
1020
+ 'crs': 'EPSG:4326',
1021
+ 'latitude': [40.0, 41.724, 43.428, 45.111, 46.769, 48.399, 50.0],
1022
+ 'longitude': [-100.0,
1023
+ -101.441,
1024
+ -102.959,
1025
+ -104.563,
1026
+ -106.267,
1027
+ -108.076,
1028
+ -110.0],
1029
+ 'time': [1577869200,
1030
+ 1577869500,
1031
+ 1577869800,
1032
+ 1577870100,
1033
+ 1577870400,
1034
+ 1577870700,
1035
+ 1577871000]}
1036
+ """
1037
+ np_encoder = json_utils.NumpyEncoder()
1038
+
1039
+ # round latitude, longitude, and altitude
1040
+ precision = {"longitude": 3, "latitude": 3, "altitude_ft": 0}
1041
+
1042
+ def encode(key: str, obj: Any) -> Any:
1043
+ # Try to handle some pandas objects
1044
+ if hasattr(obj, "to_numpy"):
1045
+ obj = obj.to_numpy()
1046
+
1047
+ # Convert numpy objects to python objects
1048
+ if isinstance(obj, (np.ndarray, np.generic)):
1049
+
1050
+ # round time to unix seconds
1051
+ if key == "time":
1052
+ return np_encoder.default(obj.astype("datetime64[s]").astype(int))
1053
+
1054
+ # round specific keys in precision
1055
+ try:
1056
+ d = precision[key]
1057
+ except KeyError:
1058
+ return np_encoder.default(obj)
1059
+
1060
+ return np_encoder.default(obj.astype(float).round(d))
1061
+
1062
+ # Pass through everything else
1063
+ return obj
1064
+
1065
+ data = {k: encode(k, v) for k, v in self.data.items()}
1066
+ attrs = {k: encode(k, v) for k, v in self.attrs.items()}
1067
+
1068
+ # Only include one of the vertical coordinate keys
1069
+ if isinstance(self, GeoVectorDataset):
1070
+ data.pop("altitude", None)
1071
+ data.pop("level", None)
1072
+ if "altitude_ft" not in data:
1073
+ data["altitude_ft"] = self.altitude_ft.round(precision["altitude_ft"]).tolist()
1074
+
1075
+ # Issue warning if any keys are duplicated
1076
+ common_keys = data.keys() & attrs.keys()
1077
+ if common_keys:
1078
+ warnings.warn(
1079
+ f"Found duplicate keys in data and attrs: {common_keys}. "
1080
+ "Data keys will overwrite attrs keys in returned dictionary."
1081
+ )
1082
+
1083
+ return {**attrs, **data}
1084
+
985
1085
  @classmethod
986
1086
  def create_empty(
987
1087
  cls: Type[VectorDatasetType],
@@ -1010,6 +1110,42 @@ class VectorDataset:
1010
1110
  """
1011
1111
  return cls(data=_empty_vector_dict(keys or set()), attrs=attrs, copy=False, **attrs_kwargs)
1012
1112
 
1113
+ @classmethod
1114
+ def from_dict(
1115
+ cls: Type[VectorDatasetType], obj: dict[str, Any], copy: bool = True, **obj_kwargs: Any
1116
+ ) -> VectorDatasetType:
1117
+ """Create instance from dict representation containing data and attrs.
1118
+
1119
+ Parameters
1120
+ ----------
1121
+ obj : dict[str, Any]
1122
+ Dict representation of VectorDataset (e.g. :meth:`to_dict`)
1123
+ copy : bool, optional
1124
+ Passed to VectorDataset constructor.
1125
+ Defaults to True.
1126
+ **obj_kwargs : Any
1127
+ Additional properties passed as keyword arguments.
1128
+
1129
+ Returns
1130
+ -------
1131
+ VectorDatasetType
1132
+ VectorDataset instance.
1133
+
1134
+ See Also
1135
+ --------
1136
+ :meth:`to_dict`
1137
+ """
1138
+ data = {}
1139
+ attrs = {}
1140
+
1141
+ for k, v in {**obj, **obj_kwargs}.items():
1142
+ if isinstance(v, (list, np.ndarray)):
1143
+ data[k] = v
1144
+ else:
1145
+ attrs[k] = v
1146
+
1147
+ return cls(data=data, attrs=attrs, copy=copy)
1148
+
1013
1149
  def generate_splits(
1014
1150
  self: VectorDatasetType, n_splits: int, copy: bool = True
1015
1151
  ) -> Generator[VectorDatasetType, None, None]:
@@ -1182,7 +1318,7 @@ class GeoVectorDataset(VectorDataset):
1182
1318
  if not np.issubdtype(time.dtype, np.datetime64):
1183
1319
  warnings.warn("Time data is not np.datetime64. Attempting to coerce.")
1184
1320
  try:
1185
- pd_time = pd.to_datetime(self["time"])
1321
+ pd_time = _handle_time_column(pd.Series(self["time"]))
1186
1322
  except ValueError as e:
1187
1323
  raise ValueError("Could not coerce time data to datetime64.") from e
1188
1324
  np_time = pd_time.to_numpy(dtype="datetime64[ns]")
@@ -1214,12 +1350,17 @@ class GeoVectorDataset(VectorDataset):
1214
1350
  @overrides
1215
1351
  def _display_attrs(self) -> dict[str, str]:
1216
1352
  try:
1217
- time0, time1 = np.nanmin(self["time"]), np.nanmax(self["time"])
1218
- lon0, lon1 = np.nanmin(self["longitude"]), np.nanmax(self["longitude"])
1219
- lat0, lat1 = np.nanmin(self["latitude"]), np.nanmax(self["latitude"])
1220
- alt0, alt1 = np.nanmin(self.altitude), np.nanmax(self.altitude)
1353
+ time0 = pd.Timestamp(np.nanmin(self["time"]))
1354
+ time1 = pd.Timestamp(np.nanmax(self["time"]))
1355
+ lon0 = round(np.nanmin(self["longitude"]), 3)
1356
+ lon1 = round(np.nanmax(self["longitude"]), 3)
1357
+ lat0 = round(np.nanmin(self["latitude"]), 3)
1358
+ lat1 = round(np.nanmax(self["latitude"]), 3)
1359
+ alt0 = round(np.nanmin(self.altitude), 1)
1360
+ alt1 = round(np.nanmax(self.altitude), 1)
1361
+
1221
1362
  attrs = {
1222
- "time": f"[{pd.Timestamp(time0)}, {pd.Timestamp(time1)}]",
1363
+ "time": f"[{time0}, {time1}]",
1223
1364
  "longitude": f"[{lon0}, {lon1}]",
1224
1365
  "latitude": f"[{lat0}, {lat1}]",
1225
1366
  "altitude": f"[{alt0}, {alt1}]",
@@ -1785,7 +1926,7 @@ class GeoVectorDataset(VectorDataset):
1785
1926
  dict[str, Any]
1786
1927
  Python representation of GeoJSON FeatureCollection
1787
1928
  """
1788
- return json_module.dataframe_to_geojson_points(self.dataframe)
1929
+ return json_utils.dataframe_to_geojson_points(self.dataframe)
1789
1930
 
1790
1931
  def to_pseudo_mercator(self: GeoVectorDatasetType, copy: bool = True) -> GeoVectorDatasetType:
1791
1932
  """Convert data from :attr:`attrs["crs"]` to Pseudo Mercator (EPSG:3857).
@@ -1908,29 +2049,112 @@ def vector_to_lon_lat_grid(
1908
2049
 
1909
2050
 
1910
2051
  def _handle_time_column(time: pd.Series) -> pd.Series:
2052
+ """Ensure that pd.Series has compatible Timestamps.
2053
+
2054
+ Parameters
2055
+ ----------
2056
+ time : pd.Series
2057
+ Pandas dataframe column labeled "time".
2058
+
2059
+
2060
+ Returns
2061
+ -------
2062
+ pd.Series
2063
+ Parsed pandas time series.
2064
+
2065
+
2066
+ Raises
2067
+ ------
2068
+ ValueError
2069
+ When time series can't be parsed, or is not timezone naive.
2070
+ """
1911
2071
  if not hasattr(time, "dt"):
1912
- # If the time column is a string, we try to convert it to a datetime
1913
- # If it fails (for example, a unix integer time), we raise an error
1914
- # and let the user figure it out.
1915
- try:
1916
- return pd.to_datetime(time)
1917
- except ValueError as exc:
1918
- raise ValueError(
1919
- "The 'time' field must hold datetime-like values. "
1920
- 'Try data["time"] = pd.to_datetime(data["time"], unit=...) '
1921
- "with the appropriate unit."
1922
- ) from exc
2072
+ time = _parse_pandas_time(time)
1923
2073
 
2074
+ # Translate all times to UTC and then remove timezone.
1924
2075
  # If the time column contains a timezone, the call to `to_numpy`
1925
- # will convert it to an array of object. We do not want this, so
1926
- # we raise an error in this case. Timezone issues are complicated,
1927
- # and so it is better for the user to handle them rather than try
1928
- # to address them here.
2076
+ # will convert it to an array of object.
2077
+ # Note `.tz_convert(None)` automatically converts to UTC first.
1929
2078
  if time.dt.tz is not None:
1930
- raise ValueError(
1931
- "The 'time' field must be timezone naive. "
1932
- "This can be achieved with: "
1933
- 'data["time"] = data["time"].dt.tz_localize(None)'
1934
- )
2079
+ time = time.dt.tz_convert(None)
1935
2080
 
1936
2081
  return time
2082
+
2083
+
2084
+ def _parse_pandas_time(time: pd.Series) -> pd.Series:
2085
+ """Parse pandas dataframe column labelled "time".
2086
+
2087
+ Parameters
2088
+ ----------
2089
+ time : pd.Series
2090
+ Time series
2091
+
2092
+ Returns
2093
+ -------
2094
+ pd.Series
2095
+ Parsed time series
2096
+
2097
+ Raises
2098
+ ------
2099
+ ValueError
2100
+ When series values can't be inferred.
2101
+ """
2102
+ try:
2103
+ # If the time series is a string, try to convert it to a datetime
2104
+ if time.dtype == "O":
2105
+ return pd.to_datetime(time)
2106
+
2107
+ # If the time is an int, try to parse it as unix time
2108
+ if np.issubdtype(time.dtype, np.integer):
2109
+ return _parse_unix_time(time)
2110
+
2111
+ raise ValueError("Unsupported time format")
2112
+
2113
+ except ValueError as exc:
2114
+ raise ValueError(
2115
+ "The 'time' field must hold datetime-like values. "
2116
+ 'Try data["time"] = pd.to_datetime(data["time"], unit=...) '
2117
+ "with the appropriate unit."
2118
+ ) from exc
2119
+
2120
+
2121
+ def _parse_unix_time(time: list[int] | npt.NDArray[np.int_] | pd.Series) -> pd.Series:
2122
+ """Parse array of int times as unix epoch timestamps.
2123
+
2124
+ Attempts to parse the time in "s", "ms", "us", "ns"
2125
+
2126
+
2127
+ Parameters
2128
+ ----------
2129
+ time : list[int] | npt.NDArray[np.int_] | pd.Series
2130
+ Sequence of unix timestamps
2131
+
2132
+
2133
+ Returns
2134
+ -------
2135
+ pd.Series
2136
+ Series of timezone naive pandas Timestamps
2137
+
2138
+ Raises
2139
+ ------
2140
+ ValueError
2141
+ When unable to parse time as unix epoch timestamp
2142
+ """
2143
+ units = "s", "ms", "us", "ns"
2144
+ for unit in units:
2145
+ try:
2146
+ out = pd.to_datetime(time, unit=unit, utc=True)
2147
+ except ValueError:
2148
+ continue
2149
+
2150
+ # make timezone naive
2151
+ out = out.dt.tz_convert(None)
2152
+
2153
+ # make sure time is reasonable
2154
+ if (pd.Timestamp("1980-01-01") <= out).all() and (out <= pd.Timestamp("2030-01-01")).all():
2155
+ return out
2156
+
2157
+ raise ValueError(
2158
+ f"Unable to parse time parameter '{time}' as unix epoch timestamp between "
2159
+ "1980-01-01 and 2030-01-01"
2160
+ )
@@ -2,48 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import datetime
6
5
  import logging
7
- from contextlib import ExitStack
6
+ import os
8
7
  from typing import Any
9
8
 
10
9
  LOG = logging.getLogger(__name__)
11
10
 
12
11
  import numpy as np
12
+ import pandas as pd
13
13
  import xarray as xr
14
14
  from overrides import overrides
15
15
 
16
16
  from pycontrails.core import datalib, met
17
- from pycontrails.utils import temp
18
-
19
-
20
- def rad_accumulated_to_average(mds: met.MetDataset, key: str, dt_accumulation: int) -> None:
21
- """Convert accumulated radiation value to instantaneous average.
22
-
23
- Parameters
24
- ----------
25
- mds : MetDataset
26
- MetDataset containing the accumulated value at ``key``
27
- key : str
28
- Data variable key
29
- dt_accumulation : int
30
- Accumulation time in seconds, [:math:`s`]
31
- """
32
- if key in mds.data and not mds.data[key].attrs.get("_pycontrails_modified", False):
33
- if not np.all(np.diff(mds.data["time"]) == np.timedelta64(dt_accumulation, "s")):
34
- raise ValueError(
35
- f"Dataset expected to have time interval of {dt_accumulation} seconds when"
36
- " converting accumulated parameters"
37
- )
38
-
39
- mds.data[key] = mds.data[key] / dt_accumulation
40
- mds.data[key].attrs["units"] = "W m**-2"
41
- mds.data[key].attrs[
42
- "_pycontrails_modified"
43
- ] = "Accumulated value converted to average instantaneous value"
44
-
45
-
46
- # TODO: Remove this in favor of functional implementation
17
+
18
+
47
19
  class ECMWFAPI(datalib.MetDataSource):
48
20
  """Abstract class for all ECMWF data accessed remotely through CDS / MARS."""
49
21
 
@@ -58,7 +30,6 @@ class ECMWFAPI(datalib.MetDataSource):
58
30
  """
59
31
  return [v.ecmwf_id for v in self.variables if v.ecmwf_id is not None]
60
32
 
61
- # TODO: this could be functional, but there many properties utilized
62
33
  def _process_dataset(self, ds: xr.Dataset, **kwargs: Any) -> met.MetDataset:
63
34
  """Process the :class:`xr.Dataset` opened from cache or local files.
64
35
 
@@ -88,8 +59,8 @@ class ECMWFAPI(datalib.MetDataSource):
88
59
  ds = ds.sel(time=self.timesteps)
89
60
  except KeyError:
90
61
  # this snippet shows the missing times for convenience
91
- np_timesteps = [np.datetime64(t, "ns") for t in self.timesteps]
92
- missing_times = sorted(set(np_timesteps) - set(ds["time"].values))
62
+ np_timesteps = {np.datetime64(t, "ns") for t in self.timesteps}
63
+ missing_times = sorted(np_timesteps.difference(ds["time"].values))
93
64
  raise KeyError(
94
65
  f"Input dataset is missing time coordinates {[str(t) for t in missing_times]}"
95
66
  )
@@ -111,22 +82,6 @@ class ECMWFAPI(datalib.MetDataSource):
111
82
  # harmonize variable names
112
83
  ds = met.standardize_variables(ds, self.variables)
113
84
 
114
- # modify values
115
-
116
- # rescale relative humidity from % -> dimensionless if its in dataset
117
- if "relative_humidity" in ds and not ds["relative_humidity"].attrs.get(
118
- "_pycontrails_modified", False
119
- ):
120
- ds["relative_humidity"] = ds["relative_humidity"] / 100
121
- ds["relative_humidity"].attrs["long_name"] = "Relative humidity"
122
- ds["relative_humidity"].attrs["standard_name"] = "relative_humidity"
123
- ds["relative_humidity"].attrs["units"] = "[0 - 1]"
124
- ds["relative_humidity"].attrs[
125
- "_pycontrails_modified"
126
- ] = "Relative humidity rescaled to [0 - 1] instead of %"
127
-
128
- ds.attrs["met_source"] = type(self).__name__
129
-
130
85
  kwargs.setdefault("cachestore", self.cachestore)
131
86
  return met.MetDataset(ds, **kwargs)
132
87
 
@@ -136,18 +91,12 @@ class ECMWFAPI(datalib.MetDataSource):
136
91
  LOG.debug("Cache is turned off, skipping")
137
92
  return
138
93
 
139
- with ExitStack() as stack:
140
- # group by hour and save one dataset for each hour to temp file
141
- time_group, datasets = zip(*dataset.groupby("time", squeeze=False))
142
-
143
- xarray_temp_filenames = [stack.enter_context(temp.temp_file()) for _ in time_group]
144
- xr.save_mfdataset(datasets, xarray_temp_filenames)
94
+ for t, ds_t in dataset.groupby("time", squeeze=False):
95
+ cache_path = self.create_cachepath(pd.Timestamp(t).to_pydatetime())
96
+ if os.path.exists(cache_path):
97
+ LOG.debug(f"Overwriting existing cache file {cache_path}")
98
+ # This may raise a PermissionError if the file is already open
99
+ # If this is the case, the user should explicitly close the file and try again
100
+ os.remove(cache_path)
145
101
 
146
- # put each hourly file into cache
147
- cache_path = [
148
- self.create_cachepath(
149
- datetime.datetime.fromtimestamp(tg.tolist() / 1e9, datetime.timezone.utc)
150
- )
151
- for tg in time_group
152
- ]
153
- self.cachestore.put_multiple(xarray_temp_filenames, cache_path)
102
+ ds_t.to_netcdf(cache_path)
@@ -21,16 +21,9 @@ from overrides import overrides
21
21
  import pycontrails
22
22
  from pycontrails.core import cache, datalib
23
23
  from pycontrails.core.met import MetDataset, MetVariable
24
- from pycontrails.datalib.ecmwf.common import ECMWFAPI, rad_accumulated_to_average
25
- from pycontrails.datalib.ecmwf.variables import (
26
- PRESSURE_LEVEL_VARIABLES,
27
- SURFACE_VARIABLES,
28
- TOAIncidentSolarRadiation,
29
- TopNetSolarRadiation,
30
- TopNetThermalRadiation,
31
- )
32
- from pycontrails.utils import dependencies
33
- from pycontrails.utils.temp import temp_file
24
+ from pycontrails.datalib.ecmwf.common import ECMWFAPI
25
+ from pycontrails.datalib.ecmwf.variables import PRESSURE_LEVEL_VARIABLES, SURFACE_VARIABLES
26
+ from pycontrails.utils import dependencies, temp
34
27
 
35
28
  if TYPE_CHECKING:
36
29
  import cdsapi
@@ -96,7 +89,7 @@ class ERA5(ECMWFAPI):
96
89
  ERA5 parameter list:
97
90
  https://confluence.ecmwf.int/pages/viewpage.action?pageId=82870405#ERA5:datadocumentation-Parameterlistings
98
91
 
99
- All accumulated radiative quantities are converted to average instantaneous quantities.
92
+ All radiative quantities are accumulated.
100
93
  See https://www.ecmwf.int/sites/default/files/elibrary/2015/18490-radiation-quantities-ecmwf-model-and-mars.pdf
101
94
  for more information.
102
95
 
@@ -176,7 +169,7 @@ class ERA5(ECMWFAPI):
176
169
  if time is None and paths is None:
177
170
  raise ValueError("The parameter 'time' must be defined if 'paths' is None")
178
171
 
179
- supported = {"reanalysis", "ensemble_mean", "ensemble_members", "ensemble_spread"}
172
+ supported = ("reanalysis", "ensemble_mean", "ensemble_members", "ensemble_spread")
180
173
  if product_type not in supported:
181
174
  raise ValueError(
182
175
  f"Unknown product_type {product_type}. "
@@ -388,28 +381,30 @@ class ERA5(ECMWFAPI):
388
381
  # run MetDataset constructor
389
382
  ds = self.open_dataset(disk_cachepaths, **xr_kwargs)
390
383
 
391
- # TODO: corner case
392
384
  # If any files are already cached, they will not have the version attached
393
385
  ds.attrs.setdefault("pycontrails_version", pycontrails.__version__)
394
386
 
395
387
  # run the same ECMWF-specific processing on the dataset
396
388
  mds = self._process_dataset(ds, **kwargs)
397
389
 
398
- # convert accumulated radiation values to average instantaneous values
399
- # set minimum for all values to 0
400
-
401
- # accumulations are 3 hours for ensembles, 1 hour for reanalysis
402
- dt_accumulation = 60 * 60 if self.product_type == "reanalysis" else 3 * 60 * 60
390
+ self.set_metadata(mds)
391
+ return mds
403
392
 
404
- for key in (
405
- TOAIncidentSolarRadiation.standard_name,
406
- TopNetSolarRadiation.standard_name,
407
- TopNetThermalRadiation.standard_name,
408
- ):
409
- if key in mds.data:
410
- rad_accumulated_to_average(mds, key, dt_accumulation)
393
+ @overrides
394
+ def set_metadata(self, ds: xr.Dataset | MetDataset) -> None:
395
+ if self.product_type == "reanalysis":
396
+ product = "reanalysis"
397
+ elif self.product_type.startswith("ensemble"):
398
+ product = "ensemble"
399
+ else:
400
+ msg = f"Unknown product type {self.product_type}"
401
+ raise ValueError(msg)
411
402
 
412
- return mds
403
+ ds.attrs.update(
404
+ provider="ECMWF",
405
+ dataset="ERA5",
406
+ product=product,
407
+ )
413
408
 
414
409
  def _open_and_cache(self, xr_kwargs: dict[str, Any]) -> xr.Dataset:
415
410
  """Open and cache :class:`xr.Dataset` from :attr:`self.paths`.
@@ -478,7 +473,7 @@ class ERA5(ECMWFAPI):
478
473
  # Open ExitStack to control temp_file context manager
479
474
  with ExitStack() as stack:
480
475
  # hold downloaded file in named temp file
481
- cds_temp_filename = stack.enter_context(temp_file())
476
+ cds_temp_filename = stack.enter_context(temp.temp_file())
482
477
  LOG.debug(f"Performing CDS request: {request} to dataset {self.dataset}")
483
478
  if not hasattr(self, "cds"):
484
479
  self._set_cds()