pycontrails 0.51.1__cp312-cp312-win_amd64.whl → 0.52.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

Files changed (40) hide show
  1. pycontrails/__init__.py +1 -1
  2. pycontrails/_version.py +2 -2
  3. pycontrails/core/__init__.py +1 -1
  4. pycontrails/core/cache.py +1 -1
  5. pycontrails/core/flight.py +32 -28
  6. pycontrails/core/polygon.py +1 -1
  7. pycontrails/core/rgi_cython.cp312-win_amd64.pyd +0 -0
  8. pycontrails/datalib/__init__.py +4 -1
  9. pycontrails/datalib/_leo_utils/search.py +250 -0
  10. pycontrails/datalib/_leo_utils/static/bq_roi_query.sql +6 -0
  11. pycontrails/datalib/_leo_utils/vis.py +60 -0
  12. pycontrails/{core/datalib.py → datalib/_met_utils/metsource.py} +1 -1
  13. pycontrails/datalib/ecmwf/arco_era5.py +8 -7
  14. pycontrails/datalib/ecmwf/common.py +3 -2
  15. pycontrails/datalib/ecmwf/era5.py +12 -11
  16. pycontrails/datalib/ecmwf/era5_model_level.py +12 -11
  17. pycontrails/datalib/ecmwf/hres.py +14 -13
  18. pycontrails/datalib/ecmwf/hres_model_level.py +15 -14
  19. pycontrails/datalib/ecmwf/ifs.py +14 -13
  20. pycontrails/datalib/gfs/gfs.py +15 -14
  21. pycontrails/datalib/goes.py +2 -2
  22. pycontrails/datalib/landsat.py +567 -0
  23. pycontrails/datalib/sentinel.py +512 -0
  24. pycontrails/models/apcemm/__init__.py +8 -0
  25. pycontrails/models/apcemm/apcemm.py +983 -0
  26. pycontrails/models/apcemm/inputs.py +226 -0
  27. pycontrails/models/apcemm/static/apcemm_yaml_template.yaml +183 -0
  28. pycontrails/models/apcemm/utils.py +437 -0
  29. pycontrails/models/cocip/__init__.py +2 -0
  30. pycontrails/models/cocip/output_formats.py +165 -0
  31. pycontrails/models/cocipgrid/cocip_grid.py +7 -6
  32. pycontrails/models/dry_advection.py +14 -5
  33. {pycontrails-0.51.1.dist-info → pycontrails-0.52.0.dist-info}/METADATA +20 -11
  34. {pycontrails-0.51.1.dist-info → pycontrails-0.52.0.dist-info}/RECORD +39 -30
  35. pycontrails/datalib/spire/__init__.py +0 -19
  36. /pycontrails/datalib/{spire/spire.py → spire.py} +0 -0
  37. {pycontrails-0.51.1.dist-info → pycontrails-0.52.0.dist-info}/LICENSE +0 -0
  38. {pycontrails-0.51.1.dist-info → pycontrails-0.52.0.dist-info}/NOTICE +0 -0
  39. {pycontrails-0.51.1.dist-info → pycontrails-0.52.0.dist-info}/WHEEL +0 -0
  40. {pycontrails-0.51.1.dist-info → pycontrails-0.52.0.dist-info}/top_level.txt +0 -0
pycontrails/__init__.py CHANGED
@@ -30,7 +30,6 @@ with contextlib.suppress(ImportError):
30
30
  import netCDF4 # noqa: F401
31
31
 
32
32
  from pycontrails.core.cache import DiskCacheStore, GCPCacheStore
33
- from pycontrails.core.datalib import MetDataSource
34
33
  from pycontrails.core.fleet import Fleet
35
34
  from pycontrails.core.flight import Flight, FlightPhase
36
35
  from pycontrails.core.fuel import Fuel, HydrogenFuel, JetA, SAFBlend
@@ -38,6 +37,7 @@ from pycontrails.core.met import MetDataArray, MetDataset
38
37
  from pycontrails.core.met_var import MetVariable
39
38
  from pycontrails.core.models import Model, ModelParams
40
39
  from pycontrails.core.vector import GeoVectorDataset, VectorDataset
40
+ from pycontrails.datalib._met_utils.metsource import MetDataSource
41
41
 
42
42
  __version__ = metadata.version("pycontrails")
43
43
  __license__ = "Apache-2.0"
pycontrails/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.51.1'
16
- __version_tuple__ = version_tuple = (0, 51, 1)
15
+ __version__ = version = '0.52.0'
16
+ __version_tuple__ = version_tuple = (0, 52, 0)
@@ -1,7 +1,6 @@
1
1
  """Core data structures and methods."""
2
2
 
3
3
  from pycontrails.core.cache import DiskCacheStore, GCPCacheStore
4
- from pycontrails.core.datalib import MetDataSource
5
4
  from pycontrails.core.fleet import Fleet
6
5
  from pycontrails.core.flight import Flight
7
6
  from pycontrails.core.fuel import Fuel, HydrogenFuel, JetA, SAFBlend
@@ -9,6 +8,7 @@ from pycontrails.core.met import MetDataArray, MetDataset
9
8
  from pycontrails.core.met_var import MetVariable
10
9
  from pycontrails.core.models import Model, ModelParams
11
10
  from pycontrails.core.vector import GeoVectorDataset, VectorDataset
11
+ from pycontrails.datalib._met_utils.metsource import MetDataSource
12
12
 
13
13
  __all__ = [
14
14
  "DiskCacheStore",
pycontrails/core/cache.py CHANGED
@@ -20,7 +20,7 @@ from pycontrails.utils import dependencies
20
20
 
21
21
  # optional imports
22
22
  if TYPE_CHECKING:
23
- import google
23
+ import google.cloud.storage
24
24
 
25
25
 
26
26
  @functools.cache
@@ -1374,21 +1374,20 @@ class Flight(GeoVectorDataset):
1374
1374
  return {"type": "FeatureCollection", "features": [linestring]}
1375
1375
 
1376
1376
  def to_geojson_multilinestring(
1377
- self, key: str, split_antimeridian: bool = True
1377
+ self, key: str | None = None, split_antimeridian: bool = True
1378
1378
  ) -> dict[str, Any]:
1379
1379
  """Return trajectory as GeoJSON FeatureCollection of MultiLineStrings.
1380
1380
 
1381
- Flight :attr:`data` is grouped according to values of ``key``. Each group gives rise to a
1382
- Feature containing a MultiLineString geometry. LineStrings can be split over the
1383
- antimeridian.
1381
+ If `key` is provided, Flight :attr:`data` is grouped according to values of ``key``.
1382
+ Each group gives rise to a Feature containing a MultiLineString geometry.
1383
+ Each MultiLineString can optionally be split over the antimeridian.
1384
1384
 
1385
1385
  Parameters
1386
1386
  ----------
1387
- key : str
1388
- Name of :attr:`data` column to group by
1387
+ key : str, optional
1388
+ If provided, name of :attr:`data` column to group by.
1389
1389
  split_antimeridian : bool, optional
1390
- Split linestrings that cross the antimeridian.
1391
- Defaults to True
1390
+ Split linestrings that cross the antimeridian. Defaults to True.
1392
1391
 
1393
1392
  Returns
1394
1393
  -------
@@ -1398,31 +1397,41 @@ class Flight(GeoVectorDataset):
1398
1397
  Raises
1399
1398
  ------
1400
1399
  KeyError
1401
- :attr:`data` does not contain column ``key``
1400
+ ``key`` is provided but :attr:`data` does not contain column ``key``
1402
1401
  """
1403
- if key not in self.dataframe.columns:
1402
+ if key is not None and key not in self.dataframe.columns:
1404
1403
  raise KeyError(f"Column {key} does not exist in data.")
1405
1404
 
1406
- jump_index = _antimeridian_index(pd.Series(self["longitude"]), self.attrs["crs"])
1405
+ jump_indices = _antimeridian_index(pd.Series(self["longitude"]), self.attrs["crs"])
1407
1406
 
1408
1407
  def _group_to_feature(group: pd.DataFrame) -> dict[str, str | dict[str, Any]]:
1408
+ # assigns a different value to each group of consecutive indices
1409
1409
  subgrouping = group.index.to_series().diff().ne(1).cumsum()
1410
- # additional splitting at antimeridian
1411
- if jump_index in subgrouping and split_antimeridian:
1412
- subgrouping.loc[jump_index:] += 1
1410
+
1411
+ # increments values after antimeridian crossings
1412
+ if split_antimeridian:
1413
+ for jump_index in jump_indices:
1414
+ if jump_index in subgrouping:
1415
+ subgrouping.loc[jump_index:] += 1
1416
+
1417
+ # creates separate linestrings for sets of points
1418
+ # - with non-consecutive indices
1419
+ # - before and after antimeridian crossings
1413
1420
  multi_ls = [_return_linestring(g) for _, g in group.groupby(subgrouping)]
1414
1421
  geometry = {"type": "MultiLineString", "coordinates": multi_ls}
1415
1422
 
1416
1423
  # adding in static properties
1417
- properties: dict[str, Any] = {key: group.name}
1424
+ properties: dict[str, Any] = {key: group.name} if key is not None else {}
1418
1425
  properties.update(self.constants)
1419
1426
  return {"type": "Feature", "geometry": geometry, "properties": properties}
1420
1427
 
1421
- features = (
1422
- self.dataframe.groupby(key)
1423
- .apply(_group_to_feature, include_groups=False)
1424
- .values.tolist()
1425
- )
1428
+ if key is not None:
1429
+ groups = self.dataframe.groupby(key)
1430
+ else:
1431
+ # create a single group containing all rows of dataframe
1432
+ groups = self.dataframe.groupby(lambda _: 0)
1433
+
1434
+ features = groups.apply(_group_to_feature, include_groups=False).values.tolist()
1426
1435
  return {"type": "FeatureCollection", "features": features}
1427
1436
 
1428
1437
  def to_traffic(self) -> traffic.core.Flight:
@@ -1609,8 +1618,8 @@ def _return_linestring(data: dict[str, npt.NDArray[np.float64]]) -> list[list[fl
1609
1618
  return [list(p) for p in points]
1610
1619
 
1611
1620
 
1612
- def _antimeridian_index(longitude: pd.Series, crs: str = "EPSG:4326") -> int:
1613
- """Return index after flight crosses antimeridian, or -1 if flight does not cross.
1621
+ def _antimeridian_index(longitude: pd.Series, crs: str = "EPSG:4326") -> list[int]:
1622
+ """Return indices after flight crosses antimeridian, or an empty list if flight does not cross.
1614
1623
 
1615
1624
  Parameters
1616
1625
  ----------
@@ -1658,12 +1667,7 @@ def _antimeridian_index(longitude: pd.Series, crs: str = "EPSG:4326") -> int:
1658
1667
  jump21 = longitude[s1.shift() & s2]
1659
1668
  jump_index = pd.concat([jump12, jump21]).index.to_list()
1660
1669
 
1661
- if len(jump_index) > 1:
1662
- raise ValueError("Only implemented for trajectories jumping the antimeridian at most once.")
1663
- if len(jump_index) == 1:
1664
- return jump_index[0]
1665
-
1666
- return -1
1670
+ return jump_index
1667
1671
 
1668
1672
 
1669
1673
  def _sg_filter(
@@ -354,7 +354,7 @@ def find_multipolygon(
354
354
  return shapely.MultiPolygon()
355
355
 
356
356
  assert len(hierarchy) == 1
357
- hierarchy = hierarchy[0]
357
+ hierarchy = hierarchy[0] # type: ignore[index]
358
358
 
359
359
  polygons = _contours_to_polygons(
360
360
  contours, # type: ignore[arg-type]
@@ -1,9 +1,12 @@
1
1
  """
2
- Meteorology and Air Traffic data source wrappers.
2
+ Meteorology, Air Traffic, and Observation data source wrappers.
3
3
 
4
4
  See individual modules for met variables and additional exports.
5
5
 
6
6
  - :module:`pycontrails.datalib.ecmwf`
7
7
  - :module:`pycontrails.datalib.gfs`
8
8
  - :module:`pycontrails.datalib.spire`
9
+ - :module:`pycontrails.datalib.goes`
10
+ - :module:`pycontrails.datalib.landsat`
11
+ - :module:`pycontrails.datalib.sentinel`
9
12
  """
@@ -0,0 +1,250 @@
1
+ """Tools for searching for low Earth orbit satellite imagery."""
2
+
3
+ import dataclasses
4
+ import pathlib
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from pycontrails.core import Flight
10
+ from pycontrails.utils import dependencies
11
+
12
+ try:
13
+ import geojson
14
+ except ModuleNotFoundError as exc:
15
+ dependencies.raise_module_not_found_error(
16
+ name="datalib._leo_utils module",
17
+ package_name="geojson",
18
+ module_not_found_error=exc,
19
+ pycontrails_optional_package="sat",
20
+ )
21
+
22
+ _path_to_static = pathlib.Path(__file__).parent / "static"
23
+ ROI_QUERY_FILENAME = _path_to_static / "bq_roi_query.sql"
24
+
25
+
26
+ #: GeoJSON polygon that covers the entire globe.
27
+ GLOBAL_EXTENT = geojson.dumps(
28
+ geojson.Polygon([[(-180, -90), (180, -90), (180, 90), (-180, 90), (-180, -90)]])
29
+ )
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class ROI:
34
+ """Spatiotemporal region of interest."""
35
+
36
+ #: Start time
37
+ start_time: np.datetime64
38
+
39
+ #: End time
40
+ end_time: np.datetime64
41
+
42
+ #: GeoJSON representation of spatial ROI.
43
+ extent: str
44
+
45
+ def __post_init__(self) -> None:
46
+ """Validate region of interest."""
47
+ if self.start_time > self.end_time:
48
+ msg = "start_time must be before end_time"
49
+ raise ValueError(msg)
50
+
51
+ try:
52
+ decoded = geojson.Feature(geometry=geojson.loads(self.extent))
53
+ except Exception as exc:
54
+ msg = "extent cannot be converted to GeoJSON structure"
55
+ raise ValueError(msg) from exc
56
+ if not decoded.is_valid:
57
+ msg = "extent is not valid GeoJSON"
58
+ raise ValueError(msg)
59
+
60
+
61
+ def track_to_geojson(flight: Flight) -> str:
62
+ """Convert ground track to GeoJSON string, splitting at antimeridian crossings.
63
+
64
+ Coordinates contain longitude and latitude only (no altitude coordinate)
65
+ and are padded to terminate and restart exactly at the antimeridian when
66
+ antimeridian crossings are encountered.
67
+
68
+ Parameters
69
+ ----------
70
+ flight : Flight
71
+ Flight with ground track to convert to GeoJSON string.
72
+
73
+ Returns
74
+ -------
75
+ str
76
+ String encoding of a GeoJSON MultiLineString containing ground track split at
77
+ antimeridian crossings.
78
+
79
+ See Also
80
+ --------
81
+ :meth:`Flight.to_geojson_multilinestring`
82
+ """
83
+
84
+ # Logic assumes longitudes are between -180 and 180.
85
+ # Raise an error if this is not the case.
86
+ if np.abs(flight["longitude"]).max() > 180.0:
87
+ msg = "Flight longitudes must be between -180 and 180."
88
+ raise ValueError(msg)
89
+
90
+ # Get feature collection containing a single multilinestring
91
+ # split at antimeridian crossings
92
+ fc = flight.to_geojson_multilinestring(split_antimeridian=True)
93
+
94
+ # Extract multilinestring
95
+ mls = fc["features"][0]["geometry"]
96
+
97
+ # Strip altitude coordinates
98
+ coords = [[[c[0], c[1]] for c in linestring] for linestring in mls["coordinates"]]
99
+
100
+ # No padding required if no antimeridian crossings were encountered
101
+ if len(coords) == 1:
102
+ return geojson.dumps(geojson.MultiLineString(coords))
103
+
104
+ # Pad at crossings
105
+ for i in range(len(coords) - 1):
106
+ x0 = coords[i][-1][0]
107
+ x1 = coords[i + 1][0][0]
108
+ if abs(x0) == 180.0 and abs(x1) == 180.0:
109
+ continue
110
+ y0 = coords[i][-1][1]
111
+ y1 = coords[i + 1][0][1]
112
+ xl = 180.0 * np.sign(x0)
113
+ xr = 180.0 * np.sign(x1)
114
+ w0 = np.abs(xr - x1)
115
+ w1 = np.abs(xl - x0)
116
+ yc = (w0 * y0 + w1 * y1) / (w0 + w1)
117
+ if abs(x0) < 180.0:
118
+ coords[i].append([xl, yc])
119
+ if abs(x1) < 180.0:
120
+ coords[i + 1].insert(0, [xr, yc])
121
+
122
+ # Encode as string
123
+ return geojson.dumps(geojson.MultiLineString(coords))
124
+
125
+
126
+ def query(table: str, roi: ROI, columns: list[str], extra_filters: str = "") -> pd.DataFrame:
127
+ """Find satellite imagery within region of interest.
128
+
129
+ This function requires access to the
130
+ `Google BigQuery API <https://cloud.google.com/bigquery?hl=en>`__
131
+ and uses the `BigQuery python library <https://cloud.google.com/python/docs/reference/bigquery/latest/index.html>`__.
132
+
133
+ Parameters
134
+ ----------
135
+ table : str
136
+ Name of BigQuery table to query
137
+ roi : ROI
138
+ Region of interest
139
+ columns : list[str]
140
+ Columns to return from Google
141
+ `BigQuery table <https://console.cloud.google.com/bigquery?p=bigquery-public-data&d=cloud_storage_geo_index&t=landsat_index&page=table&_ga=2.90807450.1051800793.1716904050-255800408.1705955196>`__.
142
+ extra_filters : str, optional
143
+ Additional selection filters, injected verbatim into constructed query.
144
+
145
+ Returns
146
+ -------
147
+ pd.DataFrame
148
+ Query results in pandas DataFrame
149
+ """
150
+ try:
151
+ from google.cloud import bigquery
152
+ except ModuleNotFoundError as exc:
153
+ dependencies.raise_module_not_found_error(
154
+ name="landsat module",
155
+ package_name="google-cloud-bigquery",
156
+ module_not_found_error=exc,
157
+ pycontrails_optional_package="landsat",
158
+ )
159
+
160
+ if len(columns) == 0:
161
+ msg = "At least column must be provided."
162
+ raise ValueError(msg)
163
+
164
+ start_time = pd.Timestamp(roi.start_time).strftime("%Y-%m-%d %H:%M:%S")
165
+ end_time = pd.Timestamp(roi.end_time).strftime("%Y-%m-%d %H:%M:%S")
166
+ extent = roi.extent.replace('"', "'")
167
+
168
+ client = bigquery.Client()
169
+ with open(ROI_QUERY_FILENAME) as f:
170
+ query_str = f.read().format(
171
+ table=table,
172
+ columns=",".join(columns),
173
+ start_time=start_time,
174
+ end_time=end_time,
175
+ geojson_str=extent,
176
+ extra_filters=extra_filters,
177
+ )
178
+
179
+ result = client.query(query_str)
180
+ return result.to_dataframe()
181
+
182
+
183
+ def intersect(
184
+ table: str, flight: Flight, columns: list[str], extra_filters: str = ""
185
+ ) -> pd.DataFrame:
186
+ """Find satellite imagery intersecting with flight track.
187
+
188
+ This function will return all scenes with a bounding box that includes flight waypoints
189
+ both before and after the sensing time.
190
+
191
+ This function requires access to the
192
+ `Google BigQuery API <https://cloud.google.com/bigquery?hl=en>`__
193
+ and uses the `BigQuery python library <https://cloud.google.com/python/docs/reference/bigquery/latest/index.html>`__.
194
+
195
+ Parameters
196
+ ----------
197
+ table : str
198
+ Name of BigQuery table to query
199
+ flight : Flight
200
+ Flight for intersection
201
+ columns : list[str]
202
+ Columns to return from Google
203
+ `BigQuery table <https://console.cloud.google.com/bigquery?p=bigquery-public-data&d=cloud_storage_geo_index&t=landsat_index&page=table&_ga=2.90807450.1051800793.1716904050-255800408.1705955196>`__.
204
+ extra_filters : str, optional
205
+ Additional selection filters, injected verbatim into constructed query.
206
+
207
+ Returns
208
+ -------
209
+ pd.DataFrame
210
+ Query results in pandas DataFrame
211
+ """
212
+
213
+ # create ROI with time span between flight start and end
214
+ # and spatial extent set to flight track
215
+ extent = track_to_geojson(flight)
216
+ roi = ROI(start_time=flight["time"].min(), end_time=flight["time"].max(), extent=extent)
217
+
218
+ # first pass: query for intersections with ROI
219
+ # requires additional columns for final intersection with flight
220
+ required_columns = set(["sensing_time", "west_lon", "east_lon", "south_lat", "north_lat"])
221
+ queried_columns = list(required_columns.union(set(columns)))
222
+ candidates = query(table, roi, queried_columns, extra_filters)
223
+
224
+ if len(candidates) == 0: # already know there are no intersections
225
+ return candidates[columns]
226
+
227
+ # second pass: keep images with where flight waypoints
228
+ # bounding sensing time are both within bounding box
229
+ flight_data = flight.dataframe
230
+
231
+ def intersects(scene: pd.Series) -> bool:
232
+ if scene["west_lon"] <= scene["east_lon"]: # scene does not span antimeridian
233
+ bbox_data = flight_data[
234
+ flight_data["longitude"].between(scene["west_lon"], scene["east_lon"])
235
+ & flight_data["latitude"].between(scene["south_lat"], scene["north_lat"])
236
+ ]
237
+ else: # scene spans antimeridian
238
+ bbox_data = flight_data[
239
+ (
240
+ flight_data["longitude"]
241
+ > scene["west_lon"] | flight.data["longitude"]
242
+ < scene["east_lon"]
243
+ )
244
+ & flight_data["latitude"].between(scene["south_lat"], scene["north_lat"])
245
+ ]
246
+ sensing_time = pd.Timestamp(scene["sensing_time"]).tz_localize(None)
247
+ return bbox_data["time"].min() <= sensing_time and bbox_data["time"].max() >= sensing_time
248
+
249
+ mask = candidates.apply(intersects, axis="columns")
250
+ return candidates[columns][mask]
@@ -0,0 +1,6 @@
1
+ SELECT {columns} FROM `{table}` WHERE
2
+ sensing_time >= "{start_time}" AND
3
+ sensing_time <= "{end_time}" AND
4
+ ST_INTERSECTSBOX(ST_GEOGFROMGEOJSON("{geojson_str}"), west_lon, south_lat, east_lon, north_lat)
5
+ {extra_filters}
6
+ ORDER BY sensing_time ASC
@@ -0,0 +1,60 @@
1
+ """Utilities for visualization of low-Earth orbit satellite imagery."""
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+
7
+ from pycontrails.utils import dependencies
8
+
9
+ try:
10
+ import skimage as ski
11
+ except ModuleNotFoundError as exc:
12
+ dependencies.raise_module_not_found_error(
13
+ name="landsat module",
14
+ package_name="scikit-image",
15
+ module_not_found_error=exc,
16
+ pycontrails_optional_package="sat",
17
+ )
18
+
19
+
20
+ def normalize(channel: np.ndarray) -> np.ndarray:
21
+ """Normalize channel values to range [0, 1], preserving ``np.nan`` in output.
22
+
23
+ Parameters
24
+ ----------
25
+ channel: np.ndarray
26
+ Array of channel values for normalization.
27
+
28
+ Returns
29
+ -------
30
+ np.ndarray
31
+ Equalized channel values. ``np.nan`` will be preserved wherever present in input.
32
+ """
33
+ return (channel - np.nanmin(channel)) / (np.nanmax(channel) - np.nanmin(channel))
34
+
35
+
36
+ def equalize(channel: np.ndarray, **equalize_kwargs: Any) -> np.ndarray:
37
+ """Apply :py:func:`ski.exposure.equalize_adapthist`, preserving ``np.nan`` in output.
38
+
39
+ Parameters
40
+ ----------
41
+ channel: np.ndarray
42
+ Array of channel values for equalization.
43
+ **equalize_kwargs : Any
44
+ Keyword arguments passed to :py:func:`ski.exposure.equalize_adapthist`.
45
+
46
+ Returns
47
+ -------
48
+ np.ndarray
49
+ Equalized channel values. ``np.nan`` will be preserved wherever present in input.
50
+
51
+ Notes
52
+ -----
53
+ NaN values are converted to 0 before passing to :py:func:`ski.exposure.equalize_adapthist`
54
+ and may affect equalized values in the neighborhood where they occur.
55
+ """
56
+ return np.where(
57
+ np.isnan(channel),
58
+ np.nan,
59
+ ski.exposure.equalize_adapthist(np.nan_to_num(channel, nan=0), **equalize_kwargs),
60
+ )
@@ -1,4 +1,4 @@
1
- """Datalib utilities."""
1
+ """Met datalib definitions and utilities."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
@@ -32,8 +32,9 @@ from typing import Any
32
32
  import xarray as xr
33
33
  from overrides import overrides
34
34
 
35
- from pycontrails.core import cache, datalib, met_var
35
+ from pycontrails.core import cache, met_var
36
36
  from pycontrails.core.met import MetDataset
37
+ from pycontrails.datalib._met_utils import metsource
37
38
  from pycontrails.datalib.ecmwf import common as ecmwf_common
38
39
  from pycontrails.datalib.ecmwf import variables as ecmwf_variables
39
40
  from pycontrails.datalib.ecmwf.model_levels import pressure_levels_at_model_levels
@@ -374,23 +375,23 @@ class ARCOERA5(ecmwf_common.ECMWFAPI):
374
375
 
375
376
  def __init__(
376
377
  self,
377
- time: datalib.TimeInput,
378
- variables: datalib.VariableInput,
379
- pressure_levels: datalib.PressureLevelInput | None = None,
378
+ time: metsource.TimeInput,
379
+ variables: metsource.VariableInput,
380
+ pressure_levels: metsource.PressureLevelInput | None = None,
380
381
  grid: float = 0.25,
381
382
  cachestore: cache.CacheStore | None = __marker, # type: ignore[assignment]
382
383
  n_jobs: int = 1,
383
384
  cleanup_metview_tempfiles: bool = True,
384
385
  ) -> None:
385
- self.timesteps = datalib.parse_timesteps(time)
386
+ self.timesteps = metsource.parse_timesteps(time)
386
387
 
387
388
  if pressure_levels is None:
388
389
  self.pressure_levels = pressure_levels_at_model_levels(20_000.0, 50_000.0)
389
390
  else:
390
- self.pressure_levels = datalib.parse_pressure_levels(pressure_levels)
391
+ self.pressure_levels = metsource.parse_pressure_levels(pressure_levels)
391
392
 
392
393
  self.paths = None
393
- self.variables = datalib.parse_variables(variables, self.supported_variables)
394
+ self.variables = metsource.parse_variables(variables, self.supported_variables)
394
395
  self.grid = grid
395
396
  self.cachestore = cache.DiskCacheStore() if cachestore is self.__marker else cachestore
396
397
  self.n_jobs = max(1, n_jobs)
@@ -13,10 +13,11 @@ import pandas as pd
13
13
  import xarray as xr
14
14
  from overrides import overrides
15
15
 
16
- from pycontrails.core import datalib, met
16
+ from pycontrails.core import met
17
+ from pycontrails.datalib._met_utils import metsource
17
18
 
18
19
 
19
- class ECMWFAPI(datalib.MetDataSource):
20
+ class ECMWFAPI(metsource.MetDataSource):
20
21
  """Abstract class for all ECMWF data accessed remotely through CDS / MARS."""
21
22
 
22
23
  @property
@@ -19,8 +19,9 @@ import xarray as xr
19
19
  from overrides import overrides
20
20
 
21
21
  import pycontrails
22
- from pycontrails.core import cache, datalib
22
+ from pycontrails.core import cache
23
23
  from pycontrails.core.met import MetDataset, MetVariable
24
+ from pycontrails.datalib._met_utils import metsource
24
25
  from pycontrails.datalib.ecmwf.common import ECMWFAPI, CDSCredentialsNotFound
25
26
  from pycontrails.datalib.ecmwf.variables import PRESSURE_LEVEL_VARIABLES, SURFACE_VARIABLES
26
27
  from pycontrails.utils import dependencies, temp
@@ -49,16 +50,16 @@ class ERA5(ECMWFAPI):
49
50
 
50
51
  Parameters
51
52
  ----------
52
- time : datalib.TimeInput | None
53
+ time : metsource.TimeInput | None
53
54
  The time range for data retrieval, either a single datetime or (start, end) datetime range.
54
55
  Input must be datetime-like or tuple of datetime-like
55
56
  (`datetime`, :class:`pd.Timestamp`, :class:`np.datetime64`)
56
57
  specifying the (start, end) of the date range, inclusive.
57
58
  Datafiles will be downloaded from CDS for each day to reduce requests.
58
59
  If None, ``paths`` must be defined and all time coordinates will be loaded from files.
59
- variables : datalib.VariableInput
60
+ variables : metsource.VariableInput
60
61
  Variable name (i.e. "t", "air_temperature", ["air_temperature, relative_humidity"])
61
- pressure_levels : datalib.PressureLevelInput, optional
62
+ pressure_levels : metsource.PressureLevelInput, optional
62
63
  Pressure levels for data, in hPa (mbar)
63
64
  Set to -1 for to download surface level parameters.
64
65
  Defaults to -1.
@@ -145,9 +146,9 @@ class ERA5(ECMWFAPI):
145
146
 
146
147
  def __init__(
147
148
  self,
148
- time: datalib.TimeInput | None,
149
- variables: datalib.VariableInput,
150
- pressure_levels: datalib.PressureLevelInput = -1,
149
+ time: metsource.TimeInput | None,
150
+ variables: metsource.VariableInput,
151
+ pressure_levels: metsource.PressureLevelInput = -1,
151
152
  paths: str | list[str] | pathlib.Path | list[pathlib.Path] | None = None,
152
153
  timestep_freq: str | None = None,
153
154
  product_type: str = "reanalysis",
@@ -193,11 +194,11 @@ class ERA5(ECMWFAPI):
193
194
  if timestep_freq is None:
194
195
  timestep_freq = "1h" if product_type == "reanalysis" else "3h"
195
196
 
196
- self.timesteps = datalib.parse_timesteps(time, freq=timestep_freq)
197
- self.pressure_levels = datalib.parse_pressure_levels(
197
+ self.timesteps = metsource.parse_timesteps(time, freq=timestep_freq)
198
+ self.pressure_levels = metsource.parse_pressure_levels(
198
199
  pressure_levels, self.supported_pressure_levels
199
200
  )
200
- self.variables = datalib.parse_variables(variables, self.supported_variables)
201
+ self.variables = metsource.parse_variables(variables, self.supported_variables)
201
202
 
202
203
  # ensemble_mean, etc - time is only available on the 0, 3, 6, etc
203
204
  if product_type.startswith("ensemble") and any(t.hour % 3 for t in self.timesteps):
@@ -482,7 +483,7 @@ class ERA5(ECMWFAPI):
482
483
 
483
484
  # open file, edit, and save for each hourly time step
484
485
  ds = stack.enter_context(
485
- xr.open_dataset(cds_temp_filename, engine=datalib.NETCDF_ENGINE)
486
+ xr.open_dataset(cds_temp_filename, engine=metsource.NETCDF_ENGINE)
486
487
  )
487
488
 
488
489
  # run preprocessing before cache