pycontrails 0.58.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

Files changed (122) hide show
  1. pycontrails/__init__.py +70 -0
  2. pycontrails/_version.py +34 -0
  3. pycontrails/core/__init__.py +30 -0
  4. pycontrails/core/aircraft_performance.py +679 -0
  5. pycontrails/core/airports.py +228 -0
  6. pycontrails/core/cache.py +889 -0
  7. pycontrails/core/coordinates.py +174 -0
  8. pycontrails/core/fleet.py +483 -0
  9. pycontrails/core/flight.py +2185 -0
  10. pycontrails/core/flightplan.py +228 -0
  11. pycontrails/core/fuel.py +140 -0
  12. pycontrails/core/interpolation.py +702 -0
  13. pycontrails/core/met.py +2931 -0
  14. pycontrails/core/met_var.py +387 -0
  15. pycontrails/core/models.py +1321 -0
  16. pycontrails/core/polygon.py +549 -0
  17. pycontrails/core/rgi_cython.cp314-win_amd64.pyd +0 -0
  18. pycontrails/core/vector.py +2249 -0
  19. pycontrails/datalib/__init__.py +12 -0
  20. pycontrails/datalib/_met_utils/metsource.py +746 -0
  21. pycontrails/datalib/ecmwf/__init__.py +73 -0
  22. pycontrails/datalib/ecmwf/arco_era5.py +345 -0
  23. pycontrails/datalib/ecmwf/common.py +114 -0
  24. pycontrails/datalib/ecmwf/era5.py +554 -0
  25. pycontrails/datalib/ecmwf/era5_model_level.py +490 -0
  26. pycontrails/datalib/ecmwf/hres.py +804 -0
  27. pycontrails/datalib/ecmwf/hres_model_level.py +466 -0
  28. pycontrails/datalib/ecmwf/ifs.py +287 -0
  29. pycontrails/datalib/ecmwf/model_levels.py +435 -0
  30. pycontrails/datalib/ecmwf/static/model_level_dataframe_v20240418.csv +139 -0
  31. pycontrails/datalib/ecmwf/variables.py +268 -0
  32. pycontrails/datalib/geo_utils.py +261 -0
  33. pycontrails/datalib/gfs/__init__.py +28 -0
  34. pycontrails/datalib/gfs/gfs.py +656 -0
  35. pycontrails/datalib/gfs/variables.py +104 -0
  36. pycontrails/datalib/goes.py +757 -0
  37. pycontrails/datalib/himawari/__init__.py +27 -0
  38. pycontrails/datalib/himawari/header_struct.py +266 -0
  39. pycontrails/datalib/himawari/himawari.py +667 -0
  40. pycontrails/datalib/landsat.py +589 -0
  41. pycontrails/datalib/leo_utils/__init__.py +5 -0
  42. pycontrails/datalib/leo_utils/correction.py +266 -0
  43. pycontrails/datalib/leo_utils/landsat_metadata.py +300 -0
  44. pycontrails/datalib/leo_utils/search.py +250 -0
  45. pycontrails/datalib/leo_utils/sentinel_metadata.py +748 -0
  46. pycontrails/datalib/leo_utils/static/bq_roi_query.sql +6 -0
  47. pycontrails/datalib/leo_utils/vis.py +59 -0
  48. pycontrails/datalib/sentinel.py +650 -0
  49. pycontrails/datalib/spire/__init__.py +5 -0
  50. pycontrails/datalib/spire/exceptions.py +62 -0
  51. pycontrails/datalib/spire/spire.py +604 -0
  52. pycontrails/ext/bada.py +42 -0
  53. pycontrails/ext/cirium.py +14 -0
  54. pycontrails/ext/empirical_grid.py +140 -0
  55. pycontrails/ext/synthetic_flight.py +431 -0
  56. pycontrails/models/__init__.py +1 -0
  57. pycontrails/models/accf.py +425 -0
  58. pycontrails/models/apcemm/__init__.py +8 -0
  59. pycontrails/models/apcemm/apcemm.py +983 -0
  60. pycontrails/models/apcemm/inputs.py +226 -0
  61. pycontrails/models/apcemm/static/apcemm_yaml_template.yaml +183 -0
  62. pycontrails/models/apcemm/utils.py +437 -0
  63. pycontrails/models/cocip/__init__.py +29 -0
  64. pycontrails/models/cocip/cocip.py +2742 -0
  65. pycontrails/models/cocip/cocip_params.py +305 -0
  66. pycontrails/models/cocip/cocip_uncertainty.py +291 -0
  67. pycontrails/models/cocip/contrail_properties.py +1530 -0
  68. pycontrails/models/cocip/output_formats.py +2270 -0
  69. pycontrails/models/cocip/radiative_forcing.py +1260 -0
  70. pycontrails/models/cocip/radiative_heating.py +520 -0
  71. pycontrails/models/cocip/unterstrasser_wake_vortex.py +508 -0
  72. pycontrails/models/cocip/wake_vortex.py +396 -0
  73. pycontrails/models/cocip/wind_shear.py +120 -0
  74. pycontrails/models/cocipgrid/__init__.py +9 -0
  75. pycontrails/models/cocipgrid/cocip_grid.py +2552 -0
  76. pycontrails/models/cocipgrid/cocip_grid_params.py +138 -0
  77. pycontrails/models/dry_advection.py +602 -0
  78. pycontrails/models/emissions/__init__.py +21 -0
  79. pycontrails/models/emissions/black_carbon.py +599 -0
  80. pycontrails/models/emissions/emissions.py +1353 -0
  81. pycontrails/models/emissions/ffm2.py +336 -0
  82. pycontrails/models/emissions/static/default-engine-uids.csv +239 -0
  83. pycontrails/models/emissions/static/edb-gaseous-v29b-engines.csv +596 -0
  84. pycontrails/models/emissions/static/edb-nvpm-v29b-engines.csv +215 -0
  85. pycontrails/models/extended_k15.py +1327 -0
  86. pycontrails/models/humidity_scaling/__init__.py +37 -0
  87. pycontrails/models/humidity_scaling/humidity_scaling.py +1075 -0
  88. pycontrails/models/humidity_scaling/quantiles/era5-model-level-quantiles.pq +0 -0
  89. pycontrails/models/humidity_scaling/quantiles/era5-pressure-level-quantiles.pq +0 -0
  90. pycontrails/models/issr.py +210 -0
  91. pycontrails/models/pcc.py +326 -0
  92. pycontrails/models/pcr.py +154 -0
  93. pycontrails/models/ps_model/__init__.py +18 -0
  94. pycontrails/models/ps_model/ps_aircraft_params.py +381 -0
  95. pycontrails/models/ps_model/ps_grid.py +701 -0
  96. pycontrails/models/ps_model/ps_model.py +1000 -0
  97. pycontrails/models/ps_model/ps_operational_limits.py +525 -0
  98. pycontrails/models/ps_model/static/ps-aircraft-params-20250328.csv +69 -0
  99. pycontrails/models/ps_model/static/ps-synonym-list-20250328.csv +104 -0
  100. pycontrails/models/sac.py +442 -0
  101. pycontrails/models/tau_cirrus.py +183 -0
  102. pycontrails/physics/__init__.py +1 -0
  103. pycontrails/physics/constants.py +117 -0
  104. pycontrails/physics/geo.py +1138 -0
  105. pycontrails/physics/jet.py +968 -0
  106. pycontrails/physics/static/iata-cargo-load-factors-20250221.csv +74 -0
  107. pycontrails/physics/static/iata-passenger-load-factors-20250221.csv +74 -0
  108. pycontrails/physics/thermo.py +551 -0
  109. pycontrails/physics/units.py +472 -0
  110. pycontrails/py.typed +0 -0
  111. pycontrails/utils/__init__.py +1 -0
  112. pycontrails/utils/dependencies.py +66 -0
  113. pycontrails/utils/iteration.py +13 -0
  114. pycontrails/utils/json.py +187 -0
  115. pycontrails/utils/temp.py +50 -0
  116. pycontrails/utils/types.py +163 -0
  117. pycontrails-0.58.0.dist-info/METADATA +180 -0
  118. pycontrails-0.58.0.dist-info/RECORD +122 -0
  119. pycontrails-0.58.0.dist-info/WHEEL +5 -0
  120. pycontrails-0.58.0.dist-info/licenses/LICENSE +178 -0
  121. pycontrails-0.58.0.dist-info/licenses/NOTICE +43 -0
  122. pycontrails-0.58.0.dist-info/top_level.txt +3 -0
@@ -0,0 +1,2931 @@
1
+ """Meteorology data models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ import logging
8
+ import pathlib
9
+ import sys
10
+ import typing
11
+ import warnings
12
+ from abc import ABC, abstractmethod
13
+ from collections.abc import (
14
+ Generator,
15
+ Hashable,
16
+ Iterable,
17
+ Iterator,
18
+ Mapping,
19
+ MutableMapping,
20
+ Sequence,
21
+ )
22
+ from contextlib import ExitStack
23
+ from datetime import datetime
24
+ from typing import (
25
+ TYPE_CHECKING,
26
+ Any,
27
+ Generic,
28
+ Literal,
29
+ Self,
30
+ TypeVar,
31
+ overload,
32
+ )
33
+
34
+ if sys.version_info >= (3, 12):
35
+ from typing import override
36
+ else:
37
+ from typing_extensions import override
38
+
39
+ import numpy as np
40
+ import numpy.typing as npt
41
+ import pandas as pd
42
+ import xarray as xr
43
+
44
+ from pycontrails.core import interpolation
45
+ from pycontrails.core import vector as vector_module
46
+ from pycontrails.core.cache import CacheStore, DiskCacheStore
47
+ from pycontrails.core.met_var import AirPressure, Altitude, MetVariable
48
+ from pycontrails.physics import units
49
+ from pycontrails.utils import dependencies
50
+ from pycontrails.utils import temp as temp_module
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ # optional imports
55
+ if TYPE_CHECKING:
56
+ import open3d as o3d
57
+
58
+ XArrayType = TypeVar("XArrayType", xr.Dataset, xr.DataArray)
59
+ MetDataType = TypeVar("MetDataType", "MetDataset", "MetDataArray")
60
+ DatasetType = TypeVar("DatasetType", xr.Dataset, "MetDataset")
61
+
62
+ COORD_DTYPE = np.float64
63
+
64
+
65
+ class MetBase(ABC, Generic[XArrayType]):
66
+ """Abstract class for building Meteorology Data handling classes.
67
+
68
+ All support here should be generic to work on xr.DataArray
69
+ and xr.Dataset.
70
+ """
71
+
72
+ __slots__ = ("cachestore", "data")
73
+
74
+ #: DataArray or Dataset
75
+ data: XArrayType
76
+
77
+ #: Cache datastore to use for :meth:`save` or :meth:`load`
78
+ cachestore: CacheStore | None
79
+
80
+ #: Default dimension order for DataArray or Dataset (x, y, z, t)
81
+ dim_order = (
82
+ "longitude",
83
+ "latitude",
84
+ "level",
85
+ "time",
86
+ )
87
+
88
+ @classmethod
89
+ def _from_fastpath(cls, data: XArrayType, cachestore: CacheStore | None = None) -> Self:
90
+ """Create new instance from consistent data.
91
+
92
+ This is a low-level method that bypasses the standard constructor in certain
93
+ special cases. It is intended for internal use only.
94
+
95
+ In essence, this method skips any validation from __init__ and directly sets
96
+ ``data`` and ``attrs``. This is useful when creating a new instance from an existing
97
+ instance the data has already been validated.
98
+ """
99
+ obj = cls.__new__(cls)
100
+ obj.data = data
101
+ obj.cachestore = cachestore
102
+ return obj
103
+
104
+ def __repr__(self) -> str:
105
+ data = getattr(self, "data", None)
106
+ return (
107
+ f"{self.__class__.__name__} with data:\n\n{data.__repr__() if data is not None else ''}"
108
+ )
109
+
110
+ def _repr_html_(self) -> str:
111
+ try:
112
+ return f"<b>{type(self).__name__}</b> with data:<br/ ><br/> {self.data._repr_html_()}"
113
+ except AttributeError:
114
+ return f"<b>{type(self).__name__}</b> without data"
115
+
116
+ def _validate_dim_contains_coords(self) -> None:
117
+ """Check that data contains four temporal-spatial coordinates.
118
+
119
+ Raises
120
+ ------
121
+ ValueError
122
+ If data does not contain all four coordinates (longitude, latitude, level, time).
123
+ """
124
+ missing = set(self.dim_order).difference(self.data.dims)
125
+ if not missing:
126
+ return
127
+
128
+ dim = sorted(missing)
129
+ msg = f"Meteorology data must contain dimension(s): {dim}."
130
+ if "level" in dim:
131
+ msg += (
132
+ " For single level data, set 'level' coordinate to constant -1 "
133
+ "using `ds = ds.expand_dims({'level': [-1]})`"
134
+ )
135
+ raise ValueError(msg)
136
+
137
+ def _validate_longitude(self) -> None:
138
+ """Check longitude bounds.
139
+
140
+ Assumes ``longitude`` dimension is already sorted.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If longitude values are not contained in the interval [-180, 180].
146
+ """
147
+ longitude = self.indexes["longitude"].to_numpy()
148
+ if longitude.dtype != COORD_DTYPE:
149
+ msg = f"Longitude values must have dtype {COORD_DTYPE}. Instantiate with 'copy=True'."
150
+ raise ValueError(msg)
151
+
152
+ if self.is_wrapped:
153
+ # Relax verification if the longitude has already been processed and wrapped
154
+ if longitude[-1] > 360.0:
155
+ raise ValueError(
156
+ "Longitude contains values > 360. Shift to WGS84 with "
157
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
158
+ )
159
+ if longitude[0] < -360.0:
160
+ raise ValueError(
161
+ "Longitude contains values < -360. Shift to WGS84 with "
162
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
163
+ )
164
+ return
165
+
166
+ # Strict!
167
+ if longitude[-1] > 180.0:
168
+ raise ValueError(
169
+ "Longitude contains values > 180. Shift to WGS84 with "
170
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
171
+ )
172
+ if longitude[0] < -180.0:
173
+ raise ValueError(
174
+ "Longitude contains values < -180. Shift to WGS84 with "
175
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
176
+ )
177
+
178
+ def _validate_latitude(self) -> None:
179
+ """Check latitude bounds.
180
+
181
+ Assumes ``latitude`` dimension is already sorted.
182
+
183
+ Raises
184
+ ------
185
+ ValueError
186
+ If latitude values are not contained in the interval [-90, 90].
187
+ """
188
+ latitude = self.indexes["latitude"].to_numpy()
189
+ if latitude.dtype != COORD_DTYPE:
190
+ msg = f"Latitude values must have dtype {COORD_DTYPE}. Instantiate with 'copy=True'."
191
+ raise ValueError(msg)
192
+
193
+ if latitude[0] < -90.0:
194
+ raise ValueError(
195
+ "Latitude contains values < -90 . "
196
+ "Latitude values must be contained in the interval [-90, 90]."
197
+ )
198
+ if latitude[-1] > 90.0:
199
+ raise ValueError(
200
+ "Latitude contains values > 90 . "
201
+ "Latitude values must be contained in the interval [-90, 90]."
202
+ )
203
+
204
+ def _validate_sorting(self) -> None:
205
+ """Check that all coordinates are sorted.
206
+
207
+ Raises
208
+ ------
209
+ ValueError
210
+ If one of the coordinates is not sorted or contains duplicate values.
211
+ """
212
+ indexes = self.indexes
213
+ for coord in self.dim_order:
214
+ arr = indexes[coord]
215
+ d = np.diff(arr)
216
+ zero = np.zeros((), dtype=d.dtype) # ensure same dtype
217
+
218
+ if np.any(d <= zero):
219
+ if np.any(d == zero):
220
+ msg = f"Coordinate '{coord}' contains duplicate values."
221
+ else:
222
+ msg = f"Coordinate '{coord}' not sorted."
223
+
224
+ msg += " Instantiate with 'copy=True'."
225
+ raise ValueError(msg)
226
+
227
+ def _validate_transpose(self) -> None:
228
+ """Check that data is transposed according to :attr:`dim_order`."""
229
+
230
+ def _check_da(da: xr.DataArray, key: Hashable | None = None) -> None:
231
+ if da.dims != self.dim_order:
232
+ if key is not None:
233
+ msg = (
234
+ f"Data dimension not transposed on variable '{key}'. "
235
+ "Instantiate with 'copy=True'."
236
+ )
237
+ else:
238
+ msg = "Data dimension not transposed. Instantiate with 'copy=True'."
239
+ raise ValueError(msg)
240
+
241
+ data = self.data
242
+ if isinstance(data, xr.DataArray):
243
+ _check_da(data)
244
+ return
245
+
246
+ for key, da in self.data.items():
247
+ _check_da(da, key)
248
+
249
+ def _validate_dims(self) -> None:
250
+ """Apply all validators."""
251
+ self._validate_dim_contains_coords()
252
+
253
+ # Apply this one first: validate_longitude and validate_latitude assume sorted
254
+ self._validate_sorting()
255
+ self._validate_longitude()
256
+ self._validate_latitude()
257
+ self._validate_transpose()
258
+ if self.data["level"].dtype != COORD_DTYPE:
259
+ msg = f"Level values must have dtype {COORD_DTYPE}. Instantiate with 'copy=True'."
260
+ raise ValueError(msg)
261
+
262
+ def _preprocess_dims(self, wrap_longitude: bool) -> None:
263
+ """Confirm DataArray or Dataset include required dimension in a consistent format.
264
+
265
+ Expects DataArray or Dataset to contain dimensions ``latitude`, ``longitude``, ``time``,
266
+ and ``level`` (in hPa/mbar).
267
+ Adds additional coordinate variables ``air_pressure`` and ``altitude`` coordinates
268
+ mapped to "level" dimension if "level" > 0.
269
+
270
+ Set ``level`` to -1 to signify single level.
271
+
272
+ .. versionchanged:: 0.40.0
273
+
274
+ All coordinate data (longitude, latitude, level) are promoted to ``float64``.
275
+ Auxiliary coordinates (altitude and air_pressure) are now cast to the same
276
+ dtype as the underlying grid data.
277
+
278
+ .. versionchanged:: 0.58.0
279
+
280
+ Duplicate dimension values are dropped, keeping the first occurrence.
281
+
282
+
283
+ Parameters
284
+ ----------
285
+ wrap_longitude : bool
286
+ If True, ensure longitude values cover the interval ``[-180, 180]``.
287
+
288
+ Raises
289
+ ------
290
+ ValueError
291
+ Raises if required dimension names are not found
292
+ """
293
+ self._validate_dim_contains_coords()
294
+
295
+ # Ensure spatial coordinates all have dtype COORD_DTYPE
296
+ indexes = self.indexes
297
+ for coord in ("longitude", "latitude", "level"):
298
+ arr = indexes[coord].to_numpy()
299
+ if arr.dtype != COORD_DTYPE:
300
+ self.data[coord] = arr.astype(COORD_DTYPE)
301
+
302
+ # Ensure time is np.datetime64[ns]
303
+ self.data["time"] = self.data["time"].astype("datetime64[ns]", copy=False)
304
+
305
+ # sortby to ensure each coordinate has ascending order
306
+ self.data = self.data.sortby(list(self.dim_order), ascending=True)
307
+
308
+ # Drop any duplicated dimension values
309
+ indexes = self.indexes
310
+ for coord in self.dim_order:
311
+ arr = indexes[coord]
312
+ d = np.diff(arr)
313
+ zero = np.zeros((), dtype=d.dtype) # ensure same dtype
314
+
315
+ if np.any(d == zero):
316
+ # Remove duplicates
317
+ filt = np.r_[True, d > zero] # prepend True keeps the first occurrence
318
+ self.data = self.data.isel({coord: filt})
319
+
320
+ if not self.is_wrapped:
321
+ # Ensure longitude is contained in interval [-180, 180)
322
+ # If longitude has value at 180, we might not want to shift it?
323
+ lon = self.indexes["longitude"].to_numpy()
324
+
325
+ # This longitude shifting can give rise to precision errors with float32
326
+ # Only shift if necessary
327
+ if np.any(lon >= 180.0) or np.any(lon < -180.0):
328
+ self.data = shift_longitude(self.data)
329
+ else:
330
+ self.data = self.data.sortby("longitude", ascending=True)
331
+
332
+ # wrap longitude, if requested
333
+ if wrap_longitude:
334
+ self.data = _wrap_longitude(self.data)
335
+
336
+ self._validate_longitude()
337
+ self._validate_latitude()
338
+
339
+ # transpose to have ordering (x, y, z, t, ...)
340
+ dim_order = [*self.dim_order, *(d for d in self.data.dims if d not in self.dim_order)]
341
+ self.data = self.data.transpose(*dim_order)
342
+
343
+ # single level data
344
+ if self.is_single_level:
345
+ # add level attributes to reflect surface level
346
+ level_attrs = self.data["level"].attrs
347
+ if not level_attrs:
348
+ level_attrs.update(units="", long_name="Single Level")
349
+ return
350
+
351
+ self.data = _add_vertical_coords(self.data)
352
+
353
+ @property
354
+ def hash(self) -> str:
355
+ """Generate a unique hash for this met instance.
356
+
357
+ Note this is not as robust as it could be since `repr`
358
+ cuts off.
359
+
360
+ Returns
361
+ -------
362
+ str
363
+ Unique hash for met instance (sha1)
364
+ """
365
+ _hash = repr(self.data)
366
+ return hashlib.sha1(bytes(_hash, "utf-8")).hexdigest()
367
+
368
+ @property
369
+ @abstractmethod
370
+ def size(self) -> int:
371
+ """Return the size of (each) array in underlying :attr:`data`.
372
+
373
+ Returns
374
+ -------
375
+ int
376
+ Total number of grid points in underlying data
377
+ """
378
+
379
+ @property
380
+ @abstractmethod
381
+ def shape(self) -> tuple[int, int, int, int]:
382
+ """Return the shape of the dimensions.
383
+
384
+ Returns
385
+ -------
386
+ tuple[int, int, int, int]
387
+ Shape of underlying data
388
+ """
389
+
390
+ @property
391
+ def coords(self) -> dict[str, np.ndarray]:
392
+ """Get coordinates of underlying :attr:`data` coordinates.
393
+
394
+ Only return non-dimension coordinates.
395
+
396
+ See:
397
+ http://xarray.pydata.org/en/stable/user-guide/data-structures.html#coordinates
398
+
399
+ Returns
400
+ -------
401
+ dict[str, np.ndarray]
402
+ Dictionary of coordinates
403
+ """
404
+ variables = self.indexes
405
+ return {
406
+ "longitude": variables["longitude"].to_numpy(),
407
+ "latitude": variables["latitude"].to_numpy(),
408
+ "level": variables["level"].to_numpy(),
409
+ "time": variables["time"].to_numpy(),
410
+ }
411
+
412
+ @property
413
+ def indexes(self) -> dict[Hashable, pd.Index]:
414
+ """Low level access to underlying :attr:`data` indexes.
415
+
416
+ This method is typically is faster for accessing coordinate indexes.
417
+
418
+ .. versionadded:: 0.25.2
419
+
420
+ Returns
421
+ -------
422
+ dict[Hashable, pd.Index]
423
+ Dictionary of indexes.
424
+
425
+ Examples
426
+ --------
427
+ >>> from pycontrails.datalib.ecmwf import ERA5
428
+ >>> times = (datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 13))
429
+ >>> variables = "air_temperature", "specific_humidity"
430
+ >>> levels = [200, 300]
431
+ >>> era5 = ERA5(times, variables, levels)
432
+ >>> mds = era5.open_metdataset()
433
+ >>> mds.indexes["level"].to_numpy()
434
+ array([200., 300.])
435
+
436
+ >>> mda = mds["air_temperature"]
437
+ >>> mda.indexes["level"].to_numpy()
438
+ array([200., 300.])
439
+ """
440
+ return {k: v.index for k, v in self.data._indexes.items()} # type: ignore[attr-defined]
441
+
442
+ @property
443
+ def is_wrapped(self) -> bool:
444
+ """Check if the longitude dimension covers the closed interval ``[-180, 180]``.
445
+
446
+ Assumes the longitude dimension is sorted (this is established by the
447
+ :class:`MetDataset` or :class:`MetDataArray` constructor).
448
+
449
+ .. versionchanged:: 0.26.0
450
+
451
+ The previous implementation checked for the minimum and maximum longitude
452
+ dimension values to be duplicated. The current implementation only checks for
453
+ that the interval ``[-180, 180]`` is covered by the longitude dimension. The
454
+ :func:`pycontrails.physics.geo.advect_longitude` is designed for compatibility
455
+ with this convention.
456
+
457
+ Returns
458
+ -------
459
+ bool
460
+ True if longitude coordinates cover ``[-180, 180]``
461
+
462
+ See Also
463
+ --------
464
+ :func:`pycontrails.physics.geo.advect_longitude`
465
+ """
466
+ longitude = self.indexes["longitude"].to_numpy()
467
+ return _is_wrapped(longitude)
468
+
469
+ @property
470
+ def is_single_level(self) -> bool:
471
+ """Check if instance contains "single level" or "surface level" data.
472
+
473
+ This method checks if ``level`` dimension contains a single value equal
474
+ to -1, the pycontrails convention for surface only data.
475
+
476
+ Returns
477
+ -------
478
+ bool
479
+ If instance contains single level data.
480
+ """
481
+ level = self.indexes["level"].to_numpy()
482
+ return len(level) == 1 and level[0] == -1
483
+
484
+ @abstractmethod
485
+ def broadcast_coords(self, name: str) -> xr.DataArray:
486
+ """Broadcast coordinates along other dimensions.
487
+
488
+ Parameters
489
+ ----------
490
+ name : str
491
+ Coordinate/dimension name to broadcast.
492
+ Can be a dimension or non-dimension coordinates.
493
+
494
+ Returns
495
+ -------
496
+ xr.DataArray
497
+ DataArray of the coordinate broadcasted along all other dimensions.
498
+ The DataArray will have the same shape as the gridded data.
499
+ """
500
+
501
+ def _save(self, dataset: xr.Dataset, **kwargs: Any) -> list[str]:
502
+ """Save dataset to netcdf files named for the met hash and hour.
503
+
504
+ Does not yet save in parallel.
505
+
506
+ .. versionchanged:: 0.34.1
507
+
508
+ If :attr:`cachestore` is None, this method assigns it
509
+ to new :class:`DiskCacheStore`.
510
+
511
+ Parameters
512
+ ----------
513
+ dataset : xr.Dataset
514
+ Dataset to save
515
+ **kwargs
516
+ Keyword args passed directly on to :func:`xarray.save_mfdataset`
517
+
518
+ Returns
519
+ -------
520
+ list[str]
521
+ List of filenames saved
522
+ """
523
+ self.cachestore = self.cachestore or DiskCacheStore()
524
+
525
+ # group by hour and save one dataset for each hour to temp file
526
+ times, datasets = zip(*dataset.groupby("time", squeeze=False), strict=True)
527
+
528
+ # Open ExitStack to control temp_file context manager
529
+ with ExitStack() as stack:
530
+ xarray_temp_filenames = [stack.enter_context(temp_module.temp_file()) for _ in times]
531
+ xr.save_mfdataset(datasets, xarray_temp_filenames, **kwargs)
532
+
533
+ # set filenames by hash
534
+ filenames = [f"{self.hash}-{t_idx}.nc" for t_idx, _ in enumerate(times)]
535
+
536
+ # put each hourly file into cache
537
+ self.cachestore.put_multiple(xarray_temp_filenames, filenames)
538
+
539
+ return filenames
540
+
541
+ def __len__(self) -> int:
542
+ return self.data.__len__()
543
+
544
+ @property
545
+ def attrs(self) -> dict[str, Any]:
546
+ """Pass through to :attr:`self.data.attrs`."""
547
+ return self.data.attrs
548
+
549
+ def downselect(self, bbox: tuple[float, ...]) -> Self:
550
+ """Downselect met data within spatial bounding box.
551
+
552
+ Parameters
553
+ ----------
554
+ bbox : list[float]
555
+ List of coordinates defining a spatial bounding box in WGS84 coordinates.
556
+ For 2D queries, list is [west, south, east, north].
557
+ For 3D queries, list is [west, south, min-level, east, north, max-level]
558
+ with level defined in [:math:`hPa`].
559
+
560
+ Returns
561
+ -------
562
+ Self
563
+ Return downselected data
564
+ """
565
+ data = downselect(self.data, bbox)
566
+ return type(self)._from_fastpath(data, cachestore=self.cachestore)
567
+
568
+ @property
569
+ def is_zarr(self) -> bool:
570
+ """Check if underlying :attr:`data` is sourced from a Zarr group.
571
+
572
+ Implementation is very brittle, and may break as external libraries change.
573
+
574
+ Some ``dask`` intermediate artifact is cached when this is called. Typically,
575
+ subsequent calls to this method are much faster than the initial call.
576
+
577
+ .. versionadded:: 0.26.0
578
+
579
+ Returns
580
+ -------
581
+ bool
582
+ If ``data`` is based on a Zarr group.
583
+ """
584
+ return _is_zarr(self.data)
585
+
586
+ def downselect_met(
587
+ self,
588
+ met: MetDataType,
589
+ *,
590
+ longitude_buffer: tuple[float, float] = (0.0, 0.0),
591
+ latitude_buffer: tuple[float, float] = (0.0, 0.0),
592
+ level_buffer: tuple[float, float] = (0.0, 0.0),
593
+ time_buffer: tuple[np.timedelta64, np.timedelta64] = (
594
+ np.timedelta64(0, "h"),
595
+ np.timedelta64(0, "h"),
596
+ ),
597
+ ) -> MetDataType:
598
+ """Downselect ``met`` to encompass a spatiotemporal region of the data.
599
+
600
+ .. warning::
601
+
602
+ This method is analogous to :meth:`GeoVectorDataset.downselect_met`.
603
+ It does not change the instance data, but instead operates on the
604
+ ``met`` input. This method is different from :meth:`downselect` which
605
+ operates on the instance data.
606
+
607
+ .. versionchanged:: 0.54.5
608
+
609
+ Data is no longer copied when downselecting.
610
+
611
+ Parameters
612
+ ----------
613
+ met : MetDataset | MetDataArray
614
+ MetDataset or MetDataArray to downselect.
615
+ longitude_buffer : tuple[float, float], optional
616
+ Extend longitude domain past by ``longitude_buffer[0]`` on the low side
617
+ and ``longitude_buffer[1]`` on the high side.
618
+ Units must be the same as class coordinates.
619
+ Defaults to ``(0, 0)`` degrees.
620
+ latitude_buffer : tuple[float, float], optional
621
+ Extend latitude domain past by ``latitude_buffer[0]`` on the low side
622
+ and ``latitude_buffer[1]`` on the high side.
623
+ Units must be the same as class coordinates.
624
+ Defaults to ``(0, 0)`` degrees.
625
+ level_buffer : tuple[float, float], optional
626
+ Extend level domain past by ``level_buffer[0]`` on the low side
627
+ and ``level_buffer[1]`` on the high side.
628
+ Units must be the same as class coordinates.
629
+ Defaults to ``(0, 0)`` [:math:`hPa`].
630
+ time_buffer : tuple[np.timedelta64, np.timedelta64], optional
631
+ Extend time domain past by ``time_buffer[0]`` on the low side
632
+ and ``time_buffer[1]`` on the high side.
633
+ Units must be the same as class coordinates.
634
+ Defaults to ``(np.timedelta64(0, "h"), np.timedelta64(0, "h"))``.
635
+
636
+ Returns
637
+ -------
638
+ MetDataset | MetDataArray
639
+ Copy of downselected MetDataset or MetDataArray.
640
+ """
641
+ indexes = self.indexes
642
+ lon = indexes["longitude"].to_numpy()
643
+ lat = indexes["latitude"].to_numpy()
644
+ level = indexes["level"].to_numpy()
645
+ time = indexes["time"].to_numpy()
646
+
647
+ vector = vector_module.GeoVectorDataset(
648
+ longitude=[lon.min(), lon.max()],
649
+ latitude=[lat.min(), lat.max()],
650
+ level=[level.min(), level.max()],
651
+ time=[time.min(), time.max()],
652
+ )
653
+
654
+ return vector.downselect_met(
655
+ met,
656
+ longitude_buffer=longitude_buffer,
657
+ latitude_buffer=latitude_buffer,
658
+ level_buffer=level_buffer,
659
+ time_buffer=time_buffer,
660
+ )
661
+
662
+ def wrap_longitude(self) -> Self:
663
+ """Wrap longitude coordinates.
664
+
665
+ Returns
666
+ -------
667
+ Self
668
+ Copy of instance with wrapped longitude values.
669
+ Returns copy of data when longitude values are already wrapped
670
+ """
671
+ return type(self)._from_fastpath(_wrap_longitude(self.data), cachestore=self.cachestore)
672
+
673
+ def copy(self) -> Self:
674
+ """Create a shallow copy of the current class.
675
+
676
+ See :meth:`xarray.Dataset.copy` for reference.
677
+
678
+ Returns
679
+ -------
680
+ Self
681
+ Copy of the current class
682
+ """
683
+ return type(self)._from_fastpath(self.data.copy(), cachestore=self.cachestore)
684
+
685
+
686
+ class MetDataset(MetBase):
687
+ """Meteorological dataset with multiple variables.
688
+
689
+ Composition around xr.Dataset to enforce certain
690
+ variables and dimensions for internal usage
691
+
692
+ Parameters
693
+ ----------
694
+ data : xr.Dataset
695
+ :class:`xarray.Dataset` containing meteorological variables and coordinates
696
+ cachestore : :class:`CacheStore`, optional
697
+ Cache datastore for staging intermediates with :meth:`save`.
698
+ Defaults to None.
699
+ wrap_longitude : bool, optional
700
+ Wrap data along the longitude dimension. If True, duplicate and shift longitude
701
+ values (ie, -180 -> 180) to ensure that the longitude dimension covers the entire
702
+ interval ``[-180, 180]``. Defaults to False.
703
+ copy : bool, optional
704
+ Copy data on construction. Defaults to True.
705
+ attrs : dict[str, Any], optional
706
+ Attributes to add to :attr:`data.attrs`. Defaults to None.
707
+ Generally, pycontrails :class:`Models` may use the following attributes:
708
+
709
+ - ``provider``: Name of the data provider (e.g. "ECMWF").
710
+ - ``dataset``: Name of the dataset (e.g. "ERA5").
711
+ - ``product``: Name of the product type (e.g. "reanalysis").
712
+
713
+ **attrs_kwargs : Any
714
+ Keyword arguments to add to :attr:`data.attrs`. Defaults to None.
715
+
716
+ Examples
717
+ --------
718
+ >>> import numpy as np
719
+ >>> import pandas as pd
720
+ >>> import xarray as xr
721
+ >>> from pycontrails.datalib.ecmwf import ERA5
722
+
723
+ >>> time = ("2022-03-01T00", "2022-03-01T02")
724
+ >>> variables = ["air_temperature", "specific_humidity"]
725
+ >>> pressure_levels = [200, 250, 300]
726
+ >>> era5 = ERA5(time, variables, pressure_levels)
727
+
728
+ >>> # Open directly as `MetDataset`
729
+ >>> met = era5.open_metdataset()
730
+ >>> # Use `data` attribute to access `xarray` object
731
+ >>> assert isinstance(met.data, xr.Dataset)
732
+
733
+ >>> # Alternatively, open with `xarray` and cast to `MetDataset`
734
+ >>> ds = xr.open_mfdataset(era5._cachepaths)
735
+ >>> met = MetDataset(ds)
736
+
737
+ >>> # Access sub-`DataArrays`
738
+ >>> mda = met["t"] # `MetDataArray` instance, needed for interpolation operations
739
+ >>> da = mda.data # Underlying `xarray` object
740
+
741
+ >>> # Check out a few values
742
+ >>> da[5:8, 5:8, 1, 1].values
743
+ array([[224.08959005, 224.41374427, 224.75945349],
744
+ [224.09456429, 224.42037658, 224.76525676],
745
+ [224.10036756, 224.42617985, 224.77106004]])
746
+
747
+ >>> # Mean temperature over entire array
748
+ >>> da.mean().load().item()
749
+ 223.5083
750
+ """
751
+
752
+ __slots__ = ()
753
+
754
+ data: xr.Dataset
755
+
756
+ def __init__(
757
+ self,
758
+ data: xr.Dataset,
759
+ cachestore: CacheStore | None = None,
760
+ wrap_longitude: bool = False,
761
+ copy: bool = True,
762
+ attrs: dict[str, Any] | None = None,
763
+ **attrs_kwargs: Any,
764
+ ) -> None:
765
+ self.cachestore = cachestore
766
+
767
+ data.attrs.update(attrs or {}, **attrs_kwargs)
768
+
769
+ # if input is already a Dataset, copy into data
770
+ if not isinstance(data, xr.Dataset):
771
+ raise TypeError("Input 'data' must be an xarray Dataset")
772
+
773
+ # copy Dataset into data
774
+ if copy:
775
+ self.data = data.copy()
776
+ self._preprocess_dims(wrap_longitude)
777
+
778
+ else:
779
+ if wrap_longitude:
780
+ raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
781
+ self.data = data
782
+ self._validate_dims()
783
+ if not self.is_single_level:
784
+ self.data = _add_vertical_coords(self.data)
785
+
786
+ def __getitem__(self, key: Hashable) -> MetDataArray:
787
+ """Return DataArray of variable ``key`` cast to a :class:`MetDataArray` object.
788
+
789
+ Parameters
790
+ ----------
791
+ key : Hashable
792
+ Variable name
793
+
794
+ Returns
795
+ -------
796
+ MetDataArray
797
+ MetDataArray instance associated to :attr:`data` variable `key`
798
+
799
+ Raises
800
+ ------
801
+ KeyError
802
+ If ``key`` not found in :attr:`data`
803
+ """
804
+ try:
805
+ da = self.data[key]
806
+ except KeyError as e:
807
+ raise KeyError(
808
+ f"Variable {key} not found. Available variables: {', '.join(self.data.data_vars)}. "
809
+ "To get items (e.g. 'time' or 'level') from underlying xr.Dataset object, "
810
+ "use the 'data' attribute."
811
+ ) from e
812
+ return MetDataArray._from_fastpath(da)
813
+
814
+ def get(self, key: str, default_value: Any = None) -> Any:
815
+ """Shortcut to :meth:`data.get(k, v)` method.
816
+
817
+ Parameters
818
+ ----------
819
+ key : str
820
+ Key to get from :attr:`data`
821
+ default_value : Any, optional
822
+ Return `default_value` if `key` not in :attr:`data`, by default `None`
823
+
824
+ Returns
825
+ -------
826
+ Any
827
+ Values returned from :attr:`data.get(key, default_value)`
828
+ """
829
+ return self.data.get(key, default_value)
830
+
831
+ def __setitem__(
832
+ self,
833
+ key: Hashable | list[Hashable] | Mapping,
834
+ value: Any,
835
+ ) -> None:
836
+ """Shortcut to set data variable on :attr:`data`.
837
+
838
+ Warns if ``key`` is already present in dataset.
839
+
840
+ Parameters
841
+ ----------
842
+ key : Hashable | list[Hashable] | Mapping
843
+ Variable name
844
+ value : Any
845
+ Value to set to variable names
846
+
847
+ See Also
848
+ --------
849
+ - :class:`xarray.Dataset.__setitem__`
850
+ """
851
+
852
+ # pull data of MetDataArray value
853
+ if isinstance(value, MetDataArray):
854
+ value = value.data
855
+
856
+ # warn if key is already in Dataset
857
+ override_keys: list[Hashable] = []
858
+
859
+ if isinstance(key, Hashable):
860
+ if key in self:
861
+ override_keys = [key]
862
+
863
+ # xarray.core.utils.is_dict_like
864
+ # https://github.com/pydata/xarray/blob/4cae8d0ec04195291b2315b1f21d846c2bad61ff/xarray/core/utils.py#L244
865
+ elif xr.core.utils.is_dict_like(key) or isinstance(key, list):
866
+ override_keys = [k for k in key if k in self.data]
867
+
868
+ if override_keys:
869
+ warnings.warn(
870
+ f"Overwriting data in keys `{override_keys}`. "
871
+ "Use `.update(...)` to suppress warning."
872
+ )
873
+
874
+ self.data.__setitem__(key, value)
875
+
876
+ def update(self, other: MutableMapping | None = None, **kwargs: Any) -> None:
877
+ """Shortcut to :meth:`data.update`.
878
+
879
+ See :meth:`xarray.Dataset.update` for reference.
880
+
881
+ Parameters
882
+ ----------
883
+ other : MutableMapping
884
+ Variables with which to update this dataset
885
+ **kwargs : Any
886
+ Variables defined by keyword arguments. If a variable exists both in
887
+ ``other`` and as a keyword argument, the keyword argument takes
888
+ precedence.
889
+
890
+ See Also
891
+ --------
892
+ - :meth:`xarray.Dataset.update`
893
+ """
894
+ other = other or {}
895
+ other.update(kwargs)
896
+
897
+ # pull data of MetDataArray value
898
+ for k, v in other.items():
899
+ if isinstance(v, MetDataArray):
900
+ other[k] = v.data
901
+
902
+ self.data.update(other)
903
+
904
+ def __iter__(self) -> Iterator[str]:
905
+ """Allow for the use as "key" in self.met, where "key" is a data variable."""
906
+ # From the typing perspective, `iter(self.data)`` returns Hashables (not
907
+ # necessarily strs). In everything we do, we use str variables.
908
+ # If we decide to extend this to support Hashable, we'll also want to
909
+ # change VectorDataset -- the underlying :attr:`data` should then be
910
+ # changed to `dict[Hashable, np.ndarray]`.
911
+ for key in self.data:
912
+ yield str(key)
913
+
914
+ def __contains__(self, key: Hashable) -> bool:
915
+ """Check if key ``key`` is in :attr:`data`.
916
+
917
+ Parameters
918
+ ----------
919
+ key : Hashable
920
+ Key to check
921
+
922
+ Returns
923
+ -------
924
+ bool
925
+ True if ``key`` is in :attr:`data`, False otherwise
926
+ """
927
+ return key in self.data
928
+
929
+ @property
930
+ @override
931
+ def shape(self) -> tuple[int, int, int, int]:
932
+ sizes = self.data.sizes
933
+ return sizes["longitude"], sizes["latitude"], sizes["level"], sizes["time"]
934
+
935
+ @property
936
+ @override
937
+ def size(self) -> int:
938
+ return np.prod(self.shape).item()
939
+
940
+ def ensure_vars(
941
+ self,
942
+ vars: MetVariable | str | Sequence[MetVariable | str | Sequence[MetVariable]],
943
+ raise_error: bool = True,
944
+ ) -> list[str]:
945
+ """Ensure variables exist in xr.Dataset.
946
+
947
+ Parameters
948
+ ----------
949
+ vars : MetVariable | str | Sequence[MetVariable | str | list[MetVariable]]
950
+ List of MetVariable (or string key), or individual MetVariable (or string key).
951
+ If ``vars`` contains an element with a list[MetVariable], then
952
+ only one variable in the list must be present in dataset.
953
+ raise_error : bool, optional
954
+ Raise KeyError if data does not contain variables.
955
+ Defaults to True.
956
+
957
+ Returns
958
+ -------
959
+ list[str]
960
+ List of met keys verified in MetDataset.
961
+ Returns an empty list if any MetVariable is missing.
962
+
963
+ Raises
964
+ ------
965
+ KeyError
966
+ Raises when dataset does not contain variable in ``vars``
967
+ """
968
+ if isinstance(vars, MetVariable | str):
969
+ vars = (vars,)
970
+
971
+ met_keys: list[str] = []
972
+ for variable in vars:
973
+ met_key: str | None = None
974
+
975
+ # input is a MetVariable or str
976
+ if isinstance(variable, MetVariable):
977
+ if (key := variable.standard_name) in self:
978
+ met_key = key
979
+ elif isinstance(variable, str):
980
+ if variable in self:
981
+ met_key = variable
982
+
983
+ # otherwise, assume input is an sequence
984
+ # Sequence[MetVariable] means that any variable in list will work
985
+ else:
986
+ for v in variable:
987
+ if (key := v.standard_name) in self:
988
+ met_key = key
989
+ break
990
+
991
+ if met_key is None:
992
+ if not raise_error:
993
+ return []
994
+
995
+ if isinstance(variable, MetVariable):
996
+ raise KeyError(f"Dataset does not contain variable `{variable.standard_name}`")
997
+ if isinstance(variable, list):
998
+ raise KeyError(
999
+ "Dataset does not contain one of variables "
1000
+ f"`{[v.standard_name for v in variable]}`"
1001
+ )
1002
+ raise KeyError(f"Dataset does not contain variable `{variable}`")
1003
+
1004
+ met_keys.append(met_key)
1005
+
1006
+ return met_keys
1007
+
1008
+ def save(self, **kwargs: Any) -> list[str]:
1009
+ """Save intermediate to :attr:`cachestore` as netcdf.
1010
+
1011
+ Load and restore using :meth:`load`.
1012
+
1013
+ Parameters
1014
+ ----------
1015
+ **kwargs : Any
1016
+ Keyword arguments passed directly to :meth:`xarray.Dataset.to_netcdf`
1017
+
1018
+ Returns
1019
+ -------
1020
+ list[str]
1021
+ Returns filenames saved
1022
+ """
1023
+ return self._save(self.data, **kwargs)
1024
+
1025
+ @classmethod
1026
+ def load(
1027
+ cls,
1028
+ hash: str,
1029
+ cachestore: CacheStore | None = None,
1030
+ chunks: dict[str, int] | None = None,
1031
+ ) -> Self:
1032
+ """Load saved intermediate from :attr:`cachestore`.
1033
+
1034
+ Parameters
1035
+ ----------
1036
+ hash : str
1037
+ Saved hash to load.
1038
+ cachestore : :class:`CacheStore`, optional
1039
+ Cache datastore to use for sourcing files.
1040
+ Defaults to DiskCacheStore.
1041
+ chunks : dict[str: int], optional
1042
+ Chunks kwarg passed to :func:`xarray.open_mfdataset()` when opening files.
1043
+
1044
+ Returns
1045
+ -------
1046
+ Self
1047
+ New MetDataArray with loaded data.
1048
+ """
1049
+ cachestore = cachestore or DiskCacheStore()
1050
+ chunks = chunks or {}
1051
+ data = _load(hash, cachestore, chunks)
1052
+ return cls(data)
1053
+
1054
+ @override
1055
+ def broadcast_coords(self, name: str) -> xr.DataArray:
1056
+ da = xr.ones_like(self.data[next(iter(self.data.keys()))]) * self.data[name]
1057
+ da.name = name
1058
+
1059
+ return da
1060
+
1061
+ def to_vector(self, transfer_attrs: bool = True) -> vector_module.GeoVectorDataset:
1062
+ """Convert a :class:`MetDataset` to a :class:`GeoVectorDataset` by raveling data.
1063
+
1064
+ If :attr:`data` is lazy, it will be loaded.
1065
+
1066
+ Parameters
1067
+ ----------
1068
+ transfer_attrs : bool, optional
1069
+ Transfer attributes from :attr:`data` to output :class:`GeoVectorDataset`.
1070
+ By default, True, meaning that attributes are transferred.
1071
+
1072
+ Returns
1073
+ -------
1074
+ GeoVectorDataset
1075
+ Converted :class:`GeoVectorDataset`. The variables on the returned instance
1076
+ include all of those on the input instance, plus the four core spatial temporal
1077
+ variables.
1078
+
1079
+ Examples
1080
+ --------
1081
+ >>> from pycontrails.datalib.ecmwf import ERA5
1082
+ >>> times = "2022-03-01", "2022-03-01T01"
1083
+ >>> variables = ["air_temperature", "specific_humidity"]
1084
+ >>> levels = [250, 200]
1085
+ >>> era5 = ERA5(time=times, variables=variables, pressure_levels=levels)
1086
+ >>> met = era5.open_metdataset()
1087
+ >>> met.to_vector(transfer_attrs=False)
1088
+ GeoVectorDataset [6 keys x 4152960 length, 0 attributes]
1089
+ Keys: longitude, latitude, level, time, air_temperature, ..., specific_humidity
1090
+ Attributes:
1091
+ time [2022-03-01 00:00:00, 2022-03-01 01:00:00]
1092
+ longitude [-180.0, 179.75]
1093
+ latitude [-90.0, 90.0]
1094
+ altitude [10362.8, 11783.9]
1095
+
1096
+ """
1097
+ coords_keys = self.data.dims
1098
+ indexes = self.indexes
1099
+ coords_vals = [indexes[key].values for key in coords_keys]
1100
+ coords_meshes = np.meshgrid(*coords_vals, indexing="ij")
1101
+ raveled_coords = (mesh.ravel() for mesh in coords_meshes)
1102
+ data = dict(zip(coords_keys, raveled_coords, strict=True))
1103
+
1104
+ out = vector_module.GeoVectorDataset(data, copy=False)
1105
+ for key, da in self.data.items():
1106
+ # The call to .values here will load the data if it is lazy
1107
+ out[key] = da.values.ravel() # type: ignore[index]
1108
+
1109
+ if transfer_attrs:
1110
+ out.attrs.update(self.attrs)
1111
+
1112
+ return out
1113
+
1114
+ def _get_pycontrails_attr_template(
1115
+ self,
1116
+ name: str,
1117
+ supported: tuple[str, ...],
1118
+ examples: dict[str, str],
1119
+ ) -> str:
1120
+ """Look up an attribute with a custom error message."""
1121
+ try:
1122
+ out = self.attrs[name]
1123
+ except KeyError as e:
1124
+ msg = f"Specify '{name}' attribute on underlying dataset."
1125
+
1126
+ for i, (k, v) in enumerate(examples.items()):
1127
+ if i == 0:
1128
+ msg = f"{msg} For example, set attrs['{name}'] = '{k}' for {v}."
1129
+ else:
1130
+ msg = f"{msg} Set attrs['{name}'] = '{k}' for {v}."
1131
+ raise KeyError(msg) from e
1132
+
1133
+ if out not in supported:
1134
+ warnings.warn(
1135
+ f"Unknown {name} '{out}'. Data may not be processed correctly. "
1136
+ f"Known {name}s are {supported}. Contact the pycontrails "
1137
+ "developers if you believe this is an error."
1138
+ )
1139
+
1140
+ return out
1141
+
1142
+ @property
1143
+ def provider_attr(self) -> str:
1144
+ """Look up the 'provider' attribute with a custom error message.
1145
+
1146
+ Returns
1147
+ -------
1148
+ str
1149
+ Provider of the data. If not one of 'ECMWF' or 'NCEP',
1150
+ a warning is issued.
1151
+ """
1152
+ supported = ("ECMWF", "NCEP")
1153
+ examples = {"ECMWF": "data provided by ECMWF", "NCEP": "GFS data"}
1154
+ return self._get_pycontrails_attr_template("provider", supported, examples)
1155
+
1156
+ @property
1157
+ def dataset_attr(self) -> str:
1158
+ """Look up the 'dataset' attribute with a custom error message.
1159
+
1160
+ Returns
1161
+ -------
1162
+ str
1163
+ Dataset of the data. If not one of 'ERA5', 'HRES', 'IFS',
1164
+ or 'GFS', a warning is issued.
1165
+ """
1166
+ supported = ("ERA5", "HRES", "IFS", "GFS")
1167
+ examples = {
1168
+ "ERA5": "ECMWF ERA5 reanalysis data",
1169
+ "HRES": "ECMWF HRES forecast data",
1170
+ "GFS": "NCEP GFS forecast data",
1171
+ }
1172
+ return self._get_pycontrails_attr_template("dataset", supported, examples)
1173
+
1174
+ @property
1175
+ def product_attr(self) -> str:
1176
+ """Look up the 'product' attribute with a custom error message.
1177
+
1178
+ Returns
1179
+ -------
1180
+ str
1181
+ Product of the data. If not one of 'forecast', 'ensemble', or 'reanalysis',
1182
+ a warning is issued.
1183
+
1184
+ """
1185
+ supported = ("reanalysis", "forecast", "ensemble")
1186
+ examples = {
1187
+ "reanalysis": "ECMWF ERA5 reanalysis data",
1188
+ "ensemble": "ECMWF ERA5 ensemble member data",
1189
+ }
1190
+ return self._get_pycontrails_attr_template("product", supported, examples)
1191
+
1192
+ @overload
1193
+ def standardize_variables(
1194
+ self, variables: Iterable[MetVariable], inplace: Literal[False] = ...
1195
+ ) -> Self: ...
1196
+
1197
+ @overload
1198
+ def standardize_variables(
1199
+ self, variables: Iterable[MetVariable], inplace: Literal[True]
1200
+ ) -> None: ...
1201
+
1202
+ def standardize_variables(
1203
+ self, variables: Iterable[MetVariable], inplace: bool = False
1204
+ ) -> Self | None:
1205
+ """Standardize variable names.
1206
+
1207
+ .. versionchanged:: 0.54.7
1208
+
1209
+ By default, this method returns a new :class:`MetDataset` instead
1210
+ of renaming in place. To retain the old behavior, set ``inplace=True``.
1211
+
1212
+ Parameters
1213
+ ----------
1214
+ variables : Iterable[MetVariable]
1215
+ Data source variables
1216
+ inplace : bool, optional
1217
+ If True, rename variables in place. Otherwise, return a new
1218
+ :class:`MetDataset` with renamed variables.
1219
+
1220
+ See Also
1221
+ --------
1222
+ :func:`standardize_variables`
1223
+ """
1224
+ data_renamed = standardize_variables(self.data, variables)
1225
+
1226
+ if inplace:
1227
+ self.data = data_renamed
1228
+ return None
1229
+
1230
+ return type(self)._from_fastpath(data_renamed, cachestore=self.cachestore)
1231
+
1232
+ @classmethod
1233
+ def from_coords(
1234
+ cls,
1235
+ longitude: npt.ArrayLike | float,
1236
+ latitude: npt.ArrayLike | float,
1237
+ level: npt.ArrayLike | float,
1238
+ time: npt.ArrayLike | np.datetime64,
1239
+ ) -> Self:
1240
+ r"""Create a :class:`MetDataset` containing a coordinate skeleton from coordinate arrays.
1241
+
1242
+ Parameters
1243
+ ----------
1244
+ longitude, latitude : npt.ArrayLike | float
1245
+ Horizontal coordinates, in [:math:`\deg`]
1246
+ level : npt.ArrayLike | float
1247
+ Vertical coordinate, in [:math:`hPa`]
1248
+ time: npt.ArrayLike | np.datetime64,
1249
+ Temporal coordinates, in [:math:`UTC`]. Will be sorted.
1250
+
1251
+ Returns
1252
+ -------
1253
+ Self
1254
+ MetDataset with no variables.
1255
+
1256
+ Examples
1257
+ --------
1258
+ >>> # Create skeleton MetDataset
1259
+ >>> longitude = np.arange(0, 10, 0.5)
1260
+ >>> latitude = np.arange(0, 10, 0.5)
1261
+ >>> level = [250, 300]
1262
+ >>> time = np.datetime64("2019-01-01")
1263
+ >>> met = MetDataset.from_coords(longitude, latitude, level, time)
1264
+ >>> met
1265
+ MetDataset with data:
1266
+ <xarray.Dataset> Size: 360B
1267
+ Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1268
+ Coordinates:
1269
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1270
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1271
+ * level (level) float64 16B 250.0 300.0
1272
+ * time (time) datetime64[ns] 8B 2019-01-01
1273
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1274
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1275
+ Data variables:
1276
+ *empty*
1277
+
1278
+ >>> met.shape
1279
+ (20, 20, 2, 1)
1280
+
1281
+ >>> met.size
1282
+ 800
1283
+
1284
+ >>> # Fill it up with some constant data
1285
+ >>> met["temperature"] = xr.DataArray(np.full(met.shape, 234.5), coords=met.coords)
1286
+ >>> met["humidity"] = xr.DataArray(np.full(met.shape, 0.5), coords=met.coords)
1287
+ >>> met
1288
+ MetDataset with data:
1289
+ <xarray.Dataset> Size: 13kB
1290
+ Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1291
+ Coordinates:
1292
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1293
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1294
+ * level (level) float64 16B 250.0 300.0
1295
+ * time (time) datetime64[ns] 8B 2019-01-01
1296
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1297
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1298
+ Data variables:
1299
+ temperature (longitude, latitude, level, time) float64 6kB 234.5 ... 234.5
1300
+ humidity (longitude, latitude, level, time) float64 6kB 0.5 0.5 ... 0.5
1301
+
1302
+ >>> # Convert to a GeoVectorDataset
1303
+ >>> vector = met.to_vector()
1304
+ >>> vector.dataframe.head()
1305
+ longitude latitude level time temperature humidity
1306
+ 0 0.0 0.0 250.0 2019-01-01 234.5 0.5
1307
+ 1 0.0 0.0 300.0 2019-01-01 234.5 0.5
1308
+ 2 0.0 0.5 250.0 2019-01-01 234.5 0.5
1309
+ 3 0.0 0.5 300.0 2019-01-01 234.5 0.5
1310
+ 4 0.0 1.0 250.0 2019-01-01 234.5 0.5
1311
+ """
1312
+ input_data = {
1313
+ "longitude": longitude,
1314
+ "latitude": latitude,
1315
+ "level": level,
1316
+ "time": time,
1317
+ }
1318
+
1319
+ # clean up input into coords
1320
+ coords: dict[str, np.ndarray] = {}
1321
+ for key, val in input_data.items():
1322
+ dtype = "datetime64[ns]" if key == "time" else COORD_DTYPE
1323
+ arr: np.ndarray = np.asarray(val, dtype=dtype)
1324
+
1325
+ if arr.ndim == 0:
1326
+ arr = arr.reshape(1)
1327
+ elif arr.ndim > 1:
1328
+ raise ValueError(f"{key} has too many dimensions")
1329
+
1330
+ arr = np.sort(arr)
1331
+ if arr.size == 0:
1332
+ raise ValueError(f"Coordinate {key} must be nonempty.")
1333
+
1334
+ coords[key] = arr
1335
+
1336
+ return cls(xr.Dataset({}, coords=coords))
1337
+
1338
+ @classmethod
1339
+ def from_zarr(cls, store: Any, **kwargs: Any) -> Self:
1340
+ """Create a :class:`MetDataset` from a path to a Zarr store.
1341
+
1342
+ Parameters
1343
+ ----------
1344
+ store : Any
1345
+ Path to Zarr store. Passed into :func:`xarray.open_zarr`.
1346
+ **kwargs : Any
1347
+ Other keyword only arguments passed into :func:`xarray.open_zarr`.
1348
+
1349
+ Returns
1350
+ -------
1351
+ Self
1352
+ MetDataset with data from Zarr store.
1353
+ """
1354
+ kwargs.setdefault("storage_options", {"read_only": True})
1355
+ ds = xr.open_zarr(store, **kwargs)
1356
+ return cls(ds)
1357
+
1358
+
1359
+ class MetDataArray(MetBase):
1360
+ """Meteorological DataArray of single variable.
1361
+
1362
+ Wrapper around :class:`xarray.DataArray` to enforce certain
1363
+ variables and dimensions for internal usage.
1364
+
1365
+ .. versionchanged:: 0.54.4
1366
+
1367
+ Remove ``validate`` parameter. Validation is now always performed.
1368
+
1369
+ Parameters
1370
+ ----------
1371
+ data : ArrayLike
1372
+ xr.DataArray or other array-like data source.
1373
+ When array-like input is provided, input ``**kwargs`` passed directly to
1374
+ xr.DataArray constructor.
1375
+ cachestore : :class:`CacheStore`, optional
1376
+ Cache datastore for staging intermediates with :meth:`save`.
1377
+ Defaults to DiskCacheStore.
1378
+ wrap_longitude : bool, optional
1379
+ Wrap data along the longitude dimension. If True, duplicate and shift longitude
1380
+ values (ie, -180 -> 180) to ensure that the longitude dimension covers the entire
1381
+ interval ``[-180, 180]``. Defaults to False.
1382
+ copy : bool, optional
1383
+ Copy `data` parameter on construction, by default `True`. If `data` is lazy-loaded
1384
+ via `dask`, this parameter has no effect. If `data` is already loaded into memory,
1385
+ a copy of the data (rather than a view) may be created if `True`.
1386
+ name : Hashable, optional
1387
+ Name of the data variable. If not specified, the name will be set to "met".
1388
+
1389
+ Examples
1390
+ --------
1391
+ >>> import numpy as np
1392
+ >>> import xarray as xr
1393
+ >>> rng = np.random.default_rng(seed=456)
1394
+
1395
+ >>> # Cook up random xarray object
1396
+ >>> coords = {
1397
+ ... "longitude": np.arange(-20, 20),
1398
+ ... "latitude": np.arange(-30, 30),
1399
+ ... "level": [220, 240, 260, 280],
1400
+ ... "time": [np.datetime64("2021-08-01T12", "ns"), np.datetime64("2021-08-01T16", "ns")]
1401
+ ... }
1402
+ >>> da = xr.DataArray(rng.random((40, 60, 4, 2)), dims=coords.keys(), coords=coords)
1403
+
1404
+ >>> # Cast to `MetDataArray` in order to interpolate
1405
+ >>> from pycontrails import MetDataArray
1406
+ >>> mda = MetDataArray(da)
1407
+ >>> mda.interpolate(-11.4, 5.7, 234, np.datetime64("2021-08-01T13"))
1408
+ array([0.52358215])
1409
+
1410
+ >>> mda.interpolate(-11.4, 5.7, 234, np.datetime64("2021-08-01T13"), method='nearest')
1411
+ array([0.4188465])
1412
+
1413
+ >>> da.sel(longitude=-11, latitude=6, level=240, time=np.datetime64("2021-08-01T12")).item()
1414
+ 0.41884649899766946
1415
+ """
1416
+
1417
+ __slots__ = ()
1418
+
1419
+ data: xr.DataArray
1420
+
1421
+ def __init__(
1422
+ self,
1423
+ data: xr.DataArray,
1424
+ cachestore: CacheStore | None = None,
1425
+ wrap_longitude: bool = False,
1426
+ copy: bool = True,
1427
+ name: Hashable | None = None,
1428
+ ) -> None:
1429
+ self.cachestore = cachestore
1430
+
1431
+ if copy:
1432
+ self.data = data.copy()
1433
+ self._preprocess_dims(wrap_longitude)
1434
+ elif wrap_longitude:
1435
+ raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
1436
+ else:
1437
+ self.data = data
1438
+ self._validate_dims()
1439
+
1440
+ # Priority: name > data.name > "met"
1441
+ self.data.name = name or self.data.name or "met"
1442
+
1443
+ @property
1444
+ def values(self) -> np.ndarray:
1445
+ """Return underlying numpy array.
1446
+
1447
+ This methods loads :attr:`data` if it is not already in memory.
1448
+
1449
+ Returns
1450
+ -------
1451
+ np.ndarray
1452
+ Underlying numpy array
1453
+
1454
+ See Also
1455
+ --------
1456
+ :meth:`xarray.Dataset.load`
1457
+ :meth:`xarray.DataArray.load`
1458
+
1459
+ """
1460
+ if not self.in_memory:
1461
+ self._check_memory("Extracting numpy array from")
1462
+ self.data.load()
1463
+
1464
+ return self.data.values
1465
+
1466
+ @property
1467
+ def name(self) -> Hashable:
1468
+ """Return the DataArray name.
1469
+
1470
+ Returns
1471
+ -------
1472
+ Hashable
1473
+ DataArray name
1474
+ """
1475
+ return self.data.name
1476
+
1477
+ @property
1478
+ def binary(self) -> bool:
1479
+ """Determine if all data is a binary value (0, 1).
1480
+
1481
+ Returns
1482
+ -------
1483
+ bool
1484
+ True if all data values are binary value (0, 1)
1485
+ """
1486
+ return np.array_equal(self.data, self.data.astype(bool))
1487
+
1488
+ @property
1489
+ @override
1490
+ def size(self) -> int:
1491
+ return self.data.size
1492
+
1493
+ @property
1494
+ @override
1495
+ def shape(self) -> tuple[int, int, int, int]:
1496
+ # https://github.com/python/mypy/issues/1178
1497
+ return typing.cast(tuple[int, int, int, int], self.data.shape)
1498
+
1499
+ @property
1500
+ def in_memory(self) -> bool:
1501
+ """Check if underlying :attr:`data` is loaded into memory.
1502
+
1503
+ This method uses protected attributes of underlying `xarray` objects, and may be subject
1504
+ to deprecation.
1505
+
1506
+ .. versionchanged:: 0.26.0
1507
+
1508
+ Rename from ``is_loaded`` to ``in_memory``.
1509
+
1510
+ Returns
1511
+ -------
1512
+ bool
1513
+ If underlying data exists as an `np.ndarray` in memory.
1514
+ """
1515
+ return self.data._in_memory
1516
+
1517
+ @overload
1518
+ def interpolate(
1519
+ self,
1520
+ longitude: float | npt.NDArray[np.floating],
1521
+ latitude: float | npt.NDArray[np.floating],
1522
+ level: float | npt.NDArray[np.floating],
1523
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1524
+ *,
1525
+ method: str = ...,
1526
+ bounds_error: bool = ...,
1527
+ fill_value: float | np.float64 | None = ...,
1528
+ localize: bool = ...,
1529
+ lowmem: bool = ...,
1530
+ indices: interpolation.RGIArtifacts | None = ...,
1531
+ return_indices: Literal[False] = ...,
1532
+ ) -> npt.NDArray[np.floating]: ...
1533
+
1534
+ @overload
1535
+ def interpolate(
1536
+ self,
1537
+ longitude: float | npt.NDArray[np.floating],
1538
+ latitude: float | npt.NDArray[np.floating],
1539
+ level: float | npt.NDArray[np.floating],
1540
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1541
+ *,
1542
+ method: str = ...,
1543
+ bounds_error: bool = ...,
1544
+ fill_value: float | np.float64 | None = ...,
1545
+ localize: bool = ...,
1546
+ lowmem: bool = ...,
1547
+ indices: interpolation.RGIArtifacts | None = ...,
1548
+ return_indices: Literal[True],
1549
+ ) -> tuple[npt.NDArray[np.floating], interpolation.RGIArtifacts]: ...
1550
+
1551
+ def interpolate(
1552
+ self,
1553
+ longitude: float | npt.NDArray[np.floating],
1554
+ latitude: float | npt.NDArray[np.floating],
1555
+ level: float | npt.NDArray[np.floating],
1556
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1557
+ *,
1558
+ method: str = "linear",
1559
+ bounds_error: bool = False,
1560
+ fill_value: float | np.float64 | None = np.nan,
1561
+ localize: bool = False,
1562
+ lowmem: bool = False,
1563
+ indices: interpolation.RGIArtifacts | None = None,
1564
+ return_indices: bool = False,
1565
+ ) -> npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], interpolation.RGIArtifacts]:
1566
+ """Interpolate values over underlying DataArray.
1567
+
1568
+ Zero dimensional coordinates are reshaped to 1D arrays.
1569
+
1570
+ If ``lowmem == False``, method automatically loads underlying :attr:`data` into
1571
+ memory. Otherwise, method iterates through smaller subsets of :attr:`data` and releases
1572
+ subsets from memory once interpolation against each subset is finished.
1573
+
1574
+ If ``method == "nearest"``, the out array will have the same ``dtype`` as
1575
+ the underlying :attr:`data`.
1576
+
1577
+ If ``method == "linear"``, the out array will be promoted to the most
1578
+ precise ``dtype`` of:
1579
+
1580
+ - underlying :attr:`data`
1581
+ - :attr:`data.longitude`
1582
+ - :attr:`data.latitude`
1583
+ - :attr:`data.level`
1584
+ - ``longitude``
1585
+ - ``latitude``
1586
+
1587
+ .. versionadded:: 0.24
1588
+
1589
+ This method can now handle singleton dimensions with ``method == "linear"``.
1590
+ Previously these degenerate dimensions caused nan values to be returned.
1591
+
1592
+ Parameters
1593
+ ----------
1594
+ longitude : float | npt.NDArray[np.floating]
1595
+ Longitude values to interpolate. Assumed to be 0 or 1 dimensional.
1596
+ latitude : float | npt.NDArray[np.floating]
1597
+ Latitude values to interpolate. Assumed to be 0 or 1 dimensional.
1598
+ level : float | npt.NDArray[np.floating]
1599
+ Level values to interpolate. Assumed to be 0 or 1 dimensional.
1600
+ time : np.datetime64 | npt.NDArray[np.datetime64]
1601
+ Time values to interpolate. Assumed to be 0 or 1 dimensional.
1602
+ method: str, optional
1603
+ Additional keyword arguments to pass to
1604
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1605
+ Defaults to "linear".
1606
+ bounds_error: bool, optional
1607
+ Additional keyword arguments to pass to
1608
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1609
+ Defaults to ``False``.
1610
+ fill_value: float | np.float64, optional
1611
+ Additional keyword arguments to pass to
1612
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1613
+ Set to None to extrapolate outside the boundary when ``method`` is ``nearest``.
1614
+ Defaults to ``np.nan``.
1615
+ localize: bool, optional
1616
+ Experimental. If True, downselect gridded data to smallest bounding box containing
1617
+ all points. By default False.
1618
+ lowmem: bool, optional
1619
+ Experimental. If True, iterate through points binned by the time coordinate of the
1620
+ grided data, and downselect gridded data to the smallest bounding box containing
1621
+ each binned set of point *before loading into memory*. This can significantly reduce
1622
+ memory consumption with large numbers of points at the cost of increased runtime.
1623
+ By default False.
1624
+ indices: tuple | None, optional
1625
+ Experimental. See :func:`interpolation.interp`. None by default.
1626
+ return_indices: bool, optional
1627
+ Experimental. See :func:`interpolation.interp`. False by default.
1628
+ Note that values returned differ when ``lowmem=True`` and ``lowmem=False``,
1629
+ so output should only be re-used in calls with the same ``lowmem`` value.
1630
+
1631
+ Returns
1632
+ -------
1633
+ np.ndarray
1634
+ Interpolated values
1635
+
1636
+ See Also
1637
+ --------
1638
+ :meth:`GeoVectorDataset.intersect_met`
1639
+
1640
+ Examples
1641
+ --------
1642
+ >>> from datetime import datetime
1643
+ >>> import numpy as np
1644
+ >>> import pandas as pd
1645
+ >>> from pycontrails.datalib.ecmwf import ERA5
1646
+
1647
+ >>> times = (datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 15))
1648
+ >>> variables = "air_temperature"
1649
+ >>> levels = [200, 250, 300]
1650
+ >>> era5 = ERA5(times, variables, levels)
1651
+ >>> met = era5.open_metdataset()
1652
+ >>> mda = met["air_temperature"]
1653
+
1654
+ >>> # Interpolation at a grid point agrees with value
1655
+ >>> mda.interpolate(1, 2, 300, np.datetime64('2022-03-01T14:00'))
1656
+ array([241.91972984])
1657
+
1658
+ >>> da = mda.data
1659
+ >>> da.sel(longitude=1, latitude=2, level=300, time=np.datetime64('2022-03-01T14')).item()
1660
+ 241.9197298421629
1661
+
1662
+ >>> # Interpolation off grid
1663
+ >>> mda.interpolate(1.1, 2.1, 290, np.datetime64('2022-03-01 13:10'))
1664
+ array([239.83793798])
1665
+
1666
+ >>> # Interpolate along path
1667
+ >>> longitude = np.linspace(1, 2, 10)
1668
+ >>> latitude = np.linspace(2, 3, 10)
1669
+ >>> level = np.linspace(200, 300, 10)
1670
+ >>> time = pd.date_range("2022-03-01T14", periods=10, freq="5min")
1671
+ >>> mda.interpolate(longitude, latitude, level, time)
1672
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1673
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1674
+ 239.99274623, 242.10792167])
1675
+
1676
+ >>> # Can easily switch to alternative low-memory implementation
1677
+ >>> mda.interpolate(longitude, latitude, level, time, lowmem=True)
1678
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1679
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1680
+ 239.99274623, 242.10792167])
1681
+ """
1682
+ if lowmem:
1683
+ return self._interp_lowmem(
1684
+ longitude,
1685
+ latitude,
1686
+ level,
1687
+ time,
1688
+ method=method,
1689
+ bounds_error=bounds_error,
1690
+ fill_value=fill_value,
1691
+ indices=indices,
1692
+ return_indices=return_indices,
1693
+ )
1694
+
1695
+ # Load if necessary
1696
+ if not self.in_memory:
1697
+ self._check_memory("Interpolation over")
1698
+ self.data.load()
1699
+
1700
+ # Convert all inputs to 1d arrays
1701
+ # Not validating against ndim >= 2
1702
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1703
+
1704
+ # Pass off to the interp function, which does all the heavy lifting
1705
+ return interpolation.interp(
1706
+ longitude=longitude,
1707
+ latitude=latitude,
1708
+ level=level,
1709
+ time=time,
1710
+ da=self.data,
1711
+ method=method,
1712
+ bounds_error=bounds_error,
1713
+ fill_value=fill_value,
1714
+ localize=localize,
1715
+ indices=indices,
1716
+ return_indices=return_indices,
1717
+ )
1718
+
1719
+ def _interp_lowmem(
1720
+ self,
1721
+ longitude: float | npt.NDArray[np.floating],
1722
+ latitude: float | npt.NDArray[np.floating],
1723
+ level: float | npt.NDArray[np.floating],
1724
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1725
+ *,
1726
+ method: str = "linear",
1727
+ bounds_error: bool = False,
1728
+ fill_value: float | np.float64 | None = np.nan,
1729
+ indices: interpolation.RGIArtifacts | None = None,
1730
+ return_indices: bool = False,
1731
+ ) -> npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], interpolation.RGIArtifacts]:
1732
+ """Interpolate values against underlying DataArray.
1733
+
1734
+ This method is used by :meth:`interpolate` when ``lowmem=True``.
1735
+ Parameters and return types are identical to :meth:`interpolate`, except
1736
+ that the ``localize`` keyword argument is omitted.
1737
+ """
1738
+ # Convert all inputs to 1d arrays
1739
+ # Not validating against ndim >= 2
1740
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1741
+
1742
+ if bounds_error:
1743
+ _lowmem_boundscheck(time, self.data)
1744
+
1745
+ # Create buffers for holding interpolation output
1746
+ # Use np.full rather than np.empty so points not covered
1747
+ # by masks are filled with correct out-of-bounds values.
1748
+ out = np.full(longitude.shape, fill_value, dtype=self.data.dtype)
1749
+ if return_indices:
1750
+ rgi_artifacts = interpolation.RGIArtifacts(
1751
+ xi_indices=np.full((4, longitude.size), -1, dtype=np.int64),
1752
+ norm_distances=np.full((4, longitude.size), np.nan, dtype=np.float64),
1753
+ out_of_bounds=np.full((longitude.size,), True, dtype=np.bool_),
1754
+ )
1755
+
1756
+ # Iterate over portions of points between adjacent time steps in gridded data
1757
+ for mask in _lowmem_masks(time, self.data["time"].values):
1758
+ if mask is None or not np.any(mask):
1759
+ continue
1760
+
1761
+ lon_sl = longitude[mask]
1762
+ lat_sl = latitude[mask]
1763
+ lev_sl = level[mask]
1764
+ t_sl = time[mask]
1765
+ if indices is not None:
1766
+ indices_sl = interpolation.RGIArtifacts(
1767
+ xi_indices=indices.xi_indices[:, mask],
1768
+ norm_distances=indices.norm_distances[:, mask],
1769
+ out_of_bounds=indices.out_of_bounds[mask],
1770
+ )
1771
+ else:
1772
+ indices_sl = None
1773
+
1774
+ coords = {"longitude": lon_sl, "latitude": lat_sl, "level": lev_sl, "time": t_sl}
1775
+ if any(np.all(np.isnan(coord)) for coord in coords.values()):
1776
+ continue
1777
+ da = interpolation._localize(self.data, coords)
1778
+ if not da._in_memory:
1779
+ logger.debug(
1780
+ "Loading %s MB subset of %s into memory.",
1781
+ round(da.nbytes / 1_000_000, 2),
1782
+ da.name,
1783
+ )
1784
+ da.load()
1785
+
1786
+ if return_indices:
1787
+ out[mask], rgi_sl = interpolation.interp(
1788
+ longitude=lon_sl,
1789
+ latitude=lat_sl,
1790
+ level=lev_sl,
1791
+ time=t_sl,
1792
+ da=da,
1793
+ method=method,
1794
+ bounds_error=bounds_error,
1795
+ fill_value=fill_value,
1796
+ localize=False, # would be no-op; da is localized already
1797
+ indices=indices_sl,
1798
+ return_indices=return_indices,
1799
+ )
1800
+ rgi_artifacts.xi_indices[:, mask] = rgi_sl.xi_indices
1801
+ rgi_artifacts.norm_distances[:, mask] = rgi_sl.norm_distances
1802
+ rgi_artifacts.out_of_bounds[mask] = rgi_sl.out_of_bounds
1803
+ else:
1804
+ out[mask] = interpolation.interp(
1805
+ longitude=lon_sl,
1806
+ latitude=lat_sl,
1807
+ level=lev_sl,
1808
+ time=t_sl,
1809
+ da=da,
1810
+ method=method,
1811
+ bounds_error=bounds_error,
1812
+ fill_value=fill_value,
1813
+ localize=False, # would be no-op; da is localized already
1814
+ indices=indices_sl,
1815
+ return_indices=return_indices,
1816
+ )
1817
+
1818
+ if return_indices:
1819
+ return out, rgi_artifacts
1820
+ return out
1821
+
1822
+ def _check_memory(self, msg_start: str) -> None:
1823
+ """Check the memory usage of the underlying data.
1824
+
1825
+ If the data is larger than 4 GB, a warning is issued. If the data is
1826
+ larger than 32 GB, a RuntimeError is raised.
1827
+ """
1828
+ n_bytes = self.data.nbytes
1829
+ mb = round(n_bytes / int(1e6), 2)
1830
+ logger.debug("Loading %s into memory consumes %s MB.", self.name, mb)
1831
+
1832
+ n_gb = n_bytes // int(1e9)
1833
+ if n_gb <= 4:
1834
+ return
1835
+
1836
+ # Prevent something stupid
1837
+ msg = (
1838
+ f"{msg_start} MetDataArray {self.name} requires loading "
1839
+ f"at least {n_gb} GB of data into memory. Downselect data if possible. "
1840
+ "If working with a GeoVectorDataset instance, this can be achieved "
1841
+ "with the method 'downselect_met'."
1842
+ )
1843
+
1844
+ if n_gb > 32:
1845
+ raise RuntimeError(msg)
1846
+ warnings.warn(msg)
1847
+
1848
+ def save(self, **kwargs: Any) -> list[str]:
1849
+ """Save intermediate to :attr:`cachestore` as netcdf.
1850
+
1851
+ Load and restore using :meth:`load`.
1852
+
1853
+ Parameters
1854
+ ----------
1855
+ **kwargs : Any
1856
+ Keyword arguments passed directly to :func:`xarray.save_mfdataset`
1857
+
1858
+ Returns
1859
+ -------
1860
+ list[str]
1861
+ Returns filenames of saved files
1862
+ """
1863
+ dataset = self.data.to_dataset()
1864
+ return self._save(dataset, **kwargs)
1865
+
1866
+ @classmethod
1867
+ def load(
1868
+ cls,
1869
+ hash: str,
1870
+ cachestore: CacheStore | None = None,
1871
+ chunks: dict[str, int] | None = None,
1872
+ ) -> Self:
1873
+ """Load saved intermediate from :attr:`cachestore`.
1874
+
1875
+ Parameters
1876
+ ----------
1877
+ hash : str
1878
+ Saved hash to load.
1879
+ cachestore : CacheStore, optional
1880
+ Cache datastore to use for sourcing files.
1881
+ Defaults to DiskCacheStore.
1882
+ chunks : dict[str, int], optional
1883
+ Chunks kwarg passed to :func:`xarray.open_mfdataset()` when opening files.
1884
+
1885
+ Returns
1886
+ -------
1887
+ MetDataArray
1888
+ New MetDataArray with loaded data.
1889
+ """
1890
+ cachestore = cachestore or DiskCacheStore()
1891
+ chunks = chunks or {}
1892
+ data = _load(hash, cachestore, chunks)
1893
+ return cls(data[next(iter(data.data_vars))])
1894
+
1895
+ @property
1896
+ def proportion(self) -> float:
1897
+ """Compute proportion of points with value 1.
1898
+
1899
+ Returns
1900
+ -------
1901
+ float
1902
+ Proportion of points with value 1
1903
+
1904
+ Raises
1905
+ ------
1906
+ NotImplementedError
1907
+ If instance does not contain binary data.
1908
+ """
1909
+ if not self.binary:
1910
+ raise NotImplementedError("proportion method is only implemented for binary fields")
1911
+
1912
+ return self.data.sum().values.item() / self.data.count().values.item()
1913
+
1914
+ def find_edges(self) -> Self:
1915
+ """Find edges of regions.
1916
+
1917
+ Returns
1918
+ -------
1919
+ Self
1920
+ MetDataArray with a binary field, 1 on the edge of the regions,
1921
+ 0 outside and inside the regions.
1922
+
1923
+ Raises
1924
+ ------
1925
+ NotImplementedError
1926
+ If the instance is not binary.
1927
+ """
1928
+ if not self.binary:
1929
+ raise NotImplementedError("find_edges method is only implemented for binary fields")
1930
+
1931
+ # edge detection algorithm using differentiation to reduce the areas to lines
1932
+ def _edges(da: xr.DataArray) -> xr.DataArray:
1933
+ lat_diff = da.differentiate("latitude")
1934
+ lon_diff = da.differentiate("longitude")
1935
+ diff = da.where((lat_diff != 0) | (lon_diff != 0), 0)
1936
+
1937
+ # TODO: what is desired behavior here?
1938
+ # set boundaries to close contour regions
1939
+ diff[dict(longitude=0)] = da[dict(longitude=0)]
1940
+ diff[dict(longitude=-1)] = da[dict(longitude=-1)]
1941
+ diff[dict(latitude=0)] = da[dict(latitude=0)]
1942
+ diff[dict(latitude=-1)] = da[dict(latitude=-1)]
1943
+
1944
+ return diff
1945
+
1946
+ # load data into memory (required for value assignment in _edges()
1947
+ self.data.load()
1948
+
1949
+ data = self.data.groupby("level", squeeze=False).map(_edges)
1950
+ return type(self)(data, cachestore=self.cachestore)
1951
+
1952
+ def to_polygon_feature(
1953
+ self,
1954
+ level: float | int | None = None,
1955
+ time: np.datetime64 | datetime | None = None,
1956
+ fill_value: float = np.nan,
1957
+ iso_value: float | None = None,
1958
+ min_area: float = 0.0,
1959
+ epsilon: float = 0.0,
1960
+ lower_bound: bool = True,
1961
+ precision: int | None = None,
1962
+ interiors: bool = True,
1963
+ convex_hull: bool = False,
1964
+ include_altitude: bool = False,
1965
+ properties: dict[str, Any] | None = None,
1966
+ ) -> dict[str, Any]:
1967
+ """Create GeoJSON Feature artifact from spatial array on a single level and time slice.
1968
+
1969
+ Computed polygons always contain an exterior linear ring as defined by the
1970
+ `GeoJSON Polygon specification <https://www.rfc-editor.org/rfc/rfc7946.html#section-3.1.6>`.
1971
+ Polygons may also contain interior linear rings (holes). This method does not support
1972
+ nesting beyond the GeoJSON specification. See the :mod:`pycontrails.core.polygon`
1973
+ for additional polygon support.
1974
+
1975
+ .. versionchanged:: 0.25.12
1976
+
1977
+ Previous implementation include several additional parameters which have
1978
+ been removed:
1979
+
1980
+ - The ``approximate`` parameter
1981
+ - An ``path`` parameter to save output as JSON
1982
+ - Passing arbitrary kwargs to :func:`skimage.measure.find_contours`.
1983
+
1984
+ New implementation includes new parameters previously lacking:
1985
+
1986
+ - ``fill_value``
1987
+ - ``min_area``
1988
+ - ``include_altitude``
1989
+
1990
+ .. versionchanged:: 0.38.0
1991
+
1992
+ Change default value of ``epsilon`` from 0.15 to 0.
1993
+
1994
+ .. versionchanged:: 0.41.0
1995
+
1996
+ Convert continuous fields to binary fields before computing polygons.
1997
+ The parameters ``max_area`` and ``epsilon`` are now expressed in terms of
1998
+ longitude/latitude units instead of pixels.
1999
+
2000
+ Parameters
2001
+ ----------
2002
+ level : float, optional
2003
+ Level slice to create polygons.
2004
+ If the "level" coordinate is length 1, then the single level slice will be selected
2005
+ automatically.
2006
+ time : datetime, optional
2007
+ Time slice to create polygons.
2008
+ If the "time" coordinate is length 1, then the single time slice will be selected
2009
+ automatically.
2010
+ fill_value : float, optional
2011
+ Value used for filling missing data and for padding the underlying data array.
2012
+ Set to ``np.nan`` by default, which ensures that regions with missing data are
2013
+ never included in polygons.
2014
+ iso_value : float, optional
2015
+ Value in field to create iso-surface.
2016
+ Defaults to the average of the min and max value of the array. (This is the
2017
+ same convention as used by ``skimage``.)
2018
+ min_area : float, optional
2019
+ Minimum area of each polygon. Polygons with area less than ``min_area`` are
2020
+ not included in the output. The unit of this parameter is in longitude/latitude
2021
+ degrees squared. Set to 0 to omit any polygon filtering based on a minimal area
2022
+ conditional. By default, 0.0.
2023
+ epsilon : float, optional
2024
+ Control the extent to which the polygon is simplified. A value of 0 does not alter
2025
+ the geometry of the polygon. The unit of this parameter is in longitude/latitude
2026
+ degrees. By default, 0.0.
2027
+ lower_bound : bool, optional
2028
+ Whether to use ``iso_value`` as a lower or upper bound on values in polygon interiors.
2029
+ By default, True.
2030
+ precision : int, optional
2031
+ Number of decimal places to round coordinates to. If None, no rounding is performed.
2032
+ interiors : bool, optional
2033
+ If True, include interior linear rings (holes) in the output. True by default.
2034
+ convex_hull : bool, optional
2035
+ EXPERIMENTAL. If True, compute the convex hull of each polygon. Only implemented
2036
+ for depth=1. False by default. A warning is issued if the underlying algorithm
2037
+ fails to make valid polygons after computing the convex hull.
2038
+ include_altitude : bool, optional
2039
+ If True, include the array altitude [:math:`m`] as a z-coordinate in the
2040
+ `GeoJSON output <https://www.rfc-editor.org/rfc/rfc7946#section-3.1.1>`.
2041
+ False by default.
2042
+ properties : dict, optional
2043
+ Additional properties to include in the GeoJSON output. By default, None.
2044
+
2045
+ Returns
2046
+ -------
2047
+ dict[str, Any]
2048
+ Python representation of GeoJSON Feature with MultiPolygon geometry.
2049
+
2050
+ Notes
2051
+ -----
2052
+ :class:`Cocip` and :class:`CocipGrid` set some quantities to 0 and other quantities
2053
+ to ``np.nan`` in regions where no contrails form. When computing polygons from
2054
+ :class:`Cocip` or :class:`CocipGrid` output, take care that the choice of
2055
+ ``fill_value`` correctly includes or excludes contrail-free regions. See the
2056
+ :class:`Cocip` documentation for details about ``np.nan`` in model output.
2057
+
2058
+ See Also
2059
+ --------
2060
+ :meth:`to_polyhedra`
2061
+ :func:`polygons.find_multipolygons`
2062
+
2063
+ Examples
2064
+ --------
2065
+ >>> from pprint import pprint
2066
+ >>> from pycontrails.datalib.ecmwf import ERA5
2067
+ >>> era5 = ERA5("2022-03-01", variables="air_temperature", pressure_levels=250)
2068
+ >>> mda = era5.open_metdataset()["air_temperature"]
2069
+ >>> mda.shape
2070
+ (1440, 721, 1, 1)
2071
+
2072
+ >>> pprint(mda.to_polygon_feature(iso_value=239.5, precision=2, epsilon=0.1))
2073
+ {'geometry': {'coordinates': [[[[167.88, -22.5],
2074
+ [167.75, -22.38],
2075
+ [167.62, -22.5],
2076
+ [167.75, -22.62],
2077
+ [167.88, -22.5]]],
2078
+ [[[43.38, -33.5],
2079
+ [43.5, -34.12],
2080
+ [43.62, -33.5],
2081
+ [43.5, -33.38],
2082
+ [43.38, -33.5]]]],
2083
+ 'type': 'MultiPolygon'},
2084
+ 'properties': {},
2085
+ 'type': 'Feature'}
2086
+
2087
+ """
2088
+ if convex_hull and interiors:
2089
+ raise ValueError("Set 'interiors=False' to use the 'convex_hull' parameter.")
2090
+
2091
+ arr, altitude = _extract_2d_arr_and_altitude(self, level, time)
2092
+ if not include_altitude:
2093
+ altitude = None # this logic used below
2094
+ elif altitude is None:
2095
+ raise ValueError(
2096
+ "The parameter 'include_altitude' is True, but altitude is not "
2097
+ "found on MetDataArray instance. Either set altitude, or pass "
2098
+ "include_altitude=False."
2099
+ )
2100
+
2101
+ if not np.isnan(fill_value):
2102
+ np.nan_to_num(arr, copy=False, nan=fill_value)
2103
+
2104
+ # default iso_value
2105
+ if iso_value is None:
2106
+ iso_value = (np.nanmax(arr) + np.nanmin(arr)) / 2
2107
+ warnings.warn(f"The 'iso_value' parameter was not specified. Using value: {iso_value}")
2108
+
2109
+ # We'll get a nice error message if dependencies are not installed
2110
+ from pycontrails.core import polygon
2111
+
2112
+ # Convert to nested lists of coordinates for GeoJSON representation
2113
+ indexes = self.indexes
2114
+ longitude = indexes["longitude"].to_numpy()
2115
+ latitude = indexes["latitude"].to_numpy()
2116
+
2117
+ mp = polygon.find_multipolygon(
2118
+ arr,
2119
+ threshold=iso_value,
2120
+ min_area=min_area,
2121
+ epsilon=epsilon,
2122
+ lower_bound=lower_bound,
2123
+ interiors=interiors,
2124
+ convex_hull=convex_hull,
2125
+ longitude=longitude,
2126
+ latitude=latitude,
2127
+ precision=precision,
2128
+ )
2129
+
2130
+ return polygon.multipolygon_to_geojson(mp, altitude, properties)
2131
+
2132
+ def to_polygon_feature_collection(
2133
+ self,
2134
+ time: np.datetime64 | datetime | None = None,
2135
+ fill_value: float = np.nan,
2136
+ iso_value: float | None = None,
2137
+ min_area: float = 0.0,
2138
+ epsilon: float = 0.0,
2139
+ lower_bound: bool = True,
2140
+ precision: int | None = None,
2141
+ interiors: bool = True,
2142
+ convex_hull: bool = False,
2143
+ include_altitude: bool = False,
2144
+ properties: dict[str, Any] | None = None,
2145
+ ) -> dict[str, Any]:
2146
+ """Create GeoJSON FeatureCollection artifact from spatial array at time slice.
2147
+
2148
+ See the :meth:`to_polygon_feature` method for a description of the parameters.
2149
+
2150
+ Returns
2151
+ -------
2152
+ dict[str, Any]
2153
+ Python representation of GeoJSON FeatureCollection. This dictionary is
2154
+ comprised of individual GeoJON Features, one per :attr:`self.data["level"]`.
2155
+ """
2156
+ base_properties = properties or {}
2157
+ features = []
2158
+ for level in self.data["level"]:
2159
+ properties = base_properties.copy()
2160
+ properties.update(level=level.item())
2161
+ properties.update({f"level_{k}": v for k, v in self.data["level"].attrs.items()})
2162
+
2163
+ feature = self.to_polygon_feature(
2164
+ level=level,
2165
+ time=time,
2166
+ fill_value=fill_value,
2167
+ iso_value=iso_value,
2168
+ min_area=min_area,
2169
+ epsilon=epsilon,
2170
+ lower_bound=lower_bound,
2171
+ precision=precision,
2172
+ interiors=interiors,
2173
+ convex_hull=convex_hull,
2174
+ include_altitude=include_altitude,
2175
+ properties=properties,
2176
+ )
2177
+ features.append(feature)
2178
+
2179
+ return {
2180
+ "type": "FeatureCollection",
2181
+ "features": features,
2182
+ }
2183
+
2184
+ @overload
2185
+ def to_polyhedra(
2186
+ self,
2187
+ *,
2188
+ time: datetime | None = ...,
2189
+ iso_value: float = ...,
2190
+ simplify_fraction: float = ...,
2191
+ lower_bound: bool = ...,
2192
+ return_type: Literal["geojson"],
2193
+ path: str | None = ...,
2194
+ altitude_scale: float = ...,
2195
+ output_vertex_normals: bool = ...,
2196
+ closed: bool = ...,
2197
+ ) -> dict: ...
2198
+
2199
+ @overload
2200
+ def to_polyhedra(
2201
+ self,
2202
+ *,
2203
+ time: datetime | None = ...,
2204
+ iso_value: float = ...,
2205
+ simplify_fraction: float = ...,
2206
+ lower_bound: bool = ...,
2207
+ return_type: Literal["mesh"],
2208
+ path: str | None = ...,
2209
+ altitude_scale: float = ...,
2210
+ output_vertex_normals: bool = ...,
2211
+ closed: bool = ...,
2212
+ ) -> o3d.geometry.TriangleMesh: ...
2213
+
2214
+ def to_polyhedra(
2215
+ self,
2216
+ *,
2217
+ time: datetime | None = None,
2218
+ iso_value: float = 0.0,
2219
+ simplify_fraction: float = 1.0,
2220
+ lower_bound: bool = True,
2221
+ return_type: str = "geojson",
2222
+ path: str | None = None,
2223
+ altitude_scale: float = 1.0,
2224
+ output_vertex_normals: bool = False,
2225
+ closed: bool = True,
2226
+ ) -> dict | o3d.geometry.TriangleMesh:
2227
+ """Create a collection of polyhedra from spatial array corresponding to a single time slice.
2228
+
2229
+ Parameters
2230
+ ----------
2231
+ time : datetime, optional
2232
+ Time slice to create mesh.
2233
+ If the "time" coordinate is length 1, then the single time slice will be selected
2234
+ automatically.
2235
+ iso_value : float, optional
2236
+ Value in field to create iso-surface. Defaults to 0.
2237
+ simplify_fraction : float, optional
2238
+ Apply `open3d` `simplify_quadric_decimation` method to simplify the polyhedra geometry.
2239
+ This parameter must be in the half-open interval (0.0, 1.0].
2240
+ Defaults to 1.0, corresponding to no reduction.
2241
+ lower_bound : bool, optional
2242
+ Whether to use ``iso_value`` as a lower or upper bound on values in polyhedra interiors.
2243
+ By default, True.
2244
+ return_type : str, optional
2245
+ Must be one of "geojson" or "mesh". Defaults to "geojson".
2246
+ If "geojson", this method returns a dictionary representation of a geojson MultiPolygon
2247
+ object whose polygons are polyhedra faces.
2248
+ If "mesh", this method returns an `open3d` `TriangleMesh` instance.
2249
+ path : str, optional
2250
+ Output geojson or mesh to file.
2251
+ If `return_type` is "mesh", see `Open3D File I/O for Mesh
2252
+ <http://www.open3d.org/docs/release/tutorial/geometry/file_io.html#Mesh>`_ for
2253
+ file type options.
2254
+ altitude_scale : float, optional
2255
+ Rescale the altitude dimension of the mesh, [:math:`m`]
2256
+ output_vertex_normals : bool, optional
2257
+ If ``path`` is defined, write out vertex normals.
2258
+ Defaults to False.
2259
+ closed : bool, optional
2260
+ If True, pad spatial array along all axes to ensure polyhedra are "closed".
2261
+ This flag often gives rise to cleaner visualizations. Defaults to True.
2262
+
2263
+ Returns
2264
+ -------
2265
+ dict | open3d.geometry.TriangleMesh
2266
+ Python representation of geojson object or `Open3D Triangle Mesh
2267
+ <http://www.open3d.org/docs/release/tutorial/geometry/mesh.html>`_ depending on the
2268
+ `return_type` parameter.
2269
+
2270
+ Raises
2271
+ ------
2272
+ ModuleNotFoundError
2273
+ Method requires the `vis` optional dependencies
2274
+ ValueError
2275
+ If input parameters are invalid.
2276
+
2277
+ See Also
2278
+ --------
2279
+ :meth:`to_polygon_feature`
2280
+ :func:`skimage.measure.marching_cubes`
2281
+ :class:`open3d.geometry.TriangleMesh`
2282
+
2283
+ Notes
2284
+ -----
2285
+ Uses the `scikit-image Marching Cubes <https://scikit-image.org/docs/dev/auto_examples/edges/plot_marching_cubes.html>`_
2286
+ algorithm to reconstruct a surface from the point-cloud like arrays.
2287
+ """
2288
+ try:
2289
+ from skimage import measure
2290
+ except ModuleNotFoundError as e:
2291
+ dependencies.raise_module_not_found_error(
2292
+ name="MetDataArray.to_polyhedra method",
2293
+ package_name="scikit-image",
2294
+ pycontrails_optional_package="vis",
2295
+ module_not_found_error=e,
2296
+ )
2297
+
2298
+ try:
2299
+ import open3d as o3d
2300
+ except ModuleNotFoundError as e:
2301
+ dependencies.raise_module_not_found_error(
2302
+ name="MetDataArray.to_polyhedra method",
2303
+ package_name="open3d",
2304
+ pycontrails_optional_package="open3d",
2305
+ module_not_found_error=e,
2306
+ )
2307
+
2308
+ if len(self.data["level"]) == 1:
2309
+ raise ValueError(
2310
+ "Found single `level` coordinate in DataArray. This method requires at least two."
2311
+ )
2312
+
2313
+ # select time
2314
+ if time is None and len(self.data["time"]) == 1:
2315
+ time = self.data["time"].values[0]
2316
+
2317
+ if time is None:
2318
+ raise ValueError(
2319
+ "time input must be defined when the length of the time coordinates are > 1"
2320
+ )
2321
+
2322
+ if simplify_fraction > 1 or simplify_fraction <= 0:
2323
+ raise ValueError("Parameter `simplify_fraction` must be in the interval (0, 1].")
2324
+
2325
+ return_types = ["geojson", "mesh"]
2326
+ if return_type not in return_types:
2327
+ raise ValueError(f"Parameter `return_type` must be one of {', '.join(return_types)}")
2328
+
2329
+ # 3d array of longitude, latitude, altitude values
2330
+ volume = self.data.sel(time=time).values
2331
+
2332
+ # invert if iso_value is an upper bound on interior values
2333
+ if not lower_bound:
2334
+ volume = -volume
2335
+ iso_value = -iso_value
2336
+
2337
+ # convert from array index back to coordinates
2338
+ longitude = self.indexes["longitude"].values
2339
+ latitude = self.indexes["latitude"].values
2340
+ altitude = units.pl_to_m(self.indexes["level"].values)
2341
+
2342
+ # Pad volume on all axes to close the volumes
2343
+ if closed:
2344
+ # pad values to domain
2345
+ longitude0 = longitude[0] - (longitude[1] - longitude[0])
2346
+ longitude1 = longitude[-1] + longitude[-1] - longitude[-2]
2347
+ longitude = np.pad(longitude, pad_width=1, constant_values=(longitude0, longitude1))
2348
+
2349
+ latitude0 = latitude[0] - (latitude[1] - latitude[0])
2350
+ latitude1 = latitude[-1] + latitude[-1] - latitude[-2]
2351
+ latitude = np.pad(latitude, pad_width=1, constant_values=(latitude0, latitude1))
2352
+
2353
+ altitude0 = altitude[0] - (altitude[1] - altitude[0])
2354
+ altitude1 = altitude[-1] + altitude[-1] - altitude[-2]
2355
+ altitude = np.pad(altitude, pad_width=1, constant_values=(altitude0, altitude1))
2356
+
2357
+ # Pad along axes to ensure polygons are closed
2358
+ volume = np.pad(volume, pad_width=1, constant_values=iso_value)
2359
+
2360
+ # Use marching cubes to obtain the surface mesh
2361
+ # Coordinates of verts are indexes of volume array
2362
+ verts, faces, normals, _ = measure.marching_cubes(
2363
+ volume, iso_value, allow_degenerate=False, gradient_direction="ascent"
2364
+ )
2365
+
2366
+ # Convert from indexes to longitude, latitude, altitude values
2367
+ verts[:, 0] = longitude[verts[:, 0].astype(int)]
2368
+ verts[:, 1] = latitude[verts[:, 1].astype(int)]
2369
+ verts[:, 2] = altitude[verts[:, 2].astype(int)]
2370
+
2371
+ # rescale altitude
2372
+ verts[:, 2] = verts[:, 2] * altitude_scale
2373
+
2374
+ # create mesh in open3d
2375
+ mesh = o3d.geometry.TriangleMesh()
2376
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
2377
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
2378
+ mesh.vertex_normals = o3d.utility.Vector3dVector(normals)
2379
+
2380
+ # simplify mesh according to sim
2381
+ if simplify_fraction < 1:
2382
+ target_n_triangles = int(faces.shape[0] * simplify_fraction)
2383
+ mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=target_n_triangles)
2384
+ mesh.compute_vertex_normals()
2385
+
2386
+ if path is not None:
2387
+ path = str(pathlib.Path(path).absolute())
2388
+
2389
+ if return_type == "geojson":
2390
+ verts = np.round(
2391
+ np.asarray(mesh.vertices), decimals=4
2392
+ ) # rounding to reduce the size of resultant json arrays
2393
+ faces = np.asarray(mesh.triangles)
2394
+
2395
+ # TODO: technically this is not valid GeoJSON because each polygon (triangle)
2396
+ # does not have the last element equal to the first (not a linear ring)
2397
+ # but it still works for now in Deck.GL
2398
+ coords = [[verts[face].tolist()] for face in faces]
2399
+
2400
+ geojson = {
2401
+ "type": "Feature",
2402
+ "properties": {},
2403
+ "geometry": {"type": "MultiPolygon", "coordinates": coords},
2404
+ }
2405
+
2406
+ if path is not None:
2407
+ if not path.endswith(".json"):
2408
+ path += ".json"
2409
+ with open(path, "w") as file:
2410
+ json.dump(geojson, file)
2411
+ return geojson
2412
+
2413
+ if path is not None:
2414
+ o3d.io.write_triangle_mesh(
2415
+ path,
2416
+ mesh,
2417
+ write_ascii=False,
2418
+ compressed=True,
2419
+ write_vertex_normals=output_vertex_normals,
2420
+ write_vertex_colors=False,
2421
+ write_triangle_uvs=True,
2422
+ print_progress=False,
2423
+ )
2424
+ return mesh
2425
+
2426
+ @override
2427
+ def broadcast_coords(self, name: str) -> xr.DataArray:
2428
+ da = xr.ones_like(self.data) * self.data[name]
2429
+ da.name = name
2430
+
2431
+ return da
2432
+
2433
+
2434
+ def _is_wrapped(longitude: np.ndarray) -> bool:
2435
+ """Check if ``longitude`` covers ``[-180, 180]``."""
2436
+ return longitude[0] <= -180.0 and longitude[-1] >= 180.0
2437
+
2438
+
2439
+ def _is_zarr(ds: xr.Dataset | xr.DataArray) -> bool:
2440
+ """Check if ``ds`` appears to be Zarr-based.
2441
+
2442
+ Neither ``xarray`` nor ``dask`` readily expose such information, so
2443
+ implementation is very narrow and brittle.
2444
+ """
2445
+ if isinstance(ds, xr.Dataset):
2446
+ # Attempt 1: xarray binds the zarr close function to the instance
2447
+ # This gives us an indicator of the data is zarr based
2448
+ try:
2449
+ if ds._close.__func__.__qualname__ == "ZarrStore.close": # type: ignore
2450
+ return True
2451
+ except AttributeError:
2452
+ pass
2453
+
2454
+ # Grab the first data variable, get underlying DataArray
2455
+ ds = ds[next(iter(ds.data_vars))]
2456
+
2457
+ # Attempt 2: Examine the dask graph
2458
+ darr = ds.variable._data # dask array in some cases
2459
+ try:
2460
+ # Get the first dask instruction
2461
+ dask0 = darr.dask[next(iter(darr.dask))] # type: ignore[union-attr]
2462
+ except AttributeError:
2463
+ return False
2464
+ return dask0.array.array.array.__class__.__name__ == "ZarrArrayWrapper"
2465
+
2466
+
2467
+ def shift_longitude(data: XArrayType, bound: float = -180.0) -> XArrayType:
2468
+ """Shift longitude values from any input domain to [bound, 360 + bound) domain.
2469
+
2470
+ Sorts data by ascending longitude values.
2471
+
2472
+
2473
+ Parameters
2474
+ ----------
2475
+ data : XArrayType
2476
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2477
+ bound : float, optional
2478
+ Lower bound of the domain.
2479
+ Output domain will be [bound, 360 + bound).
2480
+ Defaults to -180, which results in longitude domain [-180, 180).
2481
+
2482
+
2483
+ Returns
2484
+ -------
2485
+ XArrayType
2486
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [a, 360 + a).
2487
+ """
2488
+ return data.assign_coords(
2489
+ longitude=((data["longitude"].values - bound) % 360.0) + bound
2490
+ ).sortby("longitude", ascending=True)
2491
+
2492
+
2493
+ def _wrap_longitude(data: XArrayType) -> XArrayType:
2494
+ """Wrap longitude grid coordinates.
2495
+
2496
+ This function assumes the longitude dimension on ``data``:
2497
+
2498
+ - is sorted
2499
+ - is contained in [-180, 180]
2500
+
2501
+ These assumptions are checked by :class:`MetDataset` and :class`MetDataArray`
2502
+ constructors.
2503
+
2504
+ .. versionchanged:: 0.26.0
2505
+
2506
+ This function now ensures every value in the interval ``[-180, 180]``
2507
+ is covered by the longitude dimension of the returned object. See
2508
+ :meth:`MetDataset.is_wrapped` for more details.
2509
+
2510
+
2511
+ Parameters
2512
+ ----------
2513
+ data : XArrayType
2514
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2515
+
2516
+ Returns
2517
+ -------
2518
+ XArrayType
2519
+ Copy of :class:`xr.Dataset` or :class:`xr.DataArray` with wrapped longitude values.
2520
+
2521
+ Raises
2522
+ ------
2523
+ ValueError
2524
+ If longitude values are already wrapped.
2525
+ """
2526
+ lon = data._indexes["longitude"].index.to_numpy() # type: ignore[attr-defined]
2527
+ if _is_wrapped(lon):
2528
+ raise ValueError("Longitude values are already wrapped")
2529
+
2530
+ lon0 = lon[0]
2531
+ lon1 = lon[-1]
2532
+
2533
+ # Try to prevent something stupid
2534
+ if lon1 - lon0 < 330.0:
2535
+ warnings.warn("Wrapping longitude will create a large spatial gap of more than 30 degrees.")
2536
+
2537
+ objs = [data]
2538
+ if lon0 > -180.0: # if the lowest longitude is not low enough, duplicate highest
2539
+ dup1 = data.sel(longitude=[lon1]).assign_coords(longitude=[lon1 - 360.0])
2540
+ objs.insert(0, dup1)
2541
+ if lon1 < 180.0: # if the highest longitude is not highest enough, duplicate lowest
2542
+ dup0 = data.sel(longitude=[lon0]).assign_coords(longitude=[lon0 + 360.0])
2543
+ objs.append(dup0)
2544
+
2545
+ # Because we explicitly raise a ValueError if longitude already wrapped,
2546
+ # we know that len(objs) > 1, so the concatenation here is nontrivial
2547
+ wrapped = xr.concat(objs, dim="longitude")
2548
+ wrapped["longitude"] = wrapped["longitude"].astype(lon.dtype, copy=False)
2549
+
2550
+ # If there is only one longitude chunk in parameter data, increment
2551
+ # data.chunks can be None, using getattr for extra safety
2552
+ # NOTE: This probably doesn't seem to play well with Zarr data ...
2553
+ # we don't want to be rechunking them.
2554
+ # Ideally we'd raise if data was Zarr-based
2555
+ chunks = getattr(data, "chunks", None) or {}
2556
+ chunks = dict(chunks) # chunks can be frozen
2557
+ lon_chunks = chunks.get("longitude", ())
2558
+ if len(lon_chunks) == 1:
2559
+ chunks["longitude"] = (lon_chunks[0] + len(objs) - 1,)
2560
+ wrapped = wrapped.chunk(chunks)
2561
+
2562
+ return wrapped
2563
+
2564
+
2565
+ def _extract_2d_arr_and_altitude(
2566
+ mda: MetDataArray,
2567
+ level: float | int | None,
2568
+ time: np.datetime64 | datetime | None,
2569
+ ) -> tuple[np.ndarray, float | None]:
2570
+ """Extract underlying 2D array indexed by longitude and latitude.
2571
+
2572
+ Parameters
2573
+ ----------
2574
+ mda : MetDataArray
2575
+ MetDataArray to extract from
2576
+ level : float | int | None
2577
+ Pressure level to slice at
2578
+ time : np.datetime64 | datetime
2579
+ Time to slice at
2580
+
2581
+ Returns
2582
+ -------
2583
+ arr : np.ndarray
2584
+ Copy of 2D array at given level and time
2585
+ altitude : float | None
2586
+ Altitude of slice [:math:`m`]. None if "altitude" not found on ``mda``
2587
+ (ie, for surface level :class:`MetDataArray`).
2588
+ """
2589
+ # Determine level if not specified
2590
+ if level is None:
2591
+ level_coord = mda.indexes["level"].values
2592
+ if len(level_coord) == 1:
2593
+ level = level_coord[0]
2594
+ else:
2595
+ raise ValueError(
2596
+ "Parameter 'level' must be defined when the length of the 'level' "
2597
+ "coordinates is not 1."
2598
+ )
2599
+
2600
+ # Determine time if not specified
2601
+ if time is None:
2602
+ time_coord = mda.indexes["time"].values
2603
+ if len(time_coord) == 1:
2604
+ time = time_coord[0]
2605
+ else:
2606
+ raise ValueError(
2607
+ "Parameter 'time' must be defined when the length of the 'time' "
2608
+ "coordinates is not 1."
2609
+ )
2610
+
2611
+ da = mda.data.sel(level=level, time=time)
2612
+ arr = da.values.copy()
2613
+ if arr.ndim != 2:
2614
+ raise RuntimeError("Malformed data array")
2615
+
2616
+ try:
2617
+ altitude = da["altitude"].values.item() # item not implemented on dask arrays
2618
+ except KeyError:
2619
+ altitude = None
2620
+ else:
2621
+ altitude = round(altitude)
2622
+
2623
+ return arr, altitude
2624
+
2625
+
2626
+ def downselect(data: XArrayType, bbox: tuple[float, ...]) -> XArrayType:
2627
+ """Downselect :class:`xr.Dataset` or :class:`xr.DataArray` with spatial bounding box.
2628
+
2629
+ Parameters
2630
+ ----------
2631
+ data : XArrayType
2632
+ xr.Dataset or xr.DataArray to downselect
2633
+ bbox : tuple[float, ...]
2634
+ Tuple of coordinates defining a spatial bounding box in WGS84 coordinates.
2635
+
2636
+ - For 2D queries, ``bbox`` takes the form ``(west, south, east, north)``
2637
+ - For 3D queries, ``bbox`` takes the form
2638
+ ``(west, south, min-level, east, north, max-level)``
2639
+
2640
+ with level defined in [:math:`hPa`].
2641
+
2642
+ Returns
2643
+ -------
2644
+ XArrayType
2645
+ Downselected xr.Dataset or xr.DataArray
2646
+
2647
+ Raises
2648
+ ------
2649
+ ValueError
2650
+ If parameter ``bbox`` has wrong length.
2651
+ """
2652
+ if len(bbox) == 4:
2653
+ west, south, east, north = bbox
2654
+ level_min = -np.inf
2655
+ level_max = np.inf
2656
+
2657
+ elif len(bbox) == 6:
2658
+ west, south, level_min, east, north, level_max = bbox
2659
+
2660
+ else:
2661
+ raise ValueError(
2662
+ f"bbox {bbox} is not length 4 [west, south, east, north] "
2663
+ "or length 6 [west, south, min-level, east, north, max-level]"
2664
+ )
2665
+
2666
+ if west <= east:
2667
+ # Return a view of the data
2668
+ # If data is lazy, this will not load the data
2669
+ return data.sel(
2670
+ longitude=slice(west, east),
2671
+ latitude=slice(south, north),
2672
+ level=slice(level_min, level_max),
2673
+ )
2674
+
2675
+ # In this case, the bbox spans the antimeridian
2676
+ # If data is lazy, this will load the data (data.where is not lazy AFAIK)
2677
+ cond = (
2678
+ (data["latitude"] >= south)
2679
+ & (data["latitude"] <= north)
2680
+ & (data["level"] >= level_min)
2681
+ & (data["level"] <= level_max)
2682
+ & ((data["longitude"] >= west) | (data["longitude"] <= east))
2683
+ )
2684
+ return data.where(cond, drop=True)
2685
+
2686
+
2687
+ def standardize_variables(ds: xr.Dataset, variables: Iterable[MetVariable]) -> xr.Dataset:
2688
+ """Rename all variables in dataset from short name to standard name.
2689
+
2690
+ This function does not change any variables in ``ds`` that are not found in ``variables``.
2691
+
2692
+ When there are multiple variables with the same short name, the last one is used.
2693
+
2694
+ Parameters
2695
+ ----------
2696
+ ds : DatasetType
2697
+ An :class:`xr.Dataset`.
2698
+ variables : Iterable[MetVariable]
2699
+ Data source variables
2700
+
2701
+ Returns
2702
+ -------
2703
+ DatasetType
2704
+ Dataset with variables renamed to standard names
2705
+ """
2706
+ variables_dict: dict[Hashable, str] = {v.short_name: v.standard_name for v in variables}
2707
+ name_dict = {var: variables_dict[var] for var in ds.data_vars if var in variables_dict}
2708
+ return ds.rename(name_dict)
2709
+
2710
+
2711
+ def originates_from_ecmwf(met: MetDataset | MetDataArray) -> bool:
2712
+ """Check if data appears to have originated from an ECMWF source.
2713
+
2714
+ .. versionadded:: 0.27.0
2715
+
2716
+ Experimental. Implementation is brittle.
2717
+
2718
+ Parameters
2719
+ ----------
2720
+ met : MetDataset | MetDataArray
2721
+ Dataset or array to inspect.
2722
+
2723
+ Returns
2724
+ -------
2725
+ bool
2726
+ True if data appears to be derived from an ECMWF source.
2727
+
2728
+ See Also
2729
+ --------
2730
+ - :class:`ERA5`
2731
+ - :class:`HRES`
2732
+
2733
+ """
2734
+ if isinstance(met, MetDataset):
2735
+ try:
2736
+ return met.provider_attr == "ECMWF"
2737
+ except KeyError:
2738
+ pass
2739
+ return "ecmwf" in met.attrs.get("history", "")
2740
+
2741
+
2742
+ def _load(hash: str, cachestore: CacheStore, chunks: dict[str, int]) -> xr.Dataset:
2743
+ """Load xarray data from hash.
2744
+
2745
+ Parameters
2746
+ ----------
2747
+ hash : str
2748
+ Description
2749
+ cachestore : CacheStore
2750
+ Description
2751
+ chunks : dict[str, int]
2752
+ Description
2753
+
2754
+ Returns
2755
+ -------
2756
+ xr.Dataset
2757
+ Description
2758
+ """
2759
+ disk_path = cachestore.get(f"{hash}*.nc")
2760
+ return xr.open_mfdataset(disk_path, chunks=chunks)
2761
+
2762
+
2763
+ def _add_vertical_coords(data: XArrayType) -> XArrayType:
2764
+ """Add "air_pressure" and "altitude" coordinates to data.
2765
+
2766
+ .. versionchanged:: 0.52.1
2767
+ Ensure that the ``dtype`` of the additional vertical coordinates agree
2768
+ with the ``dtype`` of the underlying gridded data.
2769
+ """
2770
+
2771
+ data["level"].attrs.update(units="hPa", long_name="Pressure", positive="down")
2772
+
2773
+ # XXX: use the dtype of the data to determine the precision of these coordinates
2774
+ # There are two competing conventions here:
2775
+ # - coordinate data should be float64
2776
+ # - gridded data is typically float32
2777
+ # - air_pressure and altitude often play both roles
2778
+ # It is more important for air_pressure and altitude to be grid-aligned than to be
2779
+ # coordinate-aligned, so we use the dtype of the data to determine the precision of
2780
+ # these coordinates
2781
+ dtype = (
2782
+ np.result_type(*data.data_vars.values(), np.float32)
2783
+ if isinstance(data, xr.Dataset)
2784
+ else data.dtype
2785
+ )
2786
+
2787
+ level = data["level"].values
2788
+
2789
+ if "air_pressure" not in data.coords:
2790
+ data = data.assign_coords(air_pressure=("level", level * 100.0))
2791
+ data.coords["air_pressure"].attrs.update(
2792
+ standard_name=AirPressure.standard_name,
2793
+ long_name=AirPressure.long_name,
2794
+ units=AirPressure.units,
2795
+ )
2796
+ if data.coords["air_pressure"].dtype != dtype:
2797
+ data.coords["air_pressure"] = data.coords["air_pressure"].astype(dtype, copy=False)
2798
+
2799
+ if "altitude" not in data.coords:
2800
+ data = data.assign_coords(altitude=("level", units.pl_to_m(level)))
2801
+ data.coords["altitude"].attrs.update(
2802
+ standard_name=Altitude.standard_name,
2803
+ long_name=Altitude.long_name,
2804
+ units=Altitude.units,
2805
+ )
2806
+ if data.coords["altitude"].dtype != dtype:
2807
+ data.coords["altitude"] = data.coords["altitude"].astype(dtype, copy=False)
2808
+
2809
+ return data
2810
+
2811
+
2812
+ def _lowmem_boundscheck(time: npt.NDArray[np.datetime64], da: xr.DataArray) -> None:
2813
+ """Extra bounds check required with low-memory interpolation strategy.
2814
+
2815
+ Because the main loop in `_interp_lowmem` processes points between time steps
2816
+ in gridded data, it will never encounter points that are out-of-bounds in time
2817
+ and may fail to produce requested out-of-bounds errors.
2818
+ """
2819
+ da_time = da["time"].to_numpy()
2820
+ if not np.all((time >= da_time.min()) & (time <= da_time.max())):
2821
+ axis = da.get_axis_num("time")
2822
+ msg = f"One of the requested xi is out of bounds in dimension {axis}"
2823
+ raise ValueError(msg)
2824
+
2825
+
2826
+ def _lowmem_masks(
2827
+ time: npt.NDArray[np.datetime64], t_met: npt.NDArray[np.datetime64]
2828
+ ) -> Generator[npt.NDArray[np.bool_], None, None]:
2829
+ """Generate sequence of masks for low-memory interpolation."""
2830
+ t_met_max = t_met.max()
2831
+ t_met_min = t_met.min()
2832
+ inbounds = (time >= t_met_min) & (time <= t_met_max)
2833
+ if not np.any(inbounds):
2834
+ return
2835
+
2836
+ earliest = np.nanmin(time)
2837
+ istart = 0 if earliest < t_met_min else np.flatnonzero(t_met <= earliest).max()
2838
+ latest = np.nanmax(time)
2839
+ iend = t_met.size - 1 if latest > t_met_max else np.flatnonzero(t_met >= latest).min()
2840
+ if istart == iend:
2841
+ yield inbounds
2842
+ return
2843
+
2844
+ # Sequence of masks covers elements in time in the interval [t_met[istart], t_met[iend]].
2845
+ # The first iteration masks elements in the interval [t_met[istart], t_met[istart+1]]
2846
+ # (inclusive of both endpoints).
2847
+ # Subsequent iterations mask elements in the interval (t_met[i], t_met[i+1]]
2848
+ # (inclusive of right endpoint only).
2849
+ for i in range(istart, iend):
2850
+ mask = ((time >= t_met[i]) if i == istart else (time > t_met[i])) & (time <= t_met[i + 1])
2851
+ if np.any(mask):
2852
+ yield mask
2853
+
2854
+
2855
+ def maybe_downselect_mds(
2856
+ big_mds: MetDataset,
2857
+ little_mds: MetDataset | None,
2858
+ t0: np.datetime64,
2859
+ t1: np.datetime64,
2860
+ ) -> MetDataset:
2861
+ """Possibly downselect ``big_mds`` in the time domain to cover ``[t0, t1]``.
2862
+
2863
+ If possible, ``little_mds`` is recycled to avoid re-loading data.
2864
+
2865
+ This implementation assumes ``t0 <= t1``, but this is not enforced.
2866
+
2867
+ If ``little_mds`` already covers the time range, it is returned as-is.
2868
+
2869
+ If ``big_mds`` doesn't cover the time range, no error is raised.
2870
+
2871
+ Parameters
2872
+ ----------
2873
+ big_mds : MetDataset
2874
+ Larger MetDataset
2875
+ little_mds : MetDataset | None
2876
+ Smaller MetDataset. This is assumed to be a subset of ``big_mds``,
2877
+ though the implementation may work if this is not the case.
2878
+ t0, t1 : np.datetime64
2879
+ Time range to cover
2880
+
2881
+ Returns
2882
+ -------
2883
+ MetDataset
2884
+ MetDataset covering the time range ``[t0, t1]`` comprised of data from
2885
+ ``little_mds`` when possible, otherwise from ``big_mds``.
2886
+ """
2887
+ if little_mds is None:
2888
+ big_time = big_mds.indexes["time"].values
2889
+ i0 = np.searchsorted(big_time, t0, side="right").item()
2890
+ i0 = max(0, i0 - 1)
2891
+ i1 = np.searchsorted(big_time, t1, side="left").item()
2892
+ i1 = min(i1 + 1, big_time.size)
2893
+ return MetDataset._from_fastpath(big_mds.data.isel(time=slice(i0, i1)))
2894
+
2895
+ little_time = little_mds.indexes["time"].values
2896
+ if t0 >= little_time[0] and t1 <= little_time[-1]:
2897
+ return little_mds
2898
+
2899
+ big_time = big_mds.indexes["time"].values
2900
+ i0 = np.searchsorted(big_time, t0, side="right").item()
2901
+ i0 = max(0, i0 - 1)
2902
+ i1 = np.searchsorted(big_time, t1, side="left").item()
2903
+ i1 = min(i1 + 1, big_time.size)
2904
+ big_ds = big_mds.data.isel(time=slice(i0, i1))
2905
+ big_time = big_ds._indexes["time"].index.values # type: ignore[attr-defined]
2906
+
2907
+ # Select exactly the times in big_ds that are not in little_ds
2908
+ _, little_indices, big_indices = np.intersect1d(
2909
+ little_time, big_time, assume_unique=True, return_indices=True
2910
+ )
2911
+ little_ds = little_mds.data.isel(time=little_indices)
2912
+ filt = np.ones_like(big_time, dtype=bool)
2913
+ filt[big_indices] = False
2914
+ big_ds = big_ds.isel(time=filt)
2915
+
2916
+ # Manually load relevant parts of big_ds into memory before xr.concat
2917
+ # It appears that without this, xr.concat will forget the in-memory
2918
+ # arrays in little_ds
2919
+ for var, da in little_ds.items():
2920
+ if da._in_memory:
2921
+ da2 = big_ds[var]
2922
+ if not da2._in_memory:
2923
+ da2.load()
2924
+
2925
+ ds = xr.concat([little_ds, big_ds], dim="time")
2926
+ if not ds._indexes["time"].index.is_monotonic_increasing: # type: ignore[attr-defined]
2927
+ # Rarely would we enter this: t0 would have to be before the first
2928
+ # time in little_mds, and the various advection-based models generally
2929
+ # proceed forward in time.
2930
+ ds = ds.sortby("time")
2931
+ return MetDataset._from_fastpath(ds)