pycontrails 0.52.2__cp310-cp310-win_amd64.whl → 0.53.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pycontrails might be problematic. Click here for more details.
- pycontrails/_version.py +2 -2
- pycontrails/core/cache.py +1 -1
- pycontrails/core/flight.py +8 -5
- pycontrails/core/flightplan.py +1 -1
- pycontrails/core/interpolation.py +3 -1
- pycontrails/core/met.py +190 -15
- pycontrails/core/met_var.py +1 -1
- pycontrails/core/models.py +5 -5
- pycontrails/core/rgi_cython.cp310-win_amd64.pyd +0 -0
- pycontrails/core/vector.py +5 -5
- pycontrails/datalib/_leo_utils/vis.py +10 -11
- pycontrails/datalib/_met_utils/metsource.py +13 -11
- pycontrails/datalib/ecmwf/era5.py +1 -1
- pycontrails/datalib/ecmwf/era5_model_level.py +1 -1
- pycontrails/datalib/ecmwf/hres_model_level.py +3 -3
- pycontrails/datalib/ecmwf/variables.py +3 -3
- pycontrails/datalib/gfs/gfs.py +4 -3
- pycontrails/datalib/landsat.py +10 -9
- pycontrails/ext/synthetic_flight.py +1 -1
- pycontrails/models/accf.py +1 -1
- pycontrails/models/apcemm/apcemm.py +5 -5
- pycontrails/models/cocip/cocip.py +98 -24
- pycontrails/models/cocip/cocip_params.py +21 -0
- pycontrails/models/cocip/output_formats.py +13 -4
- pycontrails/models/cocip/radiative_forcing.py +3 -3
- pycontrails/models/cocipgrid/cocip_grid.py +4 -4
- pycontrails/models/ps_model/ps_model.py +4 -4
- pycontrails/models/sac.py +2 -2
- pycontrails/physics/thermo.py +1 -1
- pycontrails/utils/json.py +16 -18
- pycontrails/utils/types.py +7 -6
- {pycontrails-0.52.2.dist-info → pycontrails-0.53.0.dist-info}/METADATA +78 -78
- {pycontrails-0.52.2.dist-info → pycontrails-0.53.0.dist-info}/RECORD +37 -37
- {pycontrails-0.52.2.dist-info → pycontrails-0.53.0.dist-info}/WHEEL +1 -1
- {pycontrails-0.52.2.dist-info → pycontrails-0.53.0.dist-info}/top_level.txt +0 -1
- {pycontrails-0.52.2.dist-info → pycontrails-0.53.0.dist-info}/LICENSE +0 -0
- {pycontrails-0.52.2.dist-info → pycontrails-0.53.0.dist-info}/NOTICE +0 -0
pycontrails/_version.py
CHANGED
pycontrails/core/cache.py
CHANGED
|
@@ -146,7 +146,7 @@ class CacheStore(ABC):
|
|
|
146
146
|
"""
|
|
147
147
|
|
|
148
148
|
# TODO: run in parallel?
|
|
149
|
-
return [self.put(d, cp) for d, cp in zip(data_path, cache_path)]
|
|
149
|
+
return [self.put(d, cp) for d, cp in zip(data_path, cache_path, strict=True)]
|
|
150
150
|
|
|
151
151
|
# In the three methods below, child classes have a complete docstring.
|
|
152
152
|
|
pycontrails/core/flight.py
CHANGED
|
@@ -1356,7 +1356,7 @@ class Flight(GeoVectorDataset):
|
|
|
1356
1356
|
# NOTE: geod.npts does not return the initial or terminal points
|
|
1357
1357
|
lonlats: list[tuple[float, float]] = geod.npts(lon0, lat0, lon1, lat1, n_steps)
|
|
1358
1358
|
|
|
1359
|
-
lons, lats = zip(*lonlats)
|
|
1359
|
+
lons, lats = zip(*lonlats, strict=True)
|
|
1360
1360
|
longitudes.extend(lons)
|
|
1361
1361
|
latitudes.extend(lats)
|
|
1362
1362
|
|
|
@@ -1657,10 +1657,11 @@ def _return_linestring(data: dict[str, npt.NDArray[np.float64]]) -> list[list[fl
|
|
|
1657
1657
|
The list of coordinates
|
|
1658
1658
|
"""
|
|
1659
1659
|
# rounding to reduce the size of resultant json arrays
|
|
1660
|
-
points = zip(
|
|
1660
|
+
points = zip(
|
|
1661
1661
|
np.round(data["longitude"], decimals=4),
|
|
1662
1662
|
np.round(data["latitude"], decimals=4),
|
|
1663
1663
|
np.round(data["altitude"], decimals=4),
|
|
1664
|
+
strict=True,
|
|
1664
1665
|
)
|
|
1665
1666
|
return [list(p) for p in points]
|
|
1666
1667
|
|
|
@@ -1949,7 +1950,9 @@ def _altitude_interpolation_climb_descend_middle(
|
|
|
1949
1950
|
# Form array of cumulative altitude values if the flight were to climb
|
|
1950
1951
|
# at nominal_rocd over each group of nan
|
|
1951
1952
|
cumalt_list = []
|
|
1952
|
-
for start_na_idx, end_na_idx, size in zip(
|
|
1953
|
+
for start_na_idx, end_na_idx, size in zip(
|
|
1954
|
+
start_na_idxs, end_na_idxs, na_group_size, strict=True
|
|
1955
|
+
):
|
|
1953
1956
|
if s[start_na_idx] <= s[end_na_idx]:
|
|
1954
1957
|
cumalt_list.append(np.arange(1, size, dtype=float))
|
|
1955
1958
|
else:
|
|
@@ -2053,7 +2056,7 @@ def filter_altitude(
|
|
|
2053
2056
|
--------
|
|
2054
2057
|
:meth:`traffic.core.flight.Flight.filter`
|
|
2055
2058
|
:func:`scipy.signal.medfilt`
|
|
2056
|
-
"""
|
|
2059
|
+
"""
|
|
2057
2060
|
if not len(altitude_ft):
|
|
2058
2061
|
raise ValueError("Altitude must have non-zero length to filter")
|
|
2059
2062
|
|
|
@@ -2114,7 +2117,7 @@ def filter_altitude(
|
|
|
2114
2117
|
|
|
2115
2118
|
result = np.copy(altitude_ft)
|
|
2116
2119
|
if np.any(start_idxs):
|
|
2117
|
-
for i0, i1 in zip(start_idxs, end_idxs):
|
|
2120
|
+
for i0, i1 in zip(start_idxs, end_idxs, strict=True):
|
|
2118
2121
|
result[i0:i1] = altitude_filt[i0:i1]
|
|
2119
2122
|
|
|
2120
2123
|
# reapply Savitzky-Golay filter to smooth climb and descent
|
pycontrails/core/flightplan.py
CHANGED
|
@@ -76,6 +76,7 @@ class PycontrailsRegularGridInterpolator(scipy.interpolate.RegularGridInterpolat
|
|
|
76
76
|
self.method = _pick_method(scipy.__version__, method)
|
|
77
77
|
self.bounds_error = bounds_error
|
|
78
78
|
self.fill_value = fill_value
|
|
79
|
+
self._spline = None
|
|
79
80
|
|
|
80
81
|
def _prepare_xi_simple(self, xi: npt.NDArray[np.float64]) -> npt.NDArray[np.bool_]:
|
|
81
82
|
"""Run looser version of :meth:`_prepare_xi`.
|
|
@@ -215,7 +216,8 @@ class PycontrailsRegularGridInterpolator(scipy.interpolate.RegularGridInterpolat
|
|
|
215
216
|
|
|
216
217
|
if ndim == 1:
|
|
217
218
|
# np.interp could be better ... although that may also promote the dtype
|
|
218
|
-
|
|
219
|
+
# 1-d view is required for evaluate_linear_1d
|
|
220
|
+
return rgi_cython.evaluate_linear_1d(values, indices[0, :], norm_distances[0, :], out)
|
|
219
221
|
|
|
220
222
|
msg = f"Invalid number of dimensions: {ndim}"
|
|
221
223
|
raise ValueError(msg)
|
pycontrails/core/met.py
CHANGED
|
@@ -9,7 +9,15 @@ import pathlib
|
|
|
9
9
|
import typing
|
|
10
10
|
import warnings
|
|
11
11
|
from abc import ABC, abstractmethod
|
|
12
|
-
from collections.abc import
|
|
12
|
+
from collections.abc import (
|
|
13
|
+
Generator,
|
|
14
|
+
Hashable,
|
|
15
|
+
Iterable,
|
|
16
|
+
Iterator,
|
|
17
|
+
Mapping,
|
|
18
|
+
MutableMapping,
|
|
19
|
+
Sequence,
|
|
20
|
+
)
|
|
13
21
|
from contextlib import ExitStack
|
|
14
22
|
from datetime import datetime
|
|
15
23
|
from typing import (
|
|
@@ -62,12 +70,12 @@ class MetBase(ABC, Generic[XArrayType]):
|
|
|
62
70
|
cachestore: CacheStore | None
|
|
63
71
|
|
|
64
72
|
#: Default dimension order for DataArray or Dataset (x, y, z, t)
|
|
65
|
-
dim_order:
|
|
73
|
+
dim_order: tuple[Hashable, Hashable, Hashable, Hashable] = (
|
|
66
74
|
"longitude",
|
|
67
75
|
"latitude",
|
|
68
76
|
"level",
|
|
69
77
|
"time",
|
|
70
|
-
|
|
78
|
+
)
|
|
71
79
|
|
|
72
80
|
def __repr__(self) -> str:
|
|
73
81
|
data = getattr(self, "data", None)
|
|
@@ -192,10 +200,8 @@ class MetBase(ABC, Generic[XArrayType]):
|
|
|
192
200
|
def _validate_transpose(self) -> None:
|
|
193
201
|
"""Check that data is transposed according to :attr:`dim_order`."""
|
|
194
202
|
|
|
195
|
-
dims_tuple = tuple(self.dim_order)
|
|
196
|
-
|
|
197
203
|
def _check_da(da: xr.DataArray, key: Hashable | None = None) -> None:
|
|
198
|
-
if da.dims !=
|
|
204
|
+
if da.dims != self.dim_order:
|
|
199
205
|
if key is not None:
|
|
200
206
|
msg = (
|
|
201
207
|
f"Data dimension not transposed on variable '{key}'. Initiate with"
|
|
@@ -263,7 +269,7 @@ class MetBase(ABC, Generic[XArrayType]):
|
|
|
263
269
|
self.data["time"] = self.data["time"].astype("datetime64[ns]", copy=False)
|
|
264
270
|
|
|
265
271
|
# sortby to ensure each coordinate has ascending order
|
|
266
|
-
self.data = self.data.sortby(self.dim_order, ascending=True)
|
|
272
|
+
self.data = self.data.sortby(list(self.dim_order), ascending=True)
|
|
267
273
|
|
|
268
274
|
if not self.is_wrapped:
|
|
269
275
|
# Ensure longitude is contained in interval [-180, 180)
|
|
@@ -285,7 +291,7 @@ class MetBase(ABC, Generic[XArrayType]):
|
|
|
285
291
|
self._validate_latitude()
|
|
286
292
|
|
|
287
293
|
# transpose to have ordering (x, y, z, t, ...)
|
|
288
|
-
dim_order = self.dim_order
|
|
294
|
+
dim_order = [*self.dim_order, *(d for d in self.data.dims if d not in self.dim_order)]
|
|
289
295
|
self.data = self.data.transpose(*dim_order)
|
|
290
296
|
|
|
291
297
|
# single level data
|
|
@@ -481,7 +487,7 @@ class MetBase(ABC, Generic[XArrayType]):
|
|
|
481
487
|
self.cachestore = self.cachestore or DiskCacheStore()
|
|
482
488
|
|
|
483
489
|
# group by hour and save one dataset for each hour to temp file
|
|
484
|
-
times, datasets = zip(*dataset.groupby("time", squeeze=False))
|
|
490
|
+
times, datasets = zip(*dataset.groupby("time", squeeze=False), strict=True)
|
|
485
491
|
|
|
486
492
|
# Open ExitStack to control temp_file context manager
|
|
487
493
|
with ExitStack() as stack:
|
|
@@ -912,7 +918,7 @@ class MetDataset(MetBase):
|
|
|
912
918
|
KeyError
|
|
913
919
|
Raises when dataset does not contain variable in ``vars``
|
|
914
920
|
"""
|
|
915
|
-
if isinstance(vars,
|
|
921
|
+
if isinstance(vars, MetVariable | str):
|
|
916
922
|
vars = (vars,)
|
|
917
923
|
|
|
918
924
|
met_keys: list[str] = []
|
|
@@ -1014,7 +1020,7 @@ class MetDataset(MetBase):
|
|
|
1014
1020
|
|
|
1015
1021
|
@overrides
|
|
1016
1022
|
def broadcast_coords(self, name: str) -> xr.DataArray:
|
|
1017
|
-
da = xr.ones_like(self.data[
|
|
1023
|
+
da = xr.ones_like(self.data[next(iter(self.data.keys()))]) * self.data[name]
|
|
1018
1024
|
da.name = name
|
|
1019
1025
|
|
|
1020
1026
|
return da
|
|
@@ -1066,7 +1072,7 @@ class MetDataset(MetBase):
|
|
|
1066
1072
|
coords_vals = [indexes[key].values for key in coords_keys]
|
|
1067
1073
|
coords_meshes = np.meshgrid(*coords_vals, indexing="ij")
|
|
1068
1074
|
raveled_coords = (mesh.ravel() for mesh in coords_meshes)
|
|
1069
|
-
data = dict(zip(coords_keys, raveled_coords))
|
|
1075
|
+
data = dict(zip(coords_keys, raveled_coords, strict=True))
|
|
1070
1076
|
|
|
1071
1077
|
out = vector_module.GeoVectorDataset(data, copy=False)
|
|
1072
1078
|
for key, da in self.data.items():
|
|
@@ -1502,6 +1508,7 @@ class MetDataArray(MetBase):
|
|
|
1502
1508
|
bounds_error: bool = ...,
|
|
1503
1509
|
fill_value: float | np.float64 | None = ...,
|
|
1504
1510
|
localize: bool = ...,
|
|
1511
|
+
lowmem: bool = ...,
|
|
1505
1512
|
indices: interpolation.RGIArtifacts | None = ...,
|
|
1506
1513
|
return_indices: Literal[False] = ...,
|
|
1507
1514
|
) -> npt.NDArray[np.float64]: ...
|
|
@@ -1518,6 +1525,7 @@ class MetDataArray(MetBase):
|
|
|
1518
1525
|
bounds_error: bool = ...,
|
|
1519
1526
|
fill_value: float | np.float64 | None = ...,
|
|
1520
1527
|
localize: bool = ...,
|
|
1528
|
+
lowmem: bool = ...,
|
|
1521
1529
|
indices: interpolation.RGIArtifacts | None = ...,
|
|
1522
1530
|
return_indices: Literal[True],
|
|
1523
1531
|
) -> tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]: ...
|
|
@@ -1533,6 +1541,7 @@ class MetDataArray(MetBase):
|
|
|
1533
1541
|
bounds_error: bool = False,
|
|
1534
1542
|
fill_value: float | np.float64 | None = np.nan,
|
|
1535
1543
|
localize: bool = False,
|
|
1544
|
+
lowmem: bool = False,
|
|
1536
1545
|
indices: interpolation.RGIArtifacts | None = None,
|
|
1537
1546
|
return_indices: bool = False,
|
|
1538
1547
|
) -> npt.NDArray[np.float64] | tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]:
|
|
@@ -1540,7 +1549,9 @@ class MetDataArray(MetBase):
|
|
|
1540
1549
|
|
|
1541
1550
|
Zero dimensional coordinates are reshaped to 1D arrays.
|
|
1542
1551
|
|
|
1543
|
-
|
|
1552
|
+
If ``lowmem == False``, method automatically loads underlying :attr:`data` into
|
|
1553
|
+
memory. Otherwise, method iterates through smaller subsets of :attr:`data` and releases
|
|
1554
|
+
subsets from memory once interpolation against each subset is finished.
|
|
1544
1555
|
|
|
1545
1556
|
If ``method == "nearest"``, the out array will have the same ``dtype`` as
|
|
1546
1557
|
the underlying :attr:`data`.
|
|
@@ -1586,10 +1597,18 @@ class MetDataArray(MetBase):
|
|
|
1586
1597
|
localize: bool, optional
|
|
1587
1598
|
Experimental. If True, downselect gridded data to smallest bounding box containing
|
|
1588
1599
|
all points. By default False.
|
|
1600
|
+
lowmem: bool, optional
|
|
1601
|
+
Experimental. If True, iterate through points binned by the time coordinate of the
|
|
1602
|
+
grided data, and downselect gridded data to the smallest bounding box containing
|
|
1603
|
+
each binned set of point *before loading into memory*. This can significantly reduce
|
|
1604
|
+
memory consumption with large numbers of points at the cost of increased runtime.
|
|
1605
|
+
By default False.
|
|
1589
1606
|
indices: tuple | None, optional
|
|
1590
1607
|
Experimental. See :func:`interpolation.interp`. None by default.
|
|
1591
1608
|
return_indices: bool, optional
|
|
1592
1609
|
Experimental. See :func:`interpolation.interp`. False by default.
|
|
1610
|
+
Note that values returned differ when ``lowmem=True`` and ``lowmem=False``,
|
|
1611
|
+
so output should only be re-used in calls with the same ``lowmem`` value.
|
|
1593
1612
|
|
|
1594
1613
|
Returns
|
|
1595
1614
|
-------
|
|
@@ -1632,10 +1651,29 @@ class MetDataArray(MetBase):
|
|
|
1632
1651
|
>>> level = np.linspace(200, 300, 10)
|
|
1633
1652
|
>>> time = pd.date_range("2022-03-01T14", periods=10, freq="5min")
|
|
1634
1653
|
>>> mda.interpolate(longitude, latitude, level, time)
|
|
1654
|
+
array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
|
|
1655
|
+
231.10858599, 233.54857391, 235.71504913, 237.86478872,
|
|
1656
|
+
239.99274623, 242.10792167])
|
|
1657
|
+
|
|
1658
|
+
>>> # Can easily switch to alternative low-memory implementation
|
|
1659
|
+
>>> mda.interpolate(longitude, latitude, level, time, lowmem=True)
|
|
1635
1660
|
array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
|
|
1636
1661
|
231.10858599, 233.54857391, 235.71504913, 237.86478872,
|
|
1637
1662
|
239.99274623, 242.10792167])
|
|
1638
1663
|
"""
|
|
1664
|
+
if lowmem:
|
|
1665
|
+
return self._interp_lowmem(
|
|
1666
|
+
longitude,
|
|
1667
|
+
latitude,
|
|
1668
|
+
level,
|
|
1669
|
+
time,
|
|
1670
|
+
method=method,
|
|
1671
|
+
bounds_error=bounds_error,
|
|
1672
|
+
fill_value=fill_value,
|
|
1673
|
+
indices=indices,
|
|
1674
|
+
return_indices=return_indices,
|
|
1675
|
+
)
|
|
1676
|
+
|
|
1639
1677
|
# Load if necessary
|
|
1640
1678
|
if not self.in_memory:
|
|
1641
1679
|
self._check_memory("Interpolation over")
|
|
@@ -1660,6 +1698,100 @@ class MetDataArray(MetBase):
|
|
|
1660
1698
|
return_indices=return_indices,
|
|
1661
1699
|
)
|
|
1662
1700
|
|
|
1701
|
+
def _interp_lowmem(
|
|
1702
|
+
self,
|
|
1703
|
+
longitude: float | npt.NDArray[np.float64],
|
|
1704
|
+
latitude: float | npt.NDArray[np.float64],
|
|
1705
|
+
level: float | npt.NDArray[np.float64],
|
|
1706
|
+
time: np.datetime64 | npt.NDArray[np.datetime64],
|
|
1707
|
+
*,
|
|
1708
|
+
method: str = "linear",
|
|
1709
|
+
bounds_error: bool = False,
|
|
1710
|
+
fill_value: float | np.float64 | None = np.nan,
|
|
1711
|
+
minimize_memory: bool = False,
|
|
1712
|
+
indices: interpolation.RGIArtifacts | None = None,
|
|
1713
|
+
return_indices: bool = False,
|
|
1714
|
+
) -> npt.NDArray[np.float64] | tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]:
|
|
1715
|
+
"""Interpolate values against underlying DataArray.
|
|
1716
|
+
|
|
1717
|
+
This method is used by :meth:`interpolate` when ``lowmem=True``.
|
|
1718
|
+
Parameters and return types are identical to :meth:`interpolate`, except
|
|
1719
|
+
that the ``localize`` keyword argument is omitted.
|
|
1720
|
+
"""
|
|
1721
|
+
# Convert all inputs to 1d arrays
|
|
1722
|
+
# Not validating against ndim >= 2
|
|
1723
|
+
longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
|
|
1724
|
+
|
|
1725
|
+
if bounds_error:
|
|
1726
|
+
_lowmem_boundscheck(time, self.data)
|
|
1727
|
+
|
|
1728
|
+
# Create buffers for holding interpolation output
|
|
1729
|
+
# Use np.full rather than np.empty so points not covered
|
|
1730
|
+
# by masks are filled with correct out-of-bounds values.
|
|
1731
|
+
out = np.full(longitude.shape, fill_value, dtype=self.data.dtype)
|
|
1732
|
+
if return_indices:
|
|
1733
|
+
rgi_artifacts = interpolation.RGIArtifacts(
|
|
1734
|
+
xi_indices=np.full((4, longitude.size), -1, dtype=np.int64),
|
|
1735
|
+
norm_distances=np.full((4, longitude.size), np.nan, dtype=np.float64),
|
|
1736
|
+
out_of_bounds=np.full((longitude.size,), True, dtype=np.bool_),
|
|
1737
|
+
)
|
|
1738
|
+
|
|
1739
|
+
# Iterate over portions of points between adjacent time steps in gridded data
|
|
1740
|
+
for mask in _lowmem_masks(time, self.data["time"].values):
|
|
1741
|
+
if mask is None or not np.any(mask):
|
|
1742
|
+
continue
|
|
1743
|
+
|
|
1744
|
+
lon_sl = longitude[mask]
|
|
1745
|
+
lat_sl = latitude[mask]
|
|
1746
|
+
lev_sl = level[mask]
|
|
1747
|
+
t_sl = time[mask]
|
|
1748
|
+
if indices is not None:
|
|
1749
|
+
indices_sl = interpolation.RGIArtifacts(
|
|
1750
|
+
xi_indices=indices.xi_indices[:, mask],
|
|
1751
|
+
norm_distances=indices.norm_distances[:, mask],
|
|
1752
|
+
out_of_bounds=indices.out_of_bounds[mask],
|
|
1753
|
+
)
|
|
1754
|
+
else:
|
|
1755
|
+
indices_sl = None
|
|
1756
|
+
|
|
1757
|
+
coords = {"longitude": lon_sl, "latitude": lat_sl, "level": lev_sl, "time": t_sl}
|
|
1758
|
+
if any(np.all(np.isnan(coord)) for coord in coords.values()):
|
|
1759
|
+
continue
|
|
1760
|
+
da = interpolation._localize(self.data, coords)
|
|
1761
|
+
if not da._in_memory:
|
|
1762
|
+
logger.debug(
|
|
1763
|
+
"Loading %s MB subset of %s into memory.",
|
|
1764
|
+
round(da.nbytes / 1_000_000, 2),
|
|
1765
|
+
da.name,
|
|
1766
|
+
)
|
|
1767
|
+
da.load()
|
|
1768
|
+
|
|
1769
|
+
tmp = interpolation.interp(
|
|
1770
|
+
longitude=lon_sl,
|
|
1771
|
+
latitude=lat_sl,
|
|
1772
|
+
level=lev_sl,
|
|
1773
|
+
time=t_sl,
|
|
1774
|
+
da=da,
|
|
1775
|
+
method=method,
|
|
1776
|
+
bounds_error=bounds_error,
|
|
1777
|
+
fill_value=fill_value,
|
|
1778
|
+
localize=False, # would be no-op; da is localized already
|
|
1779
|
+
indices=indices_sl,
|
|
1780
|
+
return_indices=return_indices,
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
if return_indices:
|
|
1784
|
+
out[mask], rgi_sl = tmp
|
|
1785
|
+
rgi_artifacts.xi_indices[:, mask] = rgi_sl.xi_indices
|
|
1786
|
+
rgi_artifacts.norm_distances[:, mask] = rgi_sl.norm_distances
|
|
1787
|
+
rgi_artifacts.out_of_bounds[mask] = rgi_sl.out_of_bounds
|
|
1788
|
+
else:
|
|
1789
|
+
out[mask] = tmp
|
|
1790
|
+
|
|
1791
|
+
if return_indices:
|
|
1792
|
+
return out, rgi_artifacts
|
|
1793
|
+
return out
|
|
1794
|
+
|
|
1663
1795
|
def _check_memory(self, msg_start: str) -> None:
|
|
1664
1796
|
"""Check the memory usage of the underlying data.
|
|
1665
1797
|
|
|
@@ -1731,7 +1863,7 @@ class MetDataArray(MetBase):
|
|
|
1731
1863
|
cachestore = cachestore or DiskCacheStore()
|
|
1732
1864
|
chunks = chunks or {}
|
|
1733
1865
|
data = _load(hash, cachestore, chunks)
|
|
1734
|
-
return cls(data[
|
|
1866
|
+
return cls(data[next(iter(data.data_vars))])
|
|
1735
1867
|
|
|
1736
1868
|
@property
|
|
1737
1869
|
def proportion(self) -> float:
|
|
@@ -2124,7 +2256,7 @@ class MetDataArray(MetBase):
|
|
|
2124
2256
|
-----
|
|
2125
2257
|
Uses the `scikit-image Marching Cubes <https://scikit-image.org/docs/dev/auto_examples/edges/plot_marching_cubes.html>`_
|
|
2126
2258
|
algorithm to reconstruct a surface from the point-cloud like arrays.
|
|
2127
|
-
"""
|
|
2259
|
+
"""
|
|
2128
2260
|
try:
|
|
2129
2261
|
from skimage import measure
|
|
2130
2262
|
except ModuleNotFoundError as e:
|
|
@@ -2656,3 +2788,46 @@ def _add_vertical_coords(data: XArrayType) -> XArrayType:
|
|
|
2656
2788
|
data.coords["altitude"] = data.coords["altitude"].astype(dtype, copy=False)
|
|
2657
2789
|
|
|
2658
2790
|
return data
|
|
2791
|
+
|
|
2792
|
+
|
|
2793
|
+
def _lowmem_boundscheck(time: npt.NDArray[np.datetime64], da: xr.DataArray) -> None:
|
|
2794
|
+
"""Extra bounds check required with low-memory interpolation strategy.
|
|
2795
|
+
|
|
2796
|
+
Because the main loop in `_interp_lowmem` processes points between time steps
|
|
2797
|
+
in gridded data, it will never encounter points that are out-of-bounds in time
|
|
2798
|
+
and may fail to produce requested out-of-bounds errors.
|
|
2799
|
+
"""
|
|
2800
|
+
da_time = da["time"].to_numpy()
|
|
2801
|
+
if not np.all((time >= da_time.min()) & (time <= da_time.max())):
|
|
2802
|
+
axis = da.get_axis_num("time")
|
|
2803
|
+
msg = f"One of the requested xi is out of bounds in dimension {axis}"
|
|
2804
|
+
raise ValueError(msg)
|
|
2805
|
+
|
|
2806
|
+
|
|
2807
|
+
def _lowmem_masks(
|
|
2808
|
+
time: npt.NDArray[np.datetime64], t_met: npt.NDArray[np.datetime64]
|
|
2809
|
+
) -> Generator[npt.NDArray[np.bool_], None, None]:
|
|
2810
|
+
"""Generate sequence of masks for low-memory interpolation."""
|
|
2811
|
+
t_met_max = t_met.max()
|
|
2812
|
+
t_met_min = t_met.min()
|
|
2813
|
+
inbounds = (time >= t_met_min) & (time <= t_met_max)
|
|
2814
|
+
if not np.any(inbounds):
|
|
2815
|
+
return
|
|
2816
|
+
|
|
2817
|
+
earliest = np.nanmin(time)
|
|
2818
|
+
istart = 0 if earliest < t_met_min else np.flatnonzero(t_met <= earliest).max()
|
|
2819
|
+
latest = np.nanmax(time)
|
|
2820
|
+
iend = t_met.size - 1 if latest > t_met_max else np.flatnonzero(t_met >= latest).min()
|
|
2821
|
+
if istart == iend:
|
|
2822
|
+
yield inbounds
|
|
2823
|
+
return
|
|
2824
|
+
|
|
2825
|
+
# Sequence of masks covers elements in time in the interval [t_met[istart], t_met[iend]].
|
|
2826
|
+
# The first iteration masks elements in the interval [t_met[istart], t_met[istart+1]]
|
|
2827
|
+
# (inclusive of both endpoints).
|
|
2828
|
+
# Subsequent iterations mask elements in the interval (t_met[i], t_met[i+1]]
|
|
2829
|
+
# (inclusive of right endpoint only).
|
|
2830
|
+
for i in range(istart, iend):
|
|
2831
|
+
mask = ((time >= t_met[i]) if i == istart else (time > t_met[i])) & (time <= t_met[i + 1])
|
|
2832
|
+
if np.any(mask):
|
|
2833
|
+
yield mask
|
pycontrails/core/met_var.py
CHANGED
|
@@ -22,7 +22,7 @@ class MetVariable:
|
|
|
22
22
|
- `NCEP Grib v2 Code Table <https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-2.shtml>`_
|
|
23
23
|
|
|
24
24
|
Used for defining support parameters in a grib-like fashion.
|
|
25
|
-
"""
|
|
25
|
+
"""
|
|
26
26
|
|
|
27
27
|
#: Short variable name.
|
|
28
28
|
#: Chosen for greatest consistency between data sources.
|
pycontrails/core/models.py
CHANGED
|
@@ -11,7 +11,7 @@ import warnings
|
|
|
11
11
|
from abc import ABC, abstractmethod
|
|
12
12
|
from collections.abc import Sequence
|
|
13
13
|
from dataclasses import dataclass, fields
|
|
14
|
-
from typing import Any, NoReturn, TypeVar,
|
|
14
|
+
from typing import Any, NoReturn, TypeVar, overload
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
import numpy.typing as npt
|
|
@@ -30,13 +30,13 @@ from pycontrails.utils.types import type_guard
|
|
|
30
30
|
logger = logging.getLogger(__name__)
|
|
31
31
|
|
|
32
32
|
#: Model input source types
|
|
33
|
-
ModelInput =
|
|
33
|
+
ModelInput = MetDataset | GeoVectorDataset | Flight | Sequence[Flight] | None
|
|
34
34
|
|
|
35
35
|
#: Model output source types
|
|
36
|
-
ModelOutput =
|
|
36
|
+
ModelOutput = MetDataArray | MetDataset | GeoVectorDataset | Flight | list[Flight]
|
|
37
37
|
|
|
38
38
|
#: Model attribute source types
|
|
39
|
-
SourceType =
|
|
39
|
+
SourceType = MetDataset | GeoVectorDataset | Flight | Fleet
|
|
40
40
|
|
|
41
41
|
_Source = TypeVar("_Source")
|
|
42
42
|
|
|
@@ -453,7 +453,7 @@ class Model(ABC):
|
|
|
453
453
|
return Fleet.from_seq(source)
|
|
454
454
|
|
|
455
455
|
# Raise error if source is not a MetDataset or GeoVectorDataset
|
|
456
|
-
if not isinstance(source,
|
|
456
|
+
if not isinstance(source, MetDataset | GeoVectorDataset):
|
|
457
457
|
msg = f"Unknown source type: {type(source)}"
|
|
458
458
|
raise TypeError(msg)
|
|
459
459
|
|
|
Binary file
|
pycontrails/core/vector.py
CHANGED
|
@@ -847,7 +847,7 @@ class VectorDataset:
|
|
|
847
847
|
------
|
|
848
848
|
TypeError
|
|
849
849
|
If ``mask`` is not a boolean array.
|
|
850
|
-
"""
|
|
850
|
+
"""
|
|
851
851
|
self.data._validate_array(mask)
|
|
852
852
|
if mask.dtype != bool:
|
|
853
853
|
raise TypeError("Parameter `mask` must be a boolean array.")
|
|
@@ -983,7 +983,7 @@ class VectorDataset:
|
|
|
983
983
|
numeric_attrs = (
|
|
984
984
|
attr
|
|
985
985
|
for attr, val in self.attrs.items()
|
|
986
|
-
if (isinstance(val,
|
|
986
|
+
if (isinstance(val, int | float | np.number) and attr not in ignore_keys)
|
|
987
987
|
)
|
|
988
988
|
self.broadcast_attrs(numeric_attrs, overwrite)
|
|
989
989
|
|
|
@@ -1072,7 +1072,7 @@ class VectorDataset:
|
|
|
1072
1072
|
obj = obj.to_numpy()
|
|
1073
1073
|
|
|
1074
1074
|
# Convert numpy objects to python objects
|
|
1075
|
-
if isinstance(obj,
|
|
1075
|
+
if isinstance(obj, np.ndarray | np.generic):
|
|
1076
1076
|
|
|
1077
1077
|
# round time to unix seconds
|
|
1078
1078
|
if key == "time":
|
|
@@ -1166,7 +1166,7 @@ class VectorDataset:
|
|
|
1166
1166
|
attrs = {}
|
|
1167
1167
|
|
|
1168
1168
|
for k, v in {**obj, **obj_kwargs}.items():
|
|
1169
|
-
if isinstance(v,
|
|
1169
|
+
if isinstance(v, list | np.ndarray):
|
|
1170
1170
|
data[k] = v
|
|
1171
1171
|
else:
|
|
1172
1172
|
attrs[k] = v
|
|
@@ -1194,7 +1194,7 @@ class VectorDataset:
|
|
|
1194
1194
|
See Also
|
|
1195
1195
|
--------
|
|
1196
1196
|
:func:`numpy.array_split`
|
|
1197
|
-
"""
|
|
1197
|
+
"""
|
|
1198
1198
|
full_index = np.arange(self.size)
|
|
1199
1199
|
index_splits = np.array_split(full_index, n_splits)
|
|
1200
1200
|
for index in index_splits:
|
|
@@ -6,16 +6,6 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from pycontrails.utils import dependencies
|
|
8
8
|
|
|
9
|
-
try:
|
|
10
|
-
import skimage as ski
|
|
11
|
-
except ModuleNotFoundError as exc:
|
|
12
|
-
dependencies.raise_module_not_found_error(
|
|
13
|
-
name="landsat module",
|
|
14
|
-
package_name="scikit-image",
|
|
15
|
-
module_not_found_error=exc,
|
|
16
|
-
pycontrails_optional_package="sat",
|
|
17
|
-
)
|
|
18
|
-
|
|
19
9
|
|
|
20
10
|
def normalize(channel: np.ndarray) -> np.ndarray:
|
|
21
11
|
"""Normalize channel values to range [0, 1], preserving ``np.nan`` in output.
|
|
@@ -53,8 +43,17 @@ def equalize(channel: np.ndarray, **equalize_kwargs: Any) -> np.ndarray:
|
|
|
53
43
|
NaN values are converted to 0 before passing to :py:func:`ski.exposure.equalize_adapthist`
|
|
54
44
|
and may affect equalized values in the neighborhood where they occur.
|
|
55
45
|
"""
|
|
46
|
+
try:
|
|
47
|
+
import skimage.exposure
|
|
48
|
+
except ModuleNotFoundError as exc:
|
|
49
|
+
dependencies.raise_module_not_found_error(
|
|
50
|
+
name="landsat module",
|
|
51
|
+
package_name="scikit-image",
|
|
52
|
+
module_not_found_error=exc,
|
|
53
|
+
pycontrails_optional_package="sat",
|
|
54
|
+
)
|
|
56
55
|
return np.where(
|
|
57
56
|
np.isnan(channel),
|
|
58
57
|
np.nan,
|
|
59
|
-
|
|
58
|
+
skimage.exposure.equalize_adapthist(np.nan_to_num(channel, nan=0.0), **equalize_kwargs),
|
|
60
59
|
)
|
|
@@ -8,7 +8,7 @@ import logging
|
|
|
8
8
|
import pathlib
|
|
9
9
|
from collections.abc import Sequence
|
|
10
10
|
from datetime import datetime
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, TypeAlias
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
@@ -20,11 +20,13 @@ from pycontrails.utils.types import DatetimeLike
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
]
|
|
27
|
-
|
|
23
|
+
# https://github.com/python/mypy/issues/14824
|
|
24
|
+
TimeInput: TypeAlias = str | DatetimeLike | Sequence[str | DatetimeLike]
|
|
25
|
+
VariableInput = (
|
|
26
|
+
str | int | MetVariable | np.ndarray | Sequence[str | int | MetVariable | Sequence[MetVariable]]
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
PressureLevelInput = int | float | np.ndarray | Sequence[int | float]
|
|
28
30
|
|
|
29
31
|
#: NetCDF engine to use for parsing netcdf files
|
|
30
32
|
NETCDF_ENGINE: str = "netcdf4"
|
|
@@ -66,13 +68,13 @@ def parse_timesteps(time: TimeInput | None, freq: str | None = "1h") -> list[dat
|
|
|
66
68
|
------
|
|
67
69
|
ValueError
|
|
68
70
|
Raises when the time has len > 2 or when time elements fail to be parsed with pd.to_datetime
|
|
69
|
-
"""
|
|
71
|
+
"""
|
|
70
72
|
|
|
71
73
|
if time is None:
|
|
72
74
|
return []
|
|
73
75
|
|
|
74
76
|
# confirm input is tuple or list-like of length 2
|
|
75
|
-
if isinstance(time,
|
|
77
|
+
if isinstance(time, str | datetime | pd.Timestamp | np.datetime64):
|
|
76
78
|
time = (time, time)
|
|
77
79
|
elif len(time) == 1:
|
|
78
80
|
time = (time[0], time[0])
|
|
@@ -151,7 +153,7 @@ def parse_pressure_levels(
|
|
|
151
153
|
Raises ValueError if pressure level is not supported by ECMWF data source
|
|
152
154
|
"""
|
|
153
155
|
# Ensure pressure_levels is array-like
|
|
154
|
-
if isinstance(pressure_levels,
|
|
156
|
+
if isinstance(pressure_levels, int | float):
|
|
155
157
|
pressure_levels = [pressure_levels]
|
|
156
158
|
|
|
157
159
|
# Cast array-like to int dtype and sort
|
|
@@ -212,7 +214,7 @@ def parse_variables(variables: VariableInput, supported: list[MetVariable]) -> l
|
|
|
212
214
|
met_var_list: list[MetVariable] = []
|
|
213
215
|
|
|
214
216
|
# ensure input variables are a list of str
|
|
215
|
-
if isinstance(variables,
|
|
217
|
+
if isinstance(variables, str | int | MetVariable):
|
|
216
218
|
parsed_variables = [variables]
|
|
217
219
|
elif isinstance(variables, np.ndarray):
|
|
218
220
|
parsed_variables = variables.tolist()
|
|
@@ -257,7 +259,7 @@ def _find_match(
|
|
|
257
259
|
|
|
258
260
|
# list of MetVariable options
|
|
259
261
|
# here we extract the first MetVariable in var that is supported
|
|
260
|
-
elif isinstance(var,
|
|
262
|
+
elif isinstance(var, list | tuple):
|
|
261
263
|
for v in var:
|
|
262
264
|
# sanity check since we don't support other types as lists
|
|
263
265
|
if not isinstance(v, MetVariable):
|
|
@@ -349,7 +349,7 @@ class ERA5ModelLevel(ECMWFAPI):
|
|
|
349
349
|
unique_dates = set(t.strftime("%Y-%m-%d") for t in times)
|
|
350
350
|
unique_times = set(t.strftime("%H:%M:%S") for t in times)
|
|
351
351
|
# param 152 = log surface pressure, needed for metview level conversion
|
|
352
|
-
grib_params = set(self.variable_ecmwfids
|
|
352
|
+
grib_params = set((*self.variable_ecmwfids, 152))
|
|
353
353
|
common = {
|
|
354
354
|
"class": "ea",
|
|
355
355
|
"date": "/".join(sorted(unique_dates)),
|